#!/usr/bin/env python # finetune_whisper_mix_datasets.py """ Fine-tune openai/whisper-large-v3 on mixed datasets from different languages: - FLEURS Cebuano (ceb_ph) - FLEURS Khmer (km_kh) - Switchboard1 English - WenetSpeech Chinese - Eng-Indon-CS - Eng-Malay-CS Based on the Hugging Face blog: https://huggingface.co/blog/fine-tune-whisper To run this script on multiple GPUs, you have several options: 1. **Automatic Multi-GPU (DataParallel-style):** python finetune_whisper_mix_datasets.py The script will automatically detect and use all available GPUs. 2. **Distributed Training with torchrun (Recommended for 2+ GPUs):** torchrun --nproc_per_node=2 finetune_whisper_mix_datasets.py This uses DistributedDataParallel which is more efficient. 3. **Distributed Training with accelerate (Alternative):** accelerate launch --num_processes=2 finetune_whisper_mix_datasets.py Requires: pip install accelerate Note: With 2 GPUs, the effective batch size becomes: per_device_batch_size * num_gpus * gradient_accumulation_steps = 24 * 2 * 1 = 48 (compared to 32 with single GPU) CPU Core Limiting: The script automatically limits CPU usage to 20 cores using environment variables. You can also set these manually before running: export OMP_NUM_THREADS=20 export MKL_NUM_THREADS=20 export NUMEXPR_NUM_THREADS=20 python finetune_whisper_mix_datasets.py """ import os import random import io import yaml import argparse from itertools import chain import torch.distributed as dist # Load configuration from YAML file def load_config(config_path): with open(config_path, 'r') as file: return yaml.safe_load(file) # Parse command line arguments parser = argparse.ArgumentParser(description='Fine-tune Whisper on mixed datasets') parser.add_argument('--config', type=str, default='config.yaml', help='Path to configuration YAML file') args = parser.parse_args() # Load configuration config = load_config(args.config) # Set environment variables from config env_config = config['environment'] os.environ["OMP_NUM_THREADS"] = env_config['omp_num_threads'] os.environ["MKL_NUM_THREADS"] = env_config['mkl_num_threads'] os.environ["OPENBLAS_NUM_THREADS"] = env_config['openblas_num_threads'] os.environ["VECLIB_MAXIMUM_THREADS"] = env_config['veclib_maximum_threads'] os.environ["NUMEXPR_NUM_THREADS"] = env_config['numexpr_num_threads'] os.environ["TOKENIZERS_PARALLELISM"] = env_config['tokenizers_parallelism'] os.environ["TRANSFORMERS_NO_TF"] = env_config['transformers_no_tf'] import torch from datasets import load_dataset, Audio, concatenate_datasets, Dataset from torch.utils.data import Dataset as TorchDataset from transformers import ( WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer, ) import ipdb import evaluate import numpy as np import ipdb # Multi-GPU setup if torch.cuda.device_count() > 1: print(f"Setting up for {torch.cuda.device_count()} GPUs") # Enable distributed training environment variables if not already set if "LOCAL_RANK" not in os.environ: os.environ["LOCAL_RANK"] = "0" if "WORLD_SIZE" not in os.environ: os.environ["WORLD_SIZE"] = str(torch.cuda.device_count()) from dataclasses import dataclass from typing import Any, Dict, List, Union class WhisperOnTheFlyDataset(TorchDataset): """Custom dataset that preprocesses audio on-the-fly during training""" def __init__(self, dataset, processors, main_processor, max_target_length, audio_config): self.dataset = dataset self.processors = processors self.main_processor = main_processor self.max_target_length = max_target_length self.sampling_rate = audio_config['sampling_rate'] def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] # Process audio audio_sample = item["audio"] audio_data = self._process_audio(audio_sample) # Extract with main processor inputs = self.main_processor.feature_extractor( audio_data, sampling_rate=self.sampling_rate, return_tensors="pt" ) # Process text with appropriate processor lang = item["language"] if lang in ["cebuano", "khmer"]: text = item["transcription"] else: # english, chinese text = item["text"] # Map language to Whisper language token ID lang_id_map = { "english": 50259, # <|en|> "chinese": 50260, # <|zh|> "indonesian": 50275, # <|id|> "malay": 50282, # <|ms|> "khmer": 50323, # <|km|> "cebuano": 50348, # <|tl|> (using Tagalog as fallback for Cebuano) } # Get language token ID lang_token_id = lang_id_map.get(lang) # Tokenize with appropriate processor if lang == "cebuano": labels = self.processors["cebuano"].tokenizer( text, return_tensors="pt", padding=False, truncation=True, max_length=self.max_target_length ) elif lang == "khmer": labels = self.processors["khmer"].tokenizer( text, return_tensors="pt", padding=False, truncation=True, max_length=self.max_target_length ) elif lang == "english": labels = self.processors["english"].tokenizer( text, return_tensors="pt", padding=False ) elif lang == "chinese": labels = self.processors["chinese"].tokenizer( text, return_tensors="pt", padding=False ) elif lang == "indonesian": labels = self.processors["indonesian"].tokenizer( text, return_tensors="pt", padding=False ) else: # Malay labels = self.processors["malay"].tokenizer( text, return_tensors="pt", padding=False ) return { "input_features": inputs.input_features.squeeze(0), "labels": labels.input_ids.squeeze(0), "language": lang, "language_token_id": lang_token_id } def _process_audio(self, audio_sample): """Process audio sample into numpy array""" import librosa if isinstance(audio_sample, dict): if "array" in audio_sample: return audio_sample["array"] elif "bytes" in audio_sample and audio_sample["bytes"] is not None: audio_array, _ = librosa.load(io.BytesIO(audio_sample["bytes"]), sr=self.sampling_rate) return audio_array elif "path" in audio_sample: audio_array, _ = librosa.load(audio_sample["path"], sr=self.sampling_rate) return audio_array else: raise ValueError(f"Unknown audio dict format: {audio_sample.keys()}") elif isinstance(audio_sample, str): audio_array, _ = librosa.load(audio_sample, sr=self.sampling_rate) return audio_array else: return audio_sample @dataclass class DataCollatorSpeechSeq2SeqWithPadding: processor: Any decoder_start_token_id: int def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: # split inputs and labels since they have to be of different lengths and need different padding methods # first treat the audio inputs by simply returning torch tensors input_features = [{"input_features": feature["input_features"]} for feature in features] batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") # get the tokenized label sequences label_features = [{"input_ids": feature["labels"]} for feature in features] # pad the labels to max length labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") # Get original labels before modification labels = labels_batch["input_ids"] # Task ID is fixed for transcription (50360) task_token_id = 50360 # transcribe task # Create a tensor to store new labels with language and task tokens prepended batch_size = labels.size(0) seq_length = labels.size(1) # Add 2 tokens (lang token and task token) at the beginning new_labels = torch.full((batch_size, seq_length + 2), self.processor.tokenizer.pad_token_id, dtype=labels.dtype, device=labels.device) # Add the language token and task token at the beginning for each sample for i, feature in enumerate(features): # Add SOT token as first token (50258) # new_labels[i, 0] = 50258 # SOT token # Add language token as second token if available if "language_token_id" in feature and feature["language_token_id"] is not None: new_labels[i, 0] = feature["language_token_id"] # Add task token as third token new_labels[i, 1] = task_token_id # Copy the original label tokens after the special tokens token_length = min(seq_length, labels.size(1)) new_labels[i, 2:2+token_length] = labels[i, :token_length] # Create new attention mask for padded sequences new_attention_mask = torch.zeros_like(new_labels, dtype=torch.long) for i in range(batch_size): # Find the last non-padding token in the original sequence orig_seq_len = (labels[i] != self.processor.tokenizer.pad_token_id).sum().item() # Set attention mask to 1 for all tokens up to the end of the sequence + 2 special tokens new_attention_mask[i, :orig_seq_len+2] = 1 # Replace padding with -100 to ignore loss correctly new_labels = new_labels.masked_fill(new_attention_mask.ne(1), -100) # if bos token is appended in previous tokenization step, # cut bos token here as it's append later anyways if (new_labels[:, 0] == self.decoder_start_token_id).all().cpu().item(): new_labels = new_labels[:, 1:] batch["labels"] = new_labels batch["attention_mask"] = new_attention_mask return batch # → Choose device (GPU if available) device = "cuda" if torch.cuda.is_available() else "cpu" # Extract configuration values MODEL_CHECKPOINT = config['model']['checkpoint'] OUTPUT_DIR = config['output']['output_dir'] MAX_TARGET_LENGTH = config['model']['max_target_length'] # CPU usage configuration for dataset preprocessing MAX_CPU_CORES = config['environment']['max_cpu_cores'] TEST_CPU_CORES = config['environment']['test_cpu_cores'] # Language configurations for each dataset DATASET_CONFIGS = config['languages'] print("Loading datasets...") # Load datasets for each language dynamically based on configuration datasets = {} dataset_configs = config['datasets'] audio_config = config['audio'] # Get list of enabled languages from both languages and datasets config enabled_languages = set(config['languages'].keys()) & set(config['datasets'].keys()) print(f"Enabled languages: {list(enabled_languages)}") def load_fleurs_dataset(lang_name, lang_config, dataset_config): """Load FLEURS dataset for a language""" print(f"Loading FLEURS {lang_name.title()}...") lang_datasets = load_dataset( dataset_config['source'], dataset_config['language_code'], split={k: v for k, v in dataset_config['splits'].items()}, trust_remote_code=dataset_config['trust_remote_code'] ) # DON'T decode audio yet - keep it compressed until preprocessing for split in dataset_config['splits'].keys(): lang_datasets[split] = lang_datasets[split].cast_column("audio", Audio(sampling_rate=audio_config['sampling_rate'], decode=False)) # Use subset of training data if specified if 'train_subset_ratio' in lang_config: train_subset_ratio = lang_config['train_subset_ratio'] lang_datasets["train"] = lang_datasets["train"].train_test_split(test_size=1-train_subset_ratio, seed=config['data_processing']['seed'])["train"] return lang_datasets def load_simple_dataset(lang_name, dataset_config): """Load simple dataset with train/validation/test splits""" print(f"Loading {lang_name.title()}...") lang_dataset = load_dataset(dataset_config['source'], split={k: v for k, v in dataset_config['splits'].items()}) return lang_dataset def load_english_dataset(lang_config, dataset_config): """Load English dataset with custom train/validation split""" print("Loading English...") swb_train = load_dataset(dataset_config['train_dataset'], split=dataset_config['train_split'], streaming=dataset_config['streaming']) swb_test = load_dataset(dataset_config['test_dataset'], split=dataset_config['test_split'], streaming=dataset_config['streaming']) # Split into train/validation validation_size = lang_config['validation_size'] swb_val = swb_train.take(validation_size) swb_train = swb_train.skip(validation_size) return { "train": swb_train, "validation": swb_val, "test": swb_test } def load_chinese_dataset(dataset_config): """Load Chinese dataset with multiple test splits""" print("Loading Chinese...") wenet_train = load_dataset(dataset_config['train_dataset'], streaming=dataset_config['streaming']) wenet_valid = load_dataset(dataset_config['validation_dataset'], dataset_config['validation_config'], split="validation", streaming=dataset_config['streaming'], trust_remote_code=dataset_config['trust_remote_code']) wenet_testnet = load_dataset(dataset_config['test_net_dataset'], dataset_config['test_net_config'], split="test", streaming=dataset_config['streaming'], trust_remote_code=dataset_config['trust_remote_code']) wenet_testmeeting = load_dataset(dataset_config['test_meeting_dataset'], dataset_config['test_meeting_config'], split="test", streaming=dataset_config['streaming'], trust_remote_code=dataset_config['trust_remote_code']) return { "train": wenet_train["train"], "validation": wenet_valid, "test_net": wenet_testnet, "test_meeting": wenet_testmeeting } # Load datasets for each enabled language for lang in enabled_languages: lang_config = config['languages'][lang] dataset_config = dataset_configs[lang] if lang in ['cebuano', 'khmer']: # FLEURS datasets datasets[lang] = load_fleurs_dataset(lang, lang_config, dataset_config) elif lang == 'english': # English with custom validation split datasets[lang] = load_english_dataset(lang_config, dataset_config) elif lang == 'chinese': # Chinese with multiple test splits datasets[lang] = load_chinese_dataset(dataset_config) elif lang in ['indonesian', 'malay']: # Simple datasets with standard splits datasets[lang] = load_simple_dataset(lang, dataset_config) else: print(f"Warning: Unknown language {lang}, treating as simple dataset") datasets[lang] = load_simple_dataset(lang, dataset_config) print("Setting up processors...") # Create processors for each enabled language processors = {} for lang in enabled_languages: lang_config = config['languages'][lang] processors[lang] = WhisperProcessor.from_pretrained( MODEL_CHECKPOINT, language=lang_config["whisper_language"] ) # Use the first available processor as the main one, preferring English if available if "english" in processors: main_processor = processors["english"] elif processors: main_processor = processors[list(processors.keys())[0]] else: raise ValueError("No processors created. Check your language configuration.") model = WhisperForConditionalGeneration.from_pretrained(MODEL_CHECKPOINT) # Multi-GPU handling local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(local_rank) print(f"Using GPU {local_rank} for training") dist.init_process_group(backend="nccl") if torch.cuda.device_count() > 1: print(f"Using {torch.cuda.device_count()} GPUs for training") # The model will be automatically distributed by the Trainer model.to(torch.device("cuda", local_rank)) else: model.to(torch.device("cuda", local_rank)) print("Adding language labels to raw datasets...") # Remove existing language columns and add our own consistent language labels for each enabled language for lang in enabled_languages: lang_datasets = datasets[lang] # Handle different dataset structures if isinstance(lang_datasets, dict): # Dataset with explicit splits (train/validation/test) for split_name, split_dataset in lang_datasets.items(): if split_dataset is not None: # Remove existing language column if it exists columns_to_remove = [col for col in split_dataset.column_names if col.lower() in ["language", "lang"]] if columns_to_remove: print(f"Removing existing language column(s) {columns_to_remove} from {lang} {split_name}") datasets[lang][split_name] = split_dataset.remove_columns(columns_to_remove) # Add our consistent language label datasets[lang][split_name] = datasets[lang][split_name].add_column("language", [lang] * len(datasets[lang][split_name])) else: # Single dataset object - this shouldn't happen with current structure but handle gracefully print(f"Warning: Unexpected dataset structure for {lang}") continue print("Combining raw datasets before preprocessing...") # Ensure all datasets have compatible schemas before concatenation def standardize_dataset_schema(dataset, dataset_name): """Standardize dataset schema to ensure compatibility for concatenation""" print(f"Standardizing schema for {dataset_name}...") # Keep audio compressed until preprocessing - only set sampling rate if "audio" in dataset.column_names: print(f" Setting audio feature type to {audio_config['sampling_rate']}Hz (compressed) for {dataset_name}") dataset = dataset.cast_column("audio", Audio(sampling_rate=audio_config['sampling_rate'], decode=False)) # Remove problematic columns that might have different types columns_to_remove = [] for col in dataset.column_names: if col in config['data_processing']['columns_to_remove']: columns_to_remove.append(col) if columns_to_remove: print(f" Removing incompatible columns: {columns_to_remove}") dataset = dataset.remove_columns(columns_to_remove) return dataset # Standardize all training datasets dynamically print("Standardizing training datasets...") raw_train_datasets = [] for lang in enabled_languages: if "train" in datasets[lang]: std_dataset = standardize_dataset_schema(datasets[lang]["train"], f"{lang}_train") raw_train_datasets.append(std_dataset) # Standardize validation datasets dynamically print("Standardizing validation datasets...") raw_val_datasets = [] for lang in enabled_languages: if "validation" in datasets[lang]: std_dataset = standardize_dataset_schema(datasets[lang]["validation"], f"{lang}_val") raw_val_datasets.append(std_dataset) # Combine datasets if we have any if raw_train_datasets: print("Combining training datasets...") combined_raw_train = concatenate_datasets(raw_train_datasets) combined_raw_train = combined_raw_train.shuffle(seed=config['data_processing']['seed']) else: raise ValueError("No training datasets found. Check your configuration.") if raw_val_datasets: print("Combining validation datasets...") combined_raw_val = concatenate_datasets(raw_val_datasets) combined_raw_val = combined_raw_val.shuffle(seed=config['data_processing']['seed']) else: print("Warning: No validation datasets found. Training without validation.") combined_raw_val = None print("Creating on-the-fly datasets (no preprocessing stored to disk)...") # Create on-the-fly datasets instead of preprocessing and storing # Create on-the-fly datasets instead of preprocessing and storing combined_train_dataset = WhisperOnTheFlyDataset( combined_raw_train, processors, main_processor, MAX_TARGET_LENGTH, audio_config ) # Only create validation dataset if we have validation data if combined_raw_val is not None: combined_val_dataset = WhisperOnTheFlyDataset( combined_raw_val, processors, main_processor, MAX_TARGET_LENGTH, audio_config ) else: combined_val_dataset = None print("Creating on-the-fly test datasets...") # Create on-the-fly test datasets dynamically processed_datasets = {} for lang in enabled_languages: processed_datasets[lang] = {} # Handle different test split structures for different languages if lang == "chinese": # Chinese has multiple test splits if "test_net" in datasets[lang]: processed_datasets[lang]["test_net"] = WhisperOnTheFlyDataset( datasets[lang]["test_net"], processors, main_processor, MAX_TARGET_LENGTH, audio_config ) if "test_meeting" in datasets[lang]: processed_datasets[lang]["test_meeting"] = WhisperOnTheFlyDataset( datasets[lang]["test_meeting"], processors, main_processor, MAX_TARGET_LENGTH, audio_config ) else: # Standard test split if "test" in datasets[lang]: processed_datasets[lang]["test"] = WhisperOnTheFlyDataset( datasets[lang]["test"], processors, main_processor, MAX_TARGET_LENGTH, audio_config ) # Data Collator data_collator = DataCollatorSpeechSeq2SeqWithPadding( processor=main_processor, decoder_start_token_id=model.config.decoder_start_token_id, ) # Metrics: WER & CER (using Hugging Face Evaluate) wer_metric = evaluate.load("wer") cer_metric = evaluate.load("cer") def compute_metrics(pred): """ Compute WER and CER metrics for predictions """ pred_ids = pred.predictions # Decode predictions, skipping special tokens pred_str = main_processor.batch_decode(pred_ids, skip_special_tokens=True) label_ids = pred.label_ids # Replace -100 with pad token ID for decoding label_ids[label_ids == -100] = main_processor.tokenizer.pad_token_id # Decode reference texts, also skipping special tokens ref_str = main_processor.batch_decode(label_ids, skip_special_tokens=True) # lowercase & strip pred_str = [s.lower().strip() for s in pred_str] ref_str = [s.lower().strip() for s in ref_str] wer_score = wer_metric.compute(predictions=pred_str, references=ref_str) cer_score = cer_metric.compute(predictions=pred_str, references=ref_str) # Combine metrics metrics = {"wer": wer_score, "cer": cer_score} # Log example predictions if len(pred_str) > 0: num_examples = min(3, len(pred_str)) for i in range(num_examples): print(f"Example {i}:") print(f" Reference: {ref_str[i]}") print(f" Prediction: {pred_str[i]}") return metrics # Check for multi-GPU setup num_gpus = torch.cuda.device_count() print(f"Number of available GPUs: {num_gpus}") # Get training configuration training_config = config['training'] # Adjust batch size and gradient accumulation for multi-GPU if num_gpus > 1: # With multiple GPUs, use multi-GPU configuration gpu_config = training_config['multi_gpu'] per_device_batch_size = gpu_config['per_device_train_batch_size'] per_device_eval_batch_size = gpu_config['per_device_eval_batch_size'] gradient_accumulation_steps = gpu_config['gradient_accumulation_steps'] print(f"Multi-GPU training detected. Using {num_gpus} GPUs.") else: # Single GPU configuration gpu_config = training_config['single_gpu'] per_device_batch_size = gpu_config['per_device_train_batch_size'] per_device_eval_batch_size = gpu_config['per_device_eval_batch_size'] gradient_accumulation_steps = gpu_config['gradient_accumulation_steps'] print("Single GPU training.") # Training Arguments training_args = Seq2SeqTrainingArguments( output_dir=OUTPUT_DIR, per_device_train_batch_size=per_device_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, learning_rate=training_config['learning_rate'], warmup_steps=training_config['warmup_steps'], max_steps=training_config['max_steps'], gradient_checkpointing=training_config['gradient_checkpointing'], fp16=training_config['fp16'], eval_strategy=training_config['eval_strategy'], per_device_eval_batch_size=per_device_eval_batch_size, predict_with_generate=training_config['predict_with_generate'], generation_max_length=training_config['generation_max_length'], save_steps=training_config['save_steps'], eval_steps=training_config['eval_steps'], logging_steps=training_config['logging_steps'], report_to=training_config['report_to'], load_best_model_at_end=training_config['load_best_model_at_end'], metric_for_best_model=training_config['metric_for_best_model'], greater_is_better=training_config['greater_is_better'], push_to_hub=training_config['push_to_hub'], hub_private_repo=training_config['hub_private_repo'], # Always push to private repo save_total_limit=training_config['save_total_limit'], # Multi-GPU specific settings dataloader_drop_last=training_config['dataloader_drop_last'], ddp_find_unused_parameters=training_config['ddp_find_unused_parameters'], ) # Initialize Seq2SeqTrainer trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=combined_train_dataset, eval_dataset=combined_val_dataset, data_collator=data_collator, tokenizer=main_processor.feature_extractor, compute_metrics=compute_metrics, ) def evaluate_on_test_sets(): """Evaluate the model on all test sets from enabled languages""" print("\n" + "="*60) print("EVALUATING ON ALL TEST SETS") print("="*60) results = {} # Define language-specific generation parameters lang_id_map = { "english": 50259, # <|en|> "chinese": 50260, # <|zh|> "indonesian": 50275, # <|id|> "malay": 50282, # <|ms|> "khmer": 50323, # <|km|> "cebuano": 50348, # <|tl|> (using Tagalog as fallback for Cebuano) } for lang in enabled_languages: if lang in processed_datasets: lang_results = {} # Set language-specific generation parameters lang_token_id = lang_id_map.get(lang) task_token_id = 50360 # transcribe task # Define forced decoder IDs for generation if language is supported forced_decoder_ids = None if lang_token_id: forced_decoder_ids = [[1, lang_token_id], [2, task_token_id]] print(f"Using forced_decoder_ids for {lang}: {forced_decoder_ids}") if lang == "chinese": # Chinese has multiple test splits if "test_net" in processed_datasets[lang]: print(f"\n***** Evaluating on WenetSpeech Chinese TEST_NET *****") chi_testnet_metrics = trainer.predict( processed_datasets[lang]["test_net"], metric_key_prefix=f"test_{lang}_net", forced_decoder_ids=forced_decoder_ids ) print(f"Chinese TEST_NET WER: {chi_testnet_metrics.metrics[f'test_{lang}_net_wer']*100:.2f}%") print(f"Chinese TEST_NET CER: {chi_testnet_metrics.metrics[f'test_{lang}_net_cer']*100:.2f}%") lang_results["test_net"] = chi_testnet_metrics.metrics if "test_meeting" in processed_datasets[lang]: print(f"\n***** Evaluating on WenetSpeech Chinese TEST_MEETING *****") chi_testmeet_metrics = trainer.predict( processed_datasets[lang]["test_meeting"], metric_key_prefix=f"test_{lang}_meeting", forced_decoder_ids=forced_decoder_ids ) print(f"Chinese TEST_MEETING WER: {chi_testmeet_metrics.metrics[f'test_{lang}_meeting_wer']*100:.2f}%") print(f"Chinese TEST_MEETING CER: {chi_testmeet_metrics.metrics[f'test_{lang}_meeting_cer']*100:.2f}%") lang_results["test_meeting"] = chi_testmeet_metrics.metrics else: # Standard test split if "test" in processed_datasets[lang]: print(f"\n***** Evaluating on {lang.title()} test set *****") test_metrics = trainer.predict( processed_datasets[lang]["test"], metric_key_prefix=f"test_{lang}", forced_decoder_ids=forced_decoder_ids ) print(f"{lang.title()} Test WER: {test_metrics.metrics[f'test_{lang}_wer']*100:.2f}%") print(f"{lang.title()} Test CER: {test_metrics.metrics[f'test_{lang}_cer']*100:.2f}%") lang_results["test"] = test_metrics.metrics results[lang] = lang_results # Summary print("\n" + "="*60) print("SUMMARY OF ALL TEST RESULTS") print("="*60) for lang in enabled_languages: if lang in results: if lang == "chinese": if "test_net" in results[lang]: wer = results[lang]["test_net"][f"test_{lang}_net_wer"] * 100 cer = results[lang]["test_net"][f"test_{lang}_net_cer"] * 100 print(f"Chinese-NET: WER={wer:.2f}% | CER={cer:.2f}%") if "test_meeting" in results[lang]: wer = results[lang]["test_meeting"][f"test_{lang}_meeting_wer"] * 100 cer = results[lang]["test_meeting"][f"test_{lang}_meeting_cer"] * 100 print(f"Chinese-MTG: WER={wer:.2f}% | CER={cer:.2f}%") else: if "test" in results[lang]: wer = results[lang]["test"][f"test_{lang}_wer"] * 100 cer = results[lang]["test"][f"test_{lang}_cer"] * 100 print(f"{lang.title():12}: WER={wer:.2f}% | CER={cer:.2f}%") return results if __name__ == "__main__": print(f"Total training samples: {len(combined_train_dataset)}") print(f"Total validation samples: {len(combined_val_dataset)}") print("Starting training...") # Fine-tune the model trainer.train() # Evaluate on all test sets # evaluate_on_test_sets()