| | |
| | |
| |
|
| | import argparse |
| | from pathlib import Path |
| |
|
| | import torch |
| | import torch.distributed.checkpoint as DCP |
| | from transformers import AutoModelForCausalLM |
| |
|
| | import fla |
| | from torchtitan.tools.logging import init_logger, logger |
| |
|
| |
|
| | @torch.inference_mode() |
| | def convert_hf_weights(model: str, checkpoint: str): |
| | logger.info(f"Loading model from {model}") |
| | model = AutoModelForCausalLM.from_pretrained(model) |
| | state_dict = model.state_dict() |
| |
|
| | logger.info(f"Writing to DCP at '{checkpoint}'") |
| | checkpoint.mkdir(parents=True, exist_ok=True) |
| | storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8) |
| | DCP.save({"model": state_dict}, storage_writer=storage_writer) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | init_logger() |
| | parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.") |
| | parser.add_argument("--model", type=str, required=True) |
| | parser.add_argument("--checkpoint", type=Path, required=True) |
| | args = parser.parse_args() |
| |
|
| | convert_hf_weights(args.model, args.checkpoint) |
| |
|