File size: 31,717 Bytes
cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 b87d4ca cc1f176 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 |
#!/usr/bin/env python
# finetune_whisper_mix_datasets.py
"""
Fine-tune openai/whisper-large-v3 on mixed datasets from different languages:
- FLEURS Cebuano (ceb_ph)
- FLEURS Khmer (km_kh)
- Switchboard1 English
- WenetSpeech Chinese
- Eng-Indon-CS
- Eng-Malay-CS
Based on the Hugging Face blog: https://huggingface.co/blog/fine-tune-whisper
To run this script on multiple GPUs, you have several options:
1. **Automatic Multi-GPU (DataParallel-style):**
python finetune_whisper_mix_datasets.py
The script will automatically detect and use all available GPUs.
2. **Distributed Training with torchrun (Recommended for 2+ GPUs):**
torchrun --nproc_per_node=2 finetune_whisper_mix_datasets.py
This uses DistributedDataParallel which is more efficient.
3. **Distributed Training with accelerate (Alternative):**
accelerate launch --num_processes=2 finetune_whisper_mix_datasets.py
Requires: pip install accelerate
Note: With 2 GPUs, the effective batch size becomes:
per_device_batch_size * num_gpus * gradient_accumulation_steps
= 24 * 2 * 1 = 48 (compared to 32 with single GPU)
CPU Core Limiting:
The script automatically limits CPU usage to 20 cores using environment variables.
You can also set these manually before running:
export OMP_NUM_THREADS=20
export MKL_NUM_THREADS=20
export NUMEXPR_NUM_THREADS=20
python finetune_whisper_mix_datasets.py
"""
import os
import random
import io
import yaml
import argparse
from itertools import chain
import torch.distributed as dist
# Load configuration from YAML file
def load_config(config_path):
with open(config_path, 'r') as file:
return yaml.safe_load(file)
# Parse command line arguments
parser = argparse.ArgumentParser(description='Fine-tune Whisper on mixed datasets')
parser.add_argument('--config', type=str, default='config.yaml',
help='Path to configuration YAML file')
args = parser.parse_args()
# Load configuration
config = load_config(args.config)
# Set environment variables from config
env_config = config['environment']
os.environ["OMP_NUM_THREADS"] = env_config['omp_num_threads']
os.environ["MKL_NUM_THREADS"] = env_config['mkl_num_threads']
os.environ["OPENBLAS_NUM_THREADS"] = env_config['openblas_num_threads']
os.environ["VECLIB_MAXIMUM_THREADS"] = env_config['veclib_maximum_threads']
os.environ["NUMEXPR_NUM_THREADS"] = env_config['numexpr_num_threads']
os.environ["TOKENIZERS_PARALLELISM"] = env_config['tokenizers_parallelism']
os.environ["TRANSFORMERS_NO_TF"] = env_config['transformers_no_tf']
import torch
from datasets import load_dataset, Audio, concatenate_datasets, Dataset
from torch.utils.data import Dataset as TorchDataset
from transformers import (
WhisperProcessor,
WhisperForConditionalGeneration,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
)
import ipdb
import evaluate
import numpy as np
import ipdb
# Multi-GPU setup
if torch.cuda.device_count() > 1:
print(f"Setting up for {torch.cuda.device_count()} GPUs")
# Enable distributed training environment variables if not already set
if "LOCAL_RANK" not in os.environ:
os.environ["LOCAL_RANK"] = "0"
if "WORLD_SIZE" not in os.environ:
os.environ["WORLD_SIZE"] = str(torch.cuda.device_count())
from dataclasses import dataclass
from typing import Any, Dict, List, Union
class WhisperOnTheFlyDataset(TorchDataset):
"""Custom dataset that preprocesses audio on-the-fly during training"""
def __init__(self, dataset, processors, main_processor, max_target_length, audio_config):
self.dataset = dataset
self.processors = processors
self.main_processor = main_processor
self.max_target_length = max_target_length
self.sampling_rate = audio_config['sampling_rate']
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
# Process audio
audio_sample = item["audio"]
audio_data = self._process_audio(audio_sample)
# Extract with main processor
inputs = self.main_processor.feature_extractor(
audio_data,
sampling_rate=self.sampling_rate,
return_tensors="pt"
)
# Process text with appropriate processor
lang = item["language"]
if lang in ["cebuano", "khmer"]:
text = item["transcription"]
else: # english, chinese
text = item["text"]
# Map language to Whisper language token ID
lang_id_map = {
"english": 50259, # <|en|>
"chinese": 50260, # <|zh|>
"indonesian": 50275, # <|id|>
"malay": 50282, # <|ms|>
"khmer": 50323, # <|km|>
"cebuano": 50348, # <|tl|> (using Tagalog as fallback for Cebuano)
}
# Get language token ID
lang_token_id = lang_id_map.get(lang)
# Tokenize with appropriate processor
if lang == "cebuano":
labels = self.processors["cebuano"].tokenizer(
text,
return_tensors="pt",
padding=False,
truncation=True,
max_length=self.max_target_length
)
elif lang == "khmer":
labels = self.processors["khmer"].tokenizer(
text,
return_tensors="pt",
padding=False,
truncation=True,
max_length=self.max_target_length
)
elif lang == "english":
labels = self.processors["english"].tokenizer(
text,
return_tensors="pt",
padding=False
)
elif lang == "chinese":
labels = self.processors["chinese"].tokenizer(
text,
return_tensors="pt",
padding=False
)
elif lang == "indonesian":
labels = self.processors["indonesian"].tokenizer(
text,
return_tensors="pt",
padding=False
)
else: # Malay
labels = self.processors["malay"].tokenizer(
text,
return_tensors="pt",
padding=False
)
return {
"input_features": inputs.input_features.squeeze(0),
"labels": labels.input_ids.squeeze(0),
"language": lang,
"language_token_id": lang_token_id
}
def _process_audio(self, audio_sample):
"""Process audio sample into numpy array"""
import librosa
if isinstance(audio_sample, dict):
if "array" in audio_sample:
return audio_sample["array"]
elif "bytes" in audio_sample and audio_sample["bytes"] is not None:
audio_array, _ = librosa.load(io.BytesIO(audio_sample["bytes"]), sr=self.sampling_rate)
return audio_array
elif "path" in audio_sample:
audio_array, _ = librosa.load(audio_sample["path"], sr=self.sampling_rate)
return audio_array
else:
raise ValueError(f"Unknown audio dict format: {audio_sample.keys()}")
elif isinstance(audio_sample, str):
audio_array, _ = librosa.load(audio_sample, sr=self.sampling_rate)
return audio_array
else:
return audio_sample
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
decoder_start_token_id: int
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# Get original labels before modification
labels = labels_batch["input_ids"]
# Task ID is fixed for transcription (50360)
task_token_id = 50360 # transcribe task
# Create a tensor to store new labels with language and task tokens prepended
batch_size = labels.size(0)
seq_length = labels.size(1)
# Add 2 tokens (lang token and task token) at the beginning
new_labels = torch.full((batch_size, seq_length + 2), self.processor.tokenizer.pad_token_id, dtype=labels.dtype, device=labels.device)
# Add the language token and task token at the beginning for each sample
for i, feature in enumerate(features):
# Add SOT token as first token (50258)
# new_labels[i, 0] = 50258 # SOT token
# Add language token as second token if available
if "language_token_id" in feature and feature["language_token_id"] is not None:
new_labels[i, 0] = feature["language_token_id"]
# Add task token as third token
new_labels[i, 1] = task_token_id
# Copy the original label tokens after the special tokens
token_length = min(seq_length, labels.size(1))
new_labels[i, 2:2+token_length] = labels[i, :token_length]
# Create new attention mask for padded sequences
new_attention_mask = torch.zeros_like(new_labels, dtype=torch.long)
for i in range(batch_size):
# Find the last non-padding token in the original sequence
orig_seq_len = (labels[i] != self.processor.tokenizer.pad_token_id).sum().item()
# Set attention mask to 1 for all tokens up to the end of the sequence + 2 special tokens
new_attention_mask[i, :orig_seq_len+2] = 1
# Replace padding with -100 to ignore loss correctly
new_labels = new_labels.masked_fill(new_attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (new_labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
new_labels = new_labels[:, 1:]
batch["labels"] = new_labels
batch["attention_mask"] = new_attention_mask
return batch
# → Choose device (GPU if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Extract configuration values
MODEL_CHECKPOINT = config['model']['checkpoint']
OUTPUT_DIR = config['output']['output_dir']
MAX_TARGET_LENGTH = config['model']['max_target_length']
# CPU usage configuration for dataset preprocessing
MAX_CPU_CORES = config['environment']['max_cpu_cores']
TEST_CPU_CORES = config['environment']['test_cpu_cores']
# Language configurations for each dataset
DATASET_CONFIGS = config['languages']
print("Loading datasets...")
# Load datasets for each language dynamically based on configuration
datasets = {}
dataset_configs = config['datasets']
audio_config = config['audio']
# Get list of enabled languages from both languages and datasets config
enabled_languages = set(config['languages'].keys()) & set(config['datasets'].keys())
print(f"Enabled languages: {list(enabled_languages)}")
def load_fleurs_dataset(lang_name, lang_config, dataset_config):
"""Load FLEURS dataset for a language"""
print(f"Loading FLEURS {lang_name.title()}...")
lang_datasets = load_dataset(
dataset_config['source'],
dataset_config['language_code'],
split={k: v for k, v in dataset_config['splits'].items()},
trust_remote_code=dataset_config['trust_remote_code']
)
# DON'T decode audio yet - keep it compressed until preprocessing
for split in dataset_config['splits'].keys():
lang_datasets[split] = lang_datasets[split].cast_column("audio", Audio(sampling_rate=audio_config['sampling_rate'], decode=False))
# Use subset of training data if specified
if 'train_subset_ratio' in lang_config:
train_subset_ratio = lang_config['train_subset_ratio']
lang_datasets["train"] = lang_datasets["train"].train_test_split(test_size=1-train_subset_ratio, seed=config['data_processing']['seed'])["train"]
return lang_datasets
def load_simple_dataset(lang_name, dataset_config):
"""Load simple dataset with train/validation/test splits"""
print(f"Loading {lang_name.title()}...")
lang_dataset = load_dataset(dataset_config['source'], split={k: v for k, v in dataset_config['splits'].items()})
return lang_dataset
def load_english_dataset(lang_config, dataset_config):
"""Load English dataset with custom train/validation split"""
print("Loading English...")
swb_train = load_dataset(dataset_config['train_dataset'], split=dataset_config['train_split'], streaming=dataset_config['streaming'])
swb_test = load_dataset(dataset_config['test_dataset'], split=dataset_config['test_split'], streaming=dataset_config['streaming'])
# Split into train/validation
validation_size = lang_config['validation_size']
swb_val = swb_train.take(validation_size)
swb_train = swb_train.skip(validation_size)
return {
"train": swb_train,
"validation": swb_val,
"test": swb_test
}
def load_chinese_dataset(dataset_config):
"""Load Chinese dataset with multiple test splits"""
print("Loading Chinese...")
wenet_train = load_dataset(dataset_config['train_dataset'], streaming=dataset_config['streaming'])
wenet_valid = load_dataset(dataset_config['validation_dataset'], dataset_config['validation_config'], split="validation", streaming=dataset_config['streaming'], trust_remote_code=dataset_config['trust_remote_code'])
wenet_testnet = load_dataset(dataset_config['test_net_dataset'], dataset_config['test_net_config'], split="test", streaming=dataset_config['streaming'], trust_remote_code=dataset_config['trust_remote_code'])
wenet_testmeeting = load_dataset(dataset_config['test_meeting_dataset'], dataset_config['test_meeting_config'], split="test", streaming=dataset_config['streaming'], trust_remote_code=dataset_config['trust_remote_code'])
return {
"train": wenet_train["train"],
"validation": wenet_valid,
"test_net": wenet_testnet,
"test_meeting": wenet_testmeeting
}
# Load datasets for each enabled language
for lang in enabled_languages:
lang_config = config['languages'][lang]
dataset_config = dataset_configs[lang]
if lang in ['cebuano', 'khmer']:
# FLEURS datasets
datasets[lang] = load_fleurs_dataset(lang, lang_config, dataset_config)
elif lang == 'english':
# English with custom validation split
datasets[lang] = load_english_dataset(lang_config, dataset_config)
elif lang == 'chinese':
# Chinese with multiple test splits
datasets[lang] = load_chinese_dataset(dataset_config)
elif lang in ['indonesian', 'malay']:
# Simple datasets with standard splits
datasets[lang] = load_simple_dataset(lang, dataset_config)
else:
print(f"Warning: Unknown language {lang}, treating as simple dataset")
datasets[lang] = load_simple_dataset(lang, dataset_config)
print("Setting up processors...")
# Create processors for each enabled language
processors = {}
for lang in enabled_languages:
lang_config = config['languages'][lang]
processors[lang] = WhisperProcessor.from_pretrained(
MODEL_CHECKPOINT,
language=lang_config["whisper_language"]
)
# Use the first available processor as the main one, preferring English if available
if "english" in processors:
main_processor = processors["english"]
elif processors:
main_processor = processors[list(processors.keys())[0]]
else:
raise ValueError("No processors created. Check your language configuration.")
model = WhisperForConditionalGeneration.from_pretrained(MODEL_CHECKPOINT)
# Multi-GPU handling
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
print(f"Using GPU {local_rank} for training")
dist.init_process_group(backend="nccl")
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs for training")
# The model will be automatically distributed by the Trainer
model.to(torch.device("cuda", local_rank))
else:
model.to(torch.device("cuda", local_rank))
print("Adding language labels to raw datasets...")
# Remove existing language columns and add our own consistent language labels for each enabled language
for lang in enabled_languages:
lang_datasets = datasets[lang]
# Handle different dataset structures
if isinstance(lang_datasets, dict):
# Dataset with explicit splits (train/validation/test)
for split_name, split_dataset in lang_datasets.items():
if split_dataset is not None:
# Remove existing language column if it exists
columns_to_remove = [col for col in split_dataset.column_names if col.lower() in ["language", "lang"]]
if columns_to_remove:
print(f"Removing existing language column(s) {columns_to_remove} from {lang} {split_name}")
datasets[lang][split_name] = split_dataset.remove_columns(columns_to_remove)
# Add our consistent language label
datasets[lang][split_name] = datasets[lang][split_name].add_column("language", [lang] * len(datasets[lang][split_name]))
else:
# Single dataset object - this shouldn't happen with current structure but handle gracefully
print(f"Warning: Unexpected dataset structure for {lang}")
continue
print("Combining raw datasets before preprocessing...")
# Ensure all datasets have compatible schemas before concatenation
def standardize_dataset_schema(dataset, dataset_name):
"""Standardize dataset schema to ensure compatibility for concatenation"""
print(f"Standardizing schema for {dataset_name}...")
# Keep audio compressed until preprocessing - only set sampling rate
if "audio" in dataset.column_names:
print(f" Setting audio feature type to {audio_config['sampling_rate']}Hz (compressed) for {dataset_name}")
dataset = dataset.cast_column("audio", Audio(sampling_rate=audio_config['sampling_rate'], decode=False))
# Remove problematic columns that might have different types
columns_to_remove = []
for col in dataset.column_names:
if col in config['data_processing']['columns_to_remove']:
columns_to_remove.append(col)
if columns_to_remove:
print(f" Removing incompatible columns: {columns_to_remove}")
dataset = dataset.remove_columns(columns_to_remove)
return dataset
# Standardize all training datasets dynamically
print("Standardizing training datasets...")
raw_train_datasets = []
for lang in enabled_languages:
if "train" in datasets[lang]:
std_dataset = standardize_dataset_schema(datasets[lang]["train"], f"{lang}_train")
raw_train_datasets.append(std_dataset)
# Standardize validation datasets dynamically
print("Standardizing validation datasets...")
raw_val_datasets = []
for lang in enabled_languages:
if "validation" in datasets[lang]:
std_dataset = standardize_dataset_schema(datasets[lang]["validation"], f"{lang}_val")
raw_val_datasets.append(std_dataset)
# Combine datasets if we have any
if raw_train_datasets:
print("Combining training datasets...")
combined_raw_train = concatenate_datasets(raw_train_datasets)
combined_raw_train = combined_raw_train.shuffle(seed=config['data_processing']['seed'])
else:
raise ValueError("No training datasets found. Check your configuration.")
if raw_val_datasets:
print("Combining validation datasets...")
combined_raw_val = concatenate_datasets(raw_val_datasets)
combined_raw_val = combined_raw_val.shuffle(seed=config['data_processing']['seed'])
else:
print("Warning: No validation datasets found. Training without validation.")
combined_raw_val = None
print("Creating on-the-fly datasets (no preprocessing stored to disk)...")
# Create on-the-fly datasets instead of preprocessing and storing
# Create on-the-fly datasets instead of preprocessing and storing
combined_train_dataset = WhisperOnTheFlyDataset(
combined_raw_train,
processors,
main_processor,
MAX_TARGET_LENGTH,
audio_config
)
# Only create validation dataset if we have validation data
if combined_raw_val is not None:
combined_val_dataset = WhisperOnTheFlyDataset(
combined_raw_val,
processors,
main_processor,
MAX_TARGET_LENGTH,
audio_config
)
else:
combined_val_dataset = None
print("Creating on-the-fly test datasets...")
# Create on-the-fly test datasets dynamically
processed_datasets = {}
for lang in enabled_languages:
processed_datasets[lang] = {}
# Handle different test split structures for different languages
if lang == "chinese":
# Chinese has multiple test splits
if "test_net" in datasets[lang]:
processed_datasets[lang]["test_net"] = WhisperOnTheFlyDataset(
datasets[lang]["test_net"],
processors,
main_processor,
MAX_TARGET_LENGTH,
audio_config
)
if "test_meeting" in datasets[lang]:
processed_datasets[lang]["test_meeting"] = WhisperOnTheFlyDataset(
datasets[lang]["test_meeting"],
processors,
main_processor,
MAX_TARGET_LENGTH,
audio_config
)
else:
# Standard test split
if "test" in datasets[lang]:
processed_datasets[lang]["test"] = WhisperOnTheFlyDataset(
datasets[lang]["test"],
processors,
main_processor,
MAX_TARGET_LENGTH,
audio_config
)
# Data Collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=main_processor,
decoder_start_token_id=model.config.decoder_start_token_id,
)
# Metrics: WER & CER (using Hugging Face Evaluate)
wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")
def compute_metrics(pred):
"""
Compute WER and CER metrics for predictions
"""
pred_ids = pred.predictions
# Decode predictions, skipping special tokens
pred_str = main_processor.batch_decode(pred_ids, skip_special_tokens=True)
label_ids = pred.label_ids
# Replace -100 with pad token ID for decoding
label_ids[label_ids == -100] = main_processor.tokenizer.pad_token_id
# Decode reference texts, also skipping special tokens
ref_str = main_processor.batch_decode(label_ids, skip_special_tokens=True)
# lowercase & strip
pred_str = [s.lower().strip() for s in pred_str]
ref_str = [s.lower().strip() for s in ref_str]
wer_score = wer_metric.compute(predictions=pred_str, references=ref_str)
cer_score = cer_metric.compute(predictions=pred_str, references=ref_str)
# Combine metrics
metrics = {"wer": wer_score, "cer": cer_score}
# Log example predictions
if len(pred_str) > 0:
num_examples = min(3, len(pred_str))
for i in range(num_examples):
print(f"Example {i}:")
print(f" Reference: {ref_str[i]}")
print(f" Prediction: {pred_str[i]}")
return metrics
# Check for multi-GPU setup
num_gpus = torch.cuda.device_count()
print(f"Number of available GPUs: {num_gpus}")
# Get training configuration
training_config = config['training']
# Adjust batch size and gradient accumulation for multi-GPU
if num_gpus > 1:
# With multiple GPUs, use multi-GPU configuration
gpu_config = training_config['multi_gpu']
per_device_batch_size = gpu_config['per_device_train_batch_size']
per_device_eval_batch_size = gpu_config['per_device_eval_batch_size']
gradient_accumulation_steps = gpu_config['gradient_accumulation_steps']
print(f"Multi-GPU training detected. Using {num_gpus} GPUs.")
else:
# Single GPU configuration
gpu_config = training_config['single_gpu']
per_device_batch_size = gpu_config['per_device_train_batch_size']
per_device_eval_batch_size = gpu_config['per_device_eval_batch_size']
gradient_accumulation_steps = gpu_config['gradient_accumulation_steps']
print("Single GPU training.")
# Training Arguments
training_args = Seq2SeqTrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=per_device_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=training_config['learning_rate'],
warmup_steps=training_config['warmup_steps'],
max_steps=training_config['max_steps'],
gradient_checkpointing=training_config['gradient_checkpointing'],
fp16=training_config['fp16'],
eval_strategy=training_config['eval_strategy'],
per_device_eval_batch_size=per_device_eval_batch_size,
predict_with_generate=training_config['predict_with_generate'],
generation_max_length=training_config['generation_max_length'],
save_steps=training_config['save_steps'],
eval_steps=training_config['eval_steps'],
logging_steps=training_config['logging_steps'],
report_to=training_config['report_to'],
load_best_model_at_end=training_config['load_best_model_at_end'],
metric_for_best_model=training_config['metric_for_best_model'],
greater_is_better=training_config['greater_is_better'],
push_to_hub=training_config['push_to_hub'],
hub_private_repo=training_config['hub_private_repo'], # Always push to private repo
save_total_limit=training_config['save_total_limit'],
# Multi-GPU specific settings
dataloader_drop_last=training_config['dataloader_drop_last'],
ddp_find_unused_parameters=training_config['ddp_find_unused_parameters'],
)
# Initialize Seq2SeqTrainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=combined_train_dataset,
eval_dataset=combined_val_dataset,
data_collator=data_collator,
tokenizer=main_processor.feature_extractor,
compute_metrics=compute_metrics,
)
def evaluate_on_test_sets():
"""Evaluate the model on all test sets from enabled languages"""
print("\n" + "="*60)
print("EVALUATING ON ALL TEST SETS")
print("="*60)
results = {}
# Define language-specific generation parameters
lang_id_map = {
"english": 50259, # <|en|>
"chinese": 50260, # <|zh|>
"indonesian": 50275, # <|id|>
"malay": 50282, # <|ms|>
"khmer": 50323, # <|km|>
"cebuano": 50348, # <|tl|> (using Tagalog as fallback for Cebuano)
}
for lang in enabled_languages:
if lang in processed_datasets:
lang_results = {}
# Set language-specific generation parameters
lang_token_id = lang_id_map.get(lang)
task_token_id = 50360 # transcribe task
# Define forced decoder IDs for generation if language is supported
forced_decoder_ids = None
if lang_token_id:
forced_decoder_ids = [[1, lang_token_id], [2, task_token_id]]
print(f"Using forced_decoder_ids for {lang}: {forced_decoder_ids}")
if lang == "chinese":
# Chinese has multiple test splits
if "test_net" in processed_datasets[lang]:
print(f"\n***** Evaluating on WenetSpeech Chinese TEST_NET *****")
chi_testnet_metrics = trainer.predict(
processed_datasets[lang]["test_net"],
metric_key_prefix=f"test_{lang}_net",
forced_decoder_ids=forced_decoder_ids
)
print(f"Chinese TEST_NET WER: {chi_testnet_metrics.metrics[f'test_{lang}_net_wer']*100:.2f}%")
print(f"Chinese TEST_NET CER: {chi_testnet_metrics.metrics[f'test_{lang}_net_cer']*100:.2f}%")
lang_results["test_net"] = chi_testnet_metrics.metrics
if "test_meeting" in processed_datasets[lang]:
print(f"\n***** Evaluating on WenetSpeech Chinese TEST_MEETING *****")
chi_testmeet_metrics = trainer.predict(
processed_datasets[lang]["test_meeting"],
metric_key_prefix=f"test_{lang}_meeting",
forced_decoder_ids=forced_decoder_ids
)
print(f"Chinese TEST_MEETING WER: {chi_testmeet_metrics.metrics[f'test_{lang}_meeting_wer']*100:.2f}%")
print(f"Chinese TEST_MEETING CER: {chi_testmeet_metrics.metrics[f'test_{lang}_meeting_cer']*100:.2f}%")
lang_results["test_meeting"] = chi_testmeet_metrics.metrics
else:
# Standard test split
if "test" in processed_datasets[lang]:
print(f"\n***** Evaluating on {lang.title()} test set *****")
test_metrics = trainer.predict(
processed_datasets[lang]["test"],
metric_key_prefix=f"test_{lang}",
forced_decoder_ids=forced_decoder_ids
)
print(f"{lang.title()} Test WER: {test_metrics.metrics[f'test_{lang}_wer']*100:.2f}%")
print(f"{lang.title()} Test CER: {test_metrics.metrics[f'test_{lang}_cer']*100:.2f}%")
lang_results["test"] = test_metrics.metrics
results[lang] = lang_results
# Summary
print("\n" + "="*60)
print("SUMMARY OF ALL TEST RESULTS")
print("="*60)
for lang in enabled_languages:
if lang in results:
if lang == "chinese":
if "test_net" in results[lang]:
wer = results[lang]["test_net"][f"test_{lang}_net_wer"] * 100
cer = results[lang]["test_net"][f"test_{lang}_net_cer"] * 100
print(f"Chinese-NET: WER={wer:.2f}% | CER={cer:.2f}%")
if "test_meeting" in results[lang]:
wer = results[lang]["test_meeting"][f"test_{lang}_meeting_wer"] * 100
cer = results[lang]["test_meeting"][f"test_{lang}_meeting_cer"] * 100
print(f"Chinese-MTG: WER={wer:.2f}% | CER={cer:.2f}%")
else:
if "test" in results[lang]:
wer = results[lang]["test"][f"test_{lang}_wer"] * 100
cer = results[lang]["test"][f"test_{lang}_cer"] * 100
print(f"{lang.title():12}: WER={wer:.2f}% | CER={cer:.2f}%")
return results
if __name__ == "__main__":
print(f"Total training samples: {len(combined_train_dataset)}")
print(f"Total validation samples: {len(combined_val_dataset)}")
print("Starting training...")
# Fine-tune the model
trainer.train()
# Evaluate on all test sets
# evaluate_on_test_sets()
|