Training and fine-tuning example

#13
by shimohan - opened

I found that the training example config in NeMo (https://github.com/NVIDIA-NeMo/NeMo/blob/main/examples/speechlm2/salm_train.py, https://github.com/NVIDIA-NeMo/NeMo/blob/main/examples/speechlm2/conf/salm.yaml) is different from the one used for canary-qwen. Even the dataset config does not match. As a result, when I use the NeMo example to fine-tune canary-qwen, the performance is not good. Could you provide a training example that matches this canary-qwen model?

Hi there,
Yes it seems you're right. Some of the hparams of Canary-Qwen are available in its config.json: https://huggingface.co/nvidia/canary-qwen-2.5b/blob/main/config.json

The actual config used is this:

model:
  # Every name/path here starting with 'pretrained' is used to initialize the model weights.
  pretrained_llm: Qwen/Qwen3-1.7B
  pretrained_asr: nvidia/canary-1b-flash

  pretrained_weights: True  # When False, we use pretrained_name to load the architecture, but with random init

  # Regexp (re.compile) patterns matching parameters to be frozen.
  freeze_params:
    # Frozen LLM
    - "^llm\\..+$"  # LLM
    - "^embed_tokens\\..+$"  # LLM embedding is moved
    # Frozen pretrained ASR (only the modality adapter layers are trainable)
    # - "^perception\\.preprocessor\\..+$"
    # - "^perception\\.encoder\\..+$"
  prevent_freeze_params: []  # Use to make specific submodules trainable; overrides freeze_params

  lora:
    task_type: CAUSAL_LM
    r: 128
    lora_alpha: 256
    lora_dropout: 0.01
    target_modules: ["q_proj", "v_proj"]

  prompt_format: qwen
  audio_locator_tag: "<|audioplaceholder|>"  # placeholder token for audio turn is expected

  perception:
     target: nemo.collections.speechlm2.modules.perception.AudioPerceptionModule
     output_dim: 2048
     modality_adapter:
       _target_: nemo.collections.speechlm2.modules.perception.IdentityConnector
       d_model: 1024

#     spec_augment:
#       _target_: nemo.collections.asr.modules.SpectrogramAugmentation
#       freq_masks: 2 # set to zero to disable it
#       time_masks: 10 # set to zero to disable it
#       freq_width: 27
#       time_width: 5  # 5 frames = 50ms

  optimizer:
    _target_: torch.optim.AdamW
    lr: 5e-4
    betas: [0.9, 0.98]
    weight_decay: 1e-3
    foreach: true

  lr_scheduler:
    _target_: nemo.core.optim.lr_scheduler.CosineAnnealing
    warmup_steps: 1000
    min_lr: 1e-6
    max_steps: ${trainer.max_steps}

trainer:
  devices: -1
  accelerator: gpu
  num_nodes: 1
  precision: bf16-true
  logger: False # logger provided by exp_manager
  enable_checkpointing: False
  use_distributed_sampler: False
  max_steps: 100000
  limit_train_batches: 5000
  val_check_interval: ${trainer.limit_train_batches}
  limit_val_batches: 10
  log_every_n_steps: 10
  num_sanity_val_steps: 1
  gradient_clip_val: 1.0
  accumulate_grad_batches: 1
  strategy:
    _target_: lightning.pytorch.strategies.DDPStrategy
    gradient_as_bucket_view: true
    find_unused_parameters: true

data:
  train_ds:
    sample_rate: 16000
    prompt_format: ${model.prompt_format}
    token_equivalent_duration: 0.08
    text_field: answer
    lang_field: target_lang
    input_cfg:
      - type: lhotse_as_conversation
        input_cfg: /path/to/input_cfg.yaml
        audio_locator_tag: ${model.audio_locator_tag}
        tags:
          context: 'Transcribe the following:'
    seed: 42
    shuffle: true
    shard_seed: "randomized"
    num_workers: 4

    use_multimodal_sampling: true
    min_duration: 0.1
    min_tokens: 2
    max_tokens: 1024
    bucket_duration_bins: [99,110,117,124,184,247,324,391,457,520,555,579,600,618,638,1024]
    bucket_batch_size: [69, 64, 60, 40, 28, 22, 16, 14, 12, 11, 10, 6, 4, 4, 4, 2]
    use_bucketing: true
    num_buckets: 16
    bucket_buffer_size: 20000

  validation_ds:
    # The entries under 'datasets' are a list of separate dataloaders.
    # The structure is <dataset-name>: {<dataloader-dict-config>}
    # They inherit all settings from validation_ds, but can individually override them.
    prompt_format: ${model.prompt_format}
    token_equivalent_duration: 0.08
    datasets:
      devset_nemo_manifest:
        input_cfg:
        - audio_locator_tag: ${model.audio_locator_tag}
          manifest_filepath: /path/to/devset_nemo_manifest.json
          tags:
            context: 'Transcribe the following:'
          type: lhotse_as_conversation

    sample_rate: 16000
    batch_size: 8
    seed: 42
    shard_seed: "randomized"

exp_manager:
   exp_dir: null
   explicit_log_dir: canary-qwen-2.5b-results/
   name: canary-qwen-2.5b
   create_tensorboard_logger: false
   create_checkpoint_callback: true
   use_datetime_version: true
   max_time_per_run: 00:03:50:00

   resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
   # you need to set these two to True to continue the training
   resume_if_exists: true
   resume_ignore_no_checkpoint: true

   # You may use this section to create a W&B logger
   create_wandb_logger: true
   wandb_logger_kwargs:
     name: canary-qwen-2.5b-release
     project: canary-qwen-2.5b
     resume: true

   checkpoint_callback_params:
     filename: "{step}"
     monitor: val_acc
     mode: max
     every_n_train_steps: null
     every_n_epochs: 1
     save_top_k: 1
     always_save_nemo: false

It was ran as a series of consecutive 4h training jobs until 100k steps is reached (~1 day on 32xA100 80GB), if you do the same, you will need to set a different random seed (train_ds.seed) for each of those. You might also want to remove text_field and lang_field options if you have standard NeMo ASR manifests (this was using some custom formatted data).

If you wish to finetune you will need to set pretrained_weights: False and add 1 LOC to manually load canary-qwen pretrained weights before the the training starts (this capability isn't yet exposed into the YAML config). Hopefully this helps you get started.

I'll try to push a complete example/tutorial but I can't make any promises about the timeline for this.

Sign up or log in to comment