| import argparse | |
| import os | |
| from huggingface_hub import HfApi, HfFolder, snapshot_download | |
| def main(args): | |
| api = HfApi() | |
| token = HfFolder.get_token() | |
| experiment_checkpoint_folder = os.path.join(args.experiment_checkpoint_folder, "checkpoint") | |
| os.makedirs( | |
| experiment_checkpoint_folder, | |
| exist_ok=True | |
| ) | |
| snapshot_download( | |
| repo_id=args.repo_id, | |
| token=token, | |
| local_dir=experiment_checkpoint_folder, | |
| ) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Download a checkpoint from Hugging Face Hub.") | |
| parser.add_argument( | |
| "--repo_id", | |
| type=str, | |
| required=True, | |
| help="The repository ID on Hugging Face Hub.", | |
| ) | |
| parser.add_argument( | |
| "--experiment_checkpoint_folder", | |
| type=str, | |
| required=True, | |
| help="The local directory to save the downloaded checkpoint.", | |
| ) | |
| args = parser.parse_args() | |
| main(args) |