| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| |
|
| | import nemo_run as run |
| |
|
| | from nemo.collections import avlm |
| |
|
| |
|
| | def configure_recipe( |
| | nodes: int = 1, |
| | gpus_per_node: int = 8, |
| | pretrain=False, |
| | language_model_from_pretrained=None, |
| | checkpoint_path=None, |
| | output_dir=None, |
| | freeze_modules=None, |
| | ): |
| | """Configure the recipe""" |
| | if pretrain: |
| | recipe = avlm.avlm_8b.pretrain_recipe( |
| | dir=output_dir, |
| | name="avlm_pretrain", |
| | num_nodes=nodes, |
| | num_gpus_per_node=gpus_per_node, |
| | language_model_from_pretrained=language_model_from_pretrained, |
| | freeze_modules=freeze_modules, |
| | ) |
| | else: |
| | recipe = avlm.avlm_8b.finetune_recipe( |
| | checkpoint_path=checkpoint_path, |
| | dir=output_dir, |
| | name="avlm_finetune", |
| | num_nodes=nodes, |
| | num_gpus_per_node=gpus_per_node, |
| | freeze_modules=freeze_modules, |
| | peft_scheme="none", |
| | ) |
| | recipe.trainer.max_steps = 20 |
| | recipe.trainer.val_check_interval = 20 |
| | return recipe |
| |
|
| |
|
| | def local_executor_torchrun(nodes: int = 1, devices: int = 8) -> run.LocalExecutor: |
| | |
| | |
| | env_vars = {} |
| |
|
| | executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun", env_vars=env_vars) |
| |
|
| | return executor |
| |
|
| |
|
| | def run_pretraining(language_model_from_pretrained=None, checkpoint_path=None, output_dir=None, freeze_modules=None): |
| | |
| | recipe = configure_recipe( |
| | pretrain=True, |
| | language_model_from_pretrained=language_model_from_pretrained, |
| | checkpoint_path=checkpoint_path, |
| | output_dir=output_dir, |
| | freeze_modules=freeze_modules, |
| | ) |
| | executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices) |
| |
|
| | run.run(recipe, executor=executor) |
| |
|
| |
|
| | def run_finetuning(checkpoint_path=None, output_dir=None, freeze_modules=None): |
| | |
| | recipe = configure_recipe( |
| | pretrain=False, checkpoint_path=checkpoint_path, output_dir=output_dir, freeze_modules=freeze_modules |
| | ) |
| | executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices) |
| |
|
| | run.run(recipe, executor=executor) |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| |
|
| | parser = argparse.ArgumentParser(description="Script with two optional arguments.") |
| | parser.add_argument( |
| | "--training_mode", |
| | type=str, |
| | required=True, |
| | choices=["pretrain", "finetune"], |
| | help="Training mode - either 'pretrain' or 'finetune'", |
| | ) |
| | parser.add_argument( |
| | "--language_model_from_pretrained", |
| | type=str, |
| | default=None, |
| | required=False, |
| | help="Path to pretrained language model (optional).", |
| | ) |
| | parser.add_argument( |
| | "--checkpoint_path", type=str, default=None, required=False, help="Path to checkpoint (optional)." |
| | ) |
| | parser.add_argument( |
| | "--output_dir", type=str, default="./outputs/checkpoints/avlm", help="Path to store checkpoints (optional)." |
| | ) |
| | parser.add_argument("--unfreeze_language_model", action="store_true", help="Unfreeze language model (optional).") |
| | parser.add_argument("--unfreeze_vision_model", action="store_true", help="Unfreeze vision model (optional).") |
| | parser.add_argument("--unfreeze_audio_model", action="store_true", help="Unfreeze audio model (optional).") |
| | parser.add_argument( |
| | "--unfreeze_vision_projection", action="store_true", help="Unfreeze vision projection (optional)." |
| | ) |
| | parser.add_argument( |
| | "--unfreeze_audio_projection", action="store_true", help="Unfreeze audio projection (optional)." |
| | ) |
| | args = parser.parse_args() |
| |
|
| | |
| | freeze_modules = { |
| | "freeze_language_model": not args.unfreeze_language_model, |
| | "freeze_vision_model": not args.unfreeze_vision_model, |
| | "freeze_audio_model": not args.unfreeze_audio_model, |
| | "freeze_vision_projection": not args.unfreeze_vision_projection, |
| | "freeze_audio_projection": not args.unfreeze_audio_projection, |
| | } |
| | if args.training_mode == "pretrain": |
| | run_pretraining( |
| | language_model_from_pretrained=args.language_model_from_pretrained, |
| | checkpoint_path=args.checkpoint_path, |
| | output_dir=args.output_dir, |
| | freeze_modules=freeze_modules, |
| | ) |
| | elif args.training_mode == "finetune": |
| | run_finetuning( |
| | checkpoint_path=args.checkpoint_path, |
| | output_dir=args.output_dir, |
| | freeze_modules=freeze_modules, |
| | ) |
| |
|