pengyizhou commited on
Commit
b1b221d
·
1 Parent(s): 702a986

update README

Browse files
Files changed (5) hide show
  1. README.md +138 -0
  2. config.yaml +107 -0
  3. finetune-backup.py +243 -0
  4. finetune.py +673 -0
  5. inference.py +89 -0
README.md ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Whisper Fine-tuning for Khmer Language
2
+
3
+ This project provides a configurable way to fine-tune OpenAI's Whisper model specifically on the Khmer language using the Google FLEURS dataset (km_kh).
4
+
5
+ ## Features
6
+
7
+ - **Flexible Configuration**: All parameters are configurable through YAML files
8
+ - **Multi-GPU Support**: Automatic detection and support for multiple GPUs
9
+ - **Dynamic Language Selection**: Train on any subset of supported languages
10
+ - **On-the-fly Processing**: Efficient memory usage with dynamic audio preprocessing
11
+ - **Comprehensive Evaluation**: Automatic evaluation on test sets
12
+
13
+ ## Configuration
14
+
15
+ All parameters are configurable through the `config.yaml` file. This configuration is specifically set up for Khmer language training using the Google FLEURS dataset.
16
+
17
+ ### Model Configuration
18
+ - Model checkpoint (default: `openai/whisper-large-v3`)
19
+ - Maximum target length for sequences
20
+
21
+ ### Dataset Configuration
22
+ - Uses Google FLEURS Khmer (km_kh) dataset
23
+ - Dataset sources and splits
24
+ - Language-specific settings
25
+ - Training subset ratio (25% of data for faster training)
26
+
27
+ ### Training Configuration
28
+ - Learning rate, batch sizes, training steps
29
+ - Multi-GPU vs single GPU settings
30
+ - Evaluation and logging parameters
31
+
32
+ ### Environment Configuration
33
+ - CPU core limits
34
+ - Environment variables for optimization
35
+
36
+ ### Pushing to Hub
37
+ - I have set the configuration to not push to the Hugging Face Hub by default. You can enable this by setting `push_to_hub: true` in your config file.
38
+
39
+ ## Usage
40
+
41
+ ### Basic Usage
42
+ ```bash
43
+ python finetune.py --config config.yaml
44
+ ```
45
+
46
+ ### Custom Configuration
47
+ ```bash
48
+ python finetune.py --config my_custom_config.yaml
49
+ ```
50
+
51
+ ### Multi-GPU Training
52
+ Since we only have very few training data (around 2.5 hours), multi-GPU training is not recommended.
53
+
54
+ ## Configuration File Structure
55
+
56
+ The `config.yaml` file is organized into the following sections:
57
+
58
+ 1. **model**: Model checkpoint and sequence length settings
59
+ 2. **output**: Output directory configuration
60
+ 3. **environment**: Environment variables and CPU settings
61
+ 4. **audio**: Audio processing settings (sampling rate)
62
+ 5. **languages**: Khmer language configuration
63
+ 6. **datasets**: Google FLEURS Khmer dataset configuration
64
+ 7. **training**: All training hyperparameters
65
+ 8. **data_processing**: Data processing settings
66
+
67
+ ## Customizing Your Training
68
+
69
+ ### Adjusting Training Parameters
70
+ Modify the `training` section in `config.yaml`:
71
+ - Change learning rate, batch sizes, or training steps
72
+ - Adjust evaluation frequency
73
+ - Configure multi-GPU settings
74
+
75
+ ### Environment Optimization
76
+ Adjust the `environment` section to optimize for your system:
77
+ - Set CPU core limits
78
+ - Configure memory usage settings
79
+
80
+ ## Configuration
81
+
82
+ The provided `config.yaml` is specifically configured for Khmer language training using the Google FLEURS dataset.
83
+
84
+ ## Training Commands
85
+
86
+ ### Basic Training
87
+ ```bash
88
+ python finetune.py
89
+ ```
90
+
91
+ ### Single GPU Training
92
+ ```bash
93
+ python finetune.py
94
+ ```
95
+
96
+ ## Inference Guide
97
+
98
+ After training your model, you can use the provided `inference.py` script for speech recognition:
99
+
100
+ ```bash
101
+ python inference.py
102
+ ```
103
+
104
+ The inference script includes:
105
+ - Model loading from the trained checkpoint
106
+ - Audio preprocessing pipeline
107
+ - Text generation with proper formatting
108
+ - Support for Khmer language transcription
109
+
110
+ ### Using the Trained Model
111
+
112
+ The inference script automatically handles:
113
+ - Loading the fine-tuned model weights
114
+ - Audio preprocessing with proper sampling rate
115
+ - Generating transcriptions for Khmer speech
116
+ - Output formatting for evaluation metrics
117
+
118
+ ## Dependencies
119
+
120
+ Install required packages:
121
+ ```bash
122
+ pip install -r requirements.txt
123
+ ```
124
+
125
+ Key dependencies:
126
+ - PyYAML (for configuration loading)
127
+ - torch, transformers, datasets
128
+ - librosa (for audio processing)
129
+ - evaluate (for metrics)
130
+
131
+ ## Evaluation Results
132
+ | Language | Metric | Error Rate |
133
+ |-------------|:------:|-----------:|
134
+ | Khmer | CER | 33.18% |
135
+
136
+
137
+
138
+ **Note**: If you encounter issues running finetune.py, you can use the `finetune-backup.py` file which contains the original hardcoded configuration.
config.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for training only on Khmer (FLEURS km_kh) data
2
+ # Fine-tuning Whisper on Khmer language using Google FLEURS dataset
3
+
4
+ # Model Configuration
5
+ model:
6
+ checkpoint: "openai/whisper-large-v3"
7
+ max_target_length: 448
8
+
9
+ # Output Configuration
10
+ output:
11
+ output_dir: "./whisper-fleurs-km_kh-small"
12
+
13
+ # Environment Configuration
14
+ environment:
15
+ max_cpu_cores: 20
16
+ test_cpu_cores: 20
17
+ omp_num_threads: "20"
18
+ mkl_num_threads: "20"
19
+ openblas_num_threads: "20"
20
+ veclib_maximum_threads: "20"
21
+ numexpr_num_threads: "20"
22
+ tokenizers_parallelism: "false"
23
+ transformers_no_tf: "1"
24
+
25
+ # Audio Processing Configuration
26
+ audio:
27
+ sampling_rate: 16000
28
+
29
+ # Language Configurations - Khmer only
30
+ languages:
31
+
32
+ khmer:
33
+ whisper_language: "khmer"
34
+ fleurs_language: "km_kh"
35
+ text_key: "transcription"
36
+ train_subset_ratio: 0.25 # Use only 25% of training data for faster training/experimentation
37
+
38
+ # Dataset Configurations - Khmer FLEURS
39
+ datasets:
40
+
41
+ khmer:
42
+ source: "google/fleurs"
43
+ language_code: "km_kh"
44
+ splits:
45
+ train: "train"
46
+ validation: "validation"
47
+ test: "test"
48
+ trust_remote_code: true
49
+
50
+ # Training Configuration
51
+ training:
52
+ # Basic training parameters
53
+ learning_rate: 1.0e-5
54
+ warmup_steps: 100
55
+ max_steps: 800
56
+
57
+ # Batch size and accumulation
58
+ single_gpu:
59
+ per_device_train_batch_size: 16
60
+ per_device_eval_batch_size: 16
61
+ gradient_accumulation_steps: 1
62
+
63
+
64
+ # Optimization settings
65
+ gradient_checkpointing: true
66
+ fp16: true
67
+
68
+ # Evaluation settings
69
+ eval_strategy: "steps"
70
+ eval_steps: 100
71
+ predict_with_generate: true
72
+ generation_max_length: 225
73
+
74
+ # Saving and logging
75
+ save_steps: 100
76
+ logging_steps: 25
77
+ save_total_limit: 3
78
+
79
+ # Model selection
80
+ load_best_model_at_end: true
81
+ metric_for_best_model: "cer" # Using CER for Khmer (character-based language)
82
+ greater_is_better: false
83
+
84
+ # Reporting
85
+ report_to:
86
+ - "tensorboard"
87
+
88
+ # Hub settings
89
+ push_to_hub: false
90
+
91
+ # Multi-GPU specific settings
92
+ dataloader_drop_last: true
93
+ ddp_find_unused_parameters: false
94
+
95
+ # Data Processing Configuration
96
+ data_processing:
97
+ # Random seed for reproducibility
98
+ seed: 42
99
+
100
+ # Columns to remove during standardization
101
+ columns_to_remove:
102
+ - "id"
103
+ - "num_samples"
104
+ - "path"
105
+ - "speaker_id"
106
+ - "chapter_id"
107
+ - "segment_id"
finetune-backup.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # finetune_whisper_fleurs_my_mm.py
3
+
4
+ """
5
+ Fine-tune openai/whisper-large-v3 on the FLEURS Burmese dataset (config: "my_mm").
6
+ Based on the Hugging Face blog: https://huggingface.co/blog/fine-tune-whisper
7
+ """
8
+
9
+ import os
10
+
11
+ os.environ["TRANSFORMERS_NO_TF"] = "1"
12
+
13
+ import torch
14
+ from datasets import load_dataset, Audio
15
+ from transformers import (
16
+ WhisperProcessor,
17
+ WhisperForConditionalGeneration,
18
+ Seq2SeqTrainingArguments,
19
+ Seq2SeqTrainer,
20
+ )
21
+ import ipdb
22
+ import evaluate
23
+
24
+
25
+ from dataclasses import dataclass
26
+ from typing import Any, Dict, List, Union
27
+
28
+ @dataclass
29
+ class DataCollatorSpeechSeq2SeqWithPadding:
30
+ processor: Any
31
+ decoder_start_token_id: int
32
+
33
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
34
+ # split inputs and labels since they have to be of different lengths and need different padding methods
35
+ # first treat the audio inputs by simply returning torch tensors
36
+ input_features = [{"input_features": feature["input_features"]} for feature in features]
37
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
38
+
39
+ # get the tokenized label sequences
40
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
41
+ # pad the labels to max length
42
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
43
+
44
+ # replace padding with -100 to ignore loss correctly
45
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
46
+
47
+ # if bos token is appended in previous tokenization step,
48
+ # cut bos token here as it's append later anyways
49
+ if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
50
+ labels = labels[:, 1:]
51
+
52
+ batch["labels"] = labels
53
+
54
+ return batch
55
+
56
+
57
+
58
+
59
+ # → Choose device (GPU if available)
60
+ device = "cuda" if torch.cuda.is_available() else "cpu"
61
+
62
+ # 1. Configuration
63
+ LANGUAGE = "km_kh" # FLEURS config for Khmer
64
+ LANGUAGE_WHISPER = "khmer" # Whisper config for Khmer
65
+ MODEL_CHECKPOINT = "openai/whisper-large-v3"
66
+ OUTPUT_DIR = f"./whisper-fleurs-{LANGUAGE}-small"
67
+ TRAIN_SPLIT = "train"
68
+ VALID_SPLIT = "validation"
69
+ TEST_SPLIT = "test"
70
+ MAX_TARGET_LENGTH= 448
71
+ # 2. Load FLEURS Dataset (audio at 16 kHz)
72
+ raw_datasets = load_dataset("google/fleurs", LANGUAGE,
73
+ split={ "train": TRAIN_SPLIT,
74
+ "validation": VALID_SPLIT,
75
+ "test": TEST_SPLIT })
76
+
77
+
78
+
79
+ # Cast “audio” column to 16 kHz
80
+ for split in ["train", "validation", "test"]:
81
+ raw_datasets[split] = raw_datasets[split].cast_column("audio", Audio(sampling_rate=16_000))
82
+
83
+ raw_datasets["train"] = raw_datasets["train"].train_test_split(test_size=0.75, seed=42)["test"]
84
+ # 3. Load Whisper Processor & Model
85
+
86
+ processor = WhisperProcessor.from_pretrained(MODEL_CHECKPOINT, language=LANGUAGE_WHISPER)
87
+ model = WhisperForConditionalGeneration.from_pretrained(MODEL_CHECKPOINT)
88
+ model.to(device)
89
+
90
+ # 4. Preprocessing Function
91
+ # - Extract log‐Mel features from audio
92
+ # - Tokenize the target transcription
93
+ def preprocess_batch(batch):
94
+ # batch["audio"]["array"] is a list of NumPy arrays @ 16 kHz
95
+ audio_arrays = [example["array"] for example in batch["audio"]]
96
+ # 4a. Feature extraction (log‐Mel + normalization)
97
+ inputs = processor.feature_extractor(
98
+ audio_arrays,
99
+ sampling_rate=16_000,
100
+ return_tensors="pt"
101
+ )
102
+ # 4b. Tokenize (labels) using the Whisper tokenizer
103
+ # We prefix with target language ID (e.g. "<|my_mm|>") if necessary;
104
+ # but for FLEURS, the default Whisper language‐ID tokens should suffice.
105
+ labels = processor.tokenizer(
106
+ batch["transcription"],
107
+ return_tensors="pt",
108
+ padding="longest",
109
+ truncation=True,
110
+ max_length=MAX_TARGET_LENGTH
111
+ )
112
+ # ipdb.set_trace()
113
+ # rename for trainer:
114
+ inputs["input_features"] = inputs.pop("input_features")
115
+ inputs["labels"] = labels.input_ids
116
+ return inputs
117
+
118
+ # 5. Apply preprocessing to train/validation/test
119
+ # - Remove all non‐audio columns after mapping
120
+ train_dataset = raw_datasets["train"].map(
121
+ preprocess_batch,
122
+ remove_columns=raw_datasets["train"].column_names,
123
+ batched=True,
124
+ batch_size=16, # adjust batch_size to your memory
125
+ )
126
+
127
+ # ipdb.set_trace()
128
+ eval_dataset = raw_datasets["validation"].map(
129
+ preprocess_batch,
130
+ remove_columns=raw_datasets["validation"].column_names,
131
+ batched=True,
132
+ batch_size=8,
133
+ )
134
+
135
+ test_dataset = raw_datasets["test"].map(
136
+ preprocess_batch,
137
+ remove_columns=raw_datasets["test"].column_names,
138
+ batched=True,
139
+ batch_size=8,
140
+ )
141
+
142
+ # 6. Data Collator
143
+ # This will pad input_features and labels to the maximum length in the batch,
144
+ # and replace padding token ID in labels by -100 to ignore them in loss computation.
145
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(
146
+ processor=processor,
147
+ decoder_start_token_id=model.config.decoder_start_token_id,
148
+ )
149
+
150
+ # 7. Metrics: WER & CER (using Hugging Face Evaluate)
151
+ wer_metric = evaluate.load("wer")
152
+ cer_metric = evaluate.load("cer")
153
+
154
+ def compute_metrics(pred):
155
+ """
156
+ pred.predictions: raw token IDs from generate()
157
+ pred.label_ids: token IDs used as labels
158
+ """
159
+ # 7a. decode predictions → strings
160
+ pred_ids = pred.predictions
161
+ # ensure we skip special tokens
162
+ pred_str = processor.batch_decode(pred_ids,
163
+ skip_special_tokens=True)
164
+ # 7b. decode references → strings, replacing -100 with padding_token_id
165
+ label_ids = pred.label_ids
166
+ # replace -100 with pad_token_id so that the tokenizer does not crash
167
+ label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
168
+ ref_str = processor.batch_decode(label_ids, skip_special_tokens=True)
169
+
170
+ # lowercase & strip
171
+ pred_str = [s.lower().strip() for s in pred_str]
172
+ ref_str = [s.lower().strip() for s in ref_str]
173
+
174
+ wer_score = wer_metric.compute(predictions=pred_str, references=ref_str)
175
+ cer_score = cer_metric.compute(predictions=pred_str, references=ref_str)
176
+ return { "wer": wer_score, "cer": cer_score }
177
+
178
+
179
+ """
180
+ # 8. Training Arguments
181
+ training_args = Seq2SeqTrainingArguments(
182
+ output_dir=OUTPUT_DIR,
183
+ per_device_train_batch_size=4, # reduce if you OOM; or increase if large GPU
184
+ per_device_eval_batch_size=4,
185
+ gradient_accumulation_steps=2, # to simulate a larger batch
186
+ evaluation_strategy="steps",
187
+ eval_steps=500, # evaluate every 500 steps
188
+ logging_steps=250,
189
+ save_steps=1000,
190
+ num_train_epochs=3,
191
+ learning_rate=1e-5,
192
+ warmup_steps=500,
193
+ fp16=True, # use mixed precision if supported
194
+ predict_with_generate=True, # for computing WER/CER we need generate()
195
+ save_total_limit=2,
196
+ push_to_hub=False,
197
+ )
198
+ """
199
+ training_args = Seq2SeqTrainingArguments(
200
+ output_dir=OUTPUT_DIR,
201
+ per_device_train_batch_size=16,
202
+ gradient_accumulation_steps=1,
203
+ learning_rate=1e-5,
204
+ warmup_steps=100,
205
+ max_steps=800,
206
+ gradient_checkpointing=True,
207
+ fp16=True,
208
+ eval_strategy="steps",
209
+ per_device_eval_batch_size=8,
210
+ predict_with_generate=True,
211
+ generation_max_length=448,
212
+ save_steps=100,
213
+ eval_steps=100,
214
+ logging_steps=10,
215
+ report_to=["tensorboard"],
216
+ load_best_model_at_end=True,
217
+ metric_for_best_model="wer",
218
+ greater_is_better=False,
219
+ push_to_hub=True
220
+ )
221
+
222
+
223
+ # 9. Initialize Seq2SeqTrainer
224
+ trainer = Seq2SeqTrainer(
225
+ model=model,
226
+ args=training_args,
227
+ train_dataset=train_dataset,
228
+ eval_dataset=eval_dataset,
229
+ data_collator=data_collator,
230
+ tokenizer=processor.feature_extractor, # feature_extractor + tokenizer packed in processor
231
+ compute_metrics=compute_metrics,
232
+ )
233
+
234
+ # 10. Fine-tune
235
+ if __name__ == "__main__":
236
+ # 10a. Train
237
+ trainer.train()
238
+
239
+ # 10b. Evaluate on TEST split
240
+ print("\n***** Evaluating on TEST split *****")
241
+ test_metrics = trainer.predict(test_dataset, metric_key_prefix="test")
242
+ print(f"Test WER: {test_metrics.metrics['test_wer']*100:.2f}%")
243
+ print(f"Test CER: {test_metrics.metrics['test_cer']*100:.2f}%")
finetune.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # finetune_whisper_mix_datasets.py
3
+
4
+ """
5
+ Fine-tune openai/whisper-large-v3 on mixed datasets from different languages:
6
+ - FLEURS Cebuano (ceb_ph)
7
+ - FLEURS Khmer (km_kh)
8
+ - Switchboard1 English
9
+ - WenetSpeech Chinese
10
+ - Eng-Indon-CS
11
+ - Eng-Malay-CS
12
+ Based on the Hugging Face blog: https://huggingface.co/blog/fine-tune-whisper
13
+
14
+ To run this script on multiple GPUs, you have several options:
15
+
16
+ 1. **Automatic Multi-GPU (DataParallel-style):**
17
+ python finetune_whisper_mix_datasets.py
18
+
19
+ The script will automatically detect and use all available GPUs.
20
+
21
+ 2. **Distributed Training with torchrun (Recommended for 2+ GPUs):**
22
+ torchrun --nproc_per_node=2 finetune_whisper_mix_datasets.py
23
+
24
+ This uses DistributedDataParallel which is more efficient.
25
+
26
+ 3. **Distributed Training with accelerate (Alternative):**
27
+ accelerate launch --num_processes=2 finetune_whisper_mix_datasets.py
28
+
29
+ Requires: pip install accelerate
30
+
31
+ Note: With 2 GPUs, the effective batch size becomes:
32
+ per_device_batch_size * num_gpus * gradient_accumulation_steps
33
+ = 24 * 2 * 1 = 48 (compared to 32 with single GPU)
34
+
35
+ CPU Core Limiting:
36
+ The script automatically limits CPU usage to 20 cores using environment variables.
37
+ You can also set these manually before running:
38
+ export OMP_NUM_THREADS=20
39
+ export MKL_NUM_THREADS=20
40
+ export NUMEXPR_NUM_THREADS=20
41
+ python finetune_whisper_mix_datasets.py
42
+ """
43
+
44
+ import os
45
+ import random
46
+ import io
47
+ import yaml
48
+ import argparse
49
+ from itertools import chain
50
+
51
+ # Load configuration from YAML file
52
+ def load_config(config_path):
53
+ with open(config_path, 'r') as file:
54
+ return yaml.safe_load(file)
55
+
56
+ # Parse command line arguments
57
+ parser = argparse.ArgumentParser(description='Fine-tune Whisper on mixed datasets')
58
+ parser.add_argument('--config', type=str, default='config.yaml',
59
+ help='Path to configuration YAML file')
60
+ args = parser.parse_args()
61
+
62
+ # Load configuration
63
+ config = load_config(args.config)
64
+
65
+ # Set environment variables from config
66
+ env_config = config['environment']
67
+ os.environ["OMP_NUM_THREADS"] = env_config['omp_num_threads']
68
+ os.environ["MKL_NUM_THREADS"] = env_config['mkl_num_threads']
69
+ os.environ["OPENBLAS_NUM_THREADS"] = env_config['openblas_num_threads']
70
+ os.environ["VECLIB_MAXIMUM_THREADS"] = env_config['veclib_maximum_threads']
71
+ os.environ["NUMEXPR_NUM_THREADS"] = env_config['numexpr_num_threads']
72
+ os.environ["TOKENIZERS_PARALLELISM"] = env_config['tokenizers_parallelism']
73
+ os.environ["TRANSFORMERS_NO_TF"] = env_config['transformers_no_tf']
74
+
75
+ import torch
76
+ from datasets import load_dataset, Audio, concatenate_datasets, Dataset
77
+ from torch.utils.data import Dataset as TorchDataset
78
+ from transformers import (
79
+ WhisperProcessor,
80
+ WhisperForConditionalGeneration,
81
+ Seq2SeqTrainingArguments,
82
+ Seq2SeqTrainer,
83
+ )
84
+ import ipdb
85
+ import evaluate
86
+ import numpy as np
87
+ import ipdb
88
+
89
+ # Multi-GPU setup
90
+ if torch.cuda.device_count() > 1:
91
+ print(f"Setting up for {torch.cuda.device_count()} GPUs")
92
+ # Enable distributed training environment variables if not already set
93
+ if "LOCAL_RANK" not in os.environ:
94
+ os.environ["LOCAL_RANK"] = "0"
95
+ if "WORLD_SIZE" not in os.environ:
96
+ os.environ["WORLD_SIZE"] = str(torch.cuda.device_count())
97
+
98
+
99
+ from dataclasses import dataclass
100
+ from typing import Any, Dict, List, Union
101
+
102
+ class WhisperOnTheFlyDataset(TorchDataset):
103
+ """Custom dataset that preprocesses audio on-the-fly during training"""
104
+
105
+ def __init__(self, dataset, processors, main_processor, max_target_length, audio_config):
106
+ self.dataset = dataset
107
+ self.processors = processors
108
+ self.main_processor = main_processor
109
+ self.max_target_length = max_target_length
110
+ self.sampling_rate = audio_config['sampling_rate']
111
+
112
+ def __len__(self):
113
+ return len(self.dataset)
114
+
115
+ def __getitem__(self, idx):
116
+ item = self.dataset[idx]
117
+ # Process audio
118
+ audio_sample = item["audio"]
119
+ audio_data = self._process_audio(audio_sample)
120
+
121
+ # Extract with main processor
122
+ inputs = self.main_processor.feature_extractor(
123
+ audio_data,
124
+ sampling_rate=self.sampling_rate,
125
+ return_tensors="pt"
126
+ )
127
+
128
+ # Process text with appropriate processor
129
+ lang = item["language"]
130
+ if lang in ["cebuano", "khmer"]:
131
+ text = item["transcription"]
132
+ else: # english, chinese
133
+ text = item["text"]
134
+
135
+ # Tokenize with appropriate processor
136
+ if lang == "cebuano":
137
+ labels = self.processors["cebuano"].tokenizer(
138
+ text,
139
+ return_tensors="pt",
140
+ padding=False,
141
+ truncation=True,
142
+ max_length=self.max_target_length
143
+ )
144
+ elif lang == "khmer":
145
+ labels = self.processors["khmer"].tokenizer(
146
+ text,
147
+ return_tensors="pt",
148
+ padding=False,
149
+ truncation=True,
150
+ max_length=self.max_target_length
151
+ )
152
+ elif lang == "english":
153
+ labels = self.processors["english"].tokenizer(
154
+ text,
155
+ return_tensors="pt",
156
+ padding=False
157
+ )
158
+ elif lang == "chinese":
159
+ labels = self.processors["chinese"].tokenizer(
160
+ text,
161
+ return_tensors="pt",
162
+ padding=False
163
+ )
164
+ elif lang == "indonesian":
165
+ labels = self.processors["indonesian"].tokenizer(
166
+ text,
167
+ return_tensors="pt",
168
+ padding=False
169
+ )
170
+ else: # Malay
171
+ labels = self.processors["malay"].tokenizer(
172
+ text,
173
+ return_tensors="pt",
174
+ padding=False
175
+ )
176
+
177
+ return {
178
+ "input_features": inputs.input_features.squeeze(0),
179
+ "labels": labels.input_ids.squeeze(0),
180
+ "language": lang
181
+ }
182
+
183
+ def _process_audio(self, audio_sample):
184
+ """Process audio sample into numpy array"""
185
+ import librosa
186
+
187
+ if isinstance(audio_sample, dict):
188
+ if "array" in audio_sample:
189
+ return audio_sample["array"]
190
+ elif "bytes" in audio_sample and audio_sample["bytes"] is not None:
191
+ audio_array, _ = librosa.load(io.BytesIO(audio_sample["bytes"]), sr=self.sampling_rate)
192
+ return audio_array
193
+ elif "path" in audio_sample:
194
+ audio_array, _ = librosa.load(audio_sample["path"], sr=self.sampling_rate)
195
+ return audio_array
196
+ else:
197
+ raise ValueError(f"Unknown audio dict format: {audio_sample.keys()}")
198
+ elif isinstance(audio_sample, str):
199
+ audio_array, _ = librosa.load(audio_sample, sr=self.sampling_rate)
200
+ return audio_array
201
+ else:
202
+ return audio_sample
203
+
204
+ @dataclass
205
+ class DataCollatorSpeechSeq2SeqWithPadding:
206
+ processor: Any
207
+ decoder_start_token_id: int
208
+
209
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
210
+ # split inputs and labels since they have to be of different lengths and need different padding methods
211
+ # first treat the audio inputs by simply returning torch tensors
212
+ input_features = [{"input_features": feature["input_features"]} for feature in features]
213
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
214
+
215
+ # get the tokenized label sequences
216
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
217
+ # pad the labels to max length
218
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
219
+
220
+ # replace padding with -100 to ignore loss correctly
221
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
222
+
223
+ # if bos token is appended in previous tokenization step,
224
+ # cut bos token here as it's append later anyways
225
+ if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
226
+ labels = labels[:, 1:]
227
+
228
+ batch["labels"] = labels
229
+
230
+ return batch
231
+
232
+ # → Choose device (GPU if available)
233
+ device = "cuda" if torch.cuda.is_available() else "cpu"
234
+
235
+ # Extract configuration values
236
+ MODEL_CHECKPOINT = config['model']['checkpoint']
237
+ OUTPUT_DIR = config['output']['output_dir']
238
+ MAX_TARGET_LENGTH = config['model']['max_target_length']
239
+
240
+ # CPU usage configuration for dataset preprocessing
241
+ MAX_CPU_CORES = config['environment']['max_cpu_cores']
242
+ TEST_CPU_CORES = config['environment']['test_cpu_cores']
243
+
244
+ # Language configurations for each dataset
245
+ DATASET_CONFIGS = config['languages']
246
+
247
+ print("Loading datasets...")
248
+
249
+ # Load datasets for each language dynamically based on configuration
250
+ datasets = {}
251
+ dataset_configs = config['datasets']
252
+ audio_config = config['audio']
253
+
254
+ # Get list of enabled languages from both languages and datasets config
255
+ enabled_languages = set(config['languages'].keys()) & set(config['datasets'].keys())
256
+ print(f"Enabled languages: {list(enabled_languages)}")
257
+
258
+ def load_fleurs_dataset(lang_name, lang_config, dataset_config):
259
+ """Load FLEURS dataset for a language"""
260
+ print(f"Loading FLEURS {lang_name.title()}...")
261
+ lang_datasets = load_dataset(
262
+ dataset_config['source'],
263
+ dataset_config['language_code'],
264
+ split={k: v for k, v in dataset_config['splits'].items()},
265
+ trust_remote_code=dataset_config['trust_remote_code']
266
+ )
267
+ # DON'T decode audio yet - keep it compressed until preprocessing
268
+ for split in dataset_config['splits'].keys():
269
+ lang_datasets[split] = lang_datasets[split].cast_column("audio", Audio(sampling_rate=audio_config['sampling_rate'], decode=False))
270
+
271
+ # Use subset of training data if specified
272
+ if 'train_subset_ratio' in lang_config:
273
+ train_subset_ratio = lang_config['train_subset_ratio']
274
+ lang_datasets["train"] = lang_datasets["train"].train_test_split(test_size=1-train_subset_ratio, seed=config['data_processing']['seed'])["train"]
275
+
276
+ return lang_datasets
277
+
278
+ def load_simple_dataset(lang_name, dataset_config):
279
+ """Load simple dataset with train/validation/test splits"""
280
+ print(f"Loading {lang_name.title()}...")
281
+ lang_dataset = load_dataset(dataset_config['source'], split={k: v for k, v in dataset_config['splits'].items()})
282
+ return lang_dataset
283
+
284
+ def load_english_dataset(lang_config, dataset_config):
285
+ """Load English dataset with custom train/validation split"""
286
+ print("Loading English...")
287
+ swb_train = load_dataset(dataset_config['train_dataset'], split=dataset_config['train_split'], streaming=dataset_config['streaming'])
288
+ swb_test = load_dataset(dataset_config['test_dataset'], split=dataset_config['test_split'], streaming=dataset_config['streaming'])
289
+ # Split into train/validation
290
+ validation_size = lang_config['validation_size']
291
+ swb_val = swb_train.take(validation_size)
292
+ swb_train = swb_train.skip(validation_size)
293
+ return {
294
+ "train": swb_train,
295
+ "validation": swb_val,
296
+ "test": swb_test
297
+ }
298
+
299
+ def load_chinese_dataset(dataset_config):
300
+ """Load Chinese dataset with multiple test splits"""
301
+ print("Loading Chinese...")
302
+ wenet_train = load_dataset(dataset_config['train_dataset'], streaming=dataset_config['streaming'])
303
+ wenet_valid = load_dataset(dataset_config['validation_dataset'], dataset_config['validation_config'], split="validation", streaming=dataset_config['streaming'])
304
+ wenet_testnet = load_dataset(dataset_config['test_net_dataset'], dataset_config['test_net_config'], split="test", streaming=dataset_config['streaming'])
305
+ wenet_testmeeting = load_dataset(dataset_config['test_meeting_dataset'], dataset_config['test_meeting_config'], split="test", streaming=dataset_config['streaming'])
306
+ return {
307
+ "train": wenet_train["train"],
308
+ "validation": wenet_valid,
309
+ "test_net": wenet_testnet,
310
+ "test_meeting": wenet_testmeeting
311
+ }
312
+
313
+ # Load datasets for each enabled language
314
+ for lang in enabled_languages:
315
+ lang_config = config['languages'][lang]
316
+ dataset_config = dataset_configs[lang]
317
+
318
+ if lang in ['cebuano', 'khmer']:
319
+ # FLEURS datasets
320
+ datasets[lang] = load_fleurs_dataset(lang, lang_config, dataset_config)
321
+ elif lang == 'english':
322
+ # English with custom validation split
323
+ datasets[lang] = load_english_dataset(lang_config, dataset_config)
324
+ elif lang == 'chinese':
325
+ # Chinese with multiple test splits
326
+ datasets[lang] = load_chinese_dataset(dataset_config)
327
+ elif lang in ['indonesian', 'malay']:
328
+ # Simple datasets with standard splits
329
+ datasets[lang] = load_simple_dataset(lang, dataset_config)
330
+ else:
331
+ print(f"Warning: Unknown language {lang}, treating as simple dataset")
332
+ datasets[lang] = load_simple_dataset(lang, dataset_config)
333
+
334
+ print("Setting up processors...")
335
+
336
+ # Create processors for each enabled language
337
+ processors = {}
338
+ for lang in enabled_languages:
339
+ lang_config = config['languages'][lang]
340
+ processors[lang] = WhisperProcessor.from_pretrained(
341
+ MODEL_CHECKPOINT,
342
+ language=lang_config["whisper_language"]
343
+ )
344
+
345
+ # Use the first available processor as the main one, preferring English if available
346
+ if "english" in processors:
347
+ main_processor = processors["english"]
348
+ elif processors:
349
+ main_processor = processors[list(processors.keys())[0]]
350
+ else:
351
+ raise ValueError("No processors created. Check your language configuration.")
352
+ model = WhisperForConditionalGeneration.from_pretrained(MODEL_CHECKPOINT)
353
+
354
+ # Multi-GPU handling
355
+ if torch.cuda.device_count() > 1:
356
+ print(f"Using {torch.cuda.device_count()} GPUs for training")
357
+ # The model will be automatically distributed by the Trainer
358
+ model.to(device)
359
+ else:
360
+ model.to(device)
361
+
362
+
363
+
364
+ print("Adding language labels to raw datasets...")
365
+
366
+ # Remove existing language columns and add our own consistent language labels for each enabled language
367
+ for lang in enabled_languages:
368
+ lang_datasets = datasets[lang]
369
+
370
+ # Handle different dataset structures
371
+ if isinstance(lang_datasets, dict):
372
+ # Dataset with explicit splits (train/validation/test)
373
+ for split_name, split_dataset in lang_datasets.items():
374
+ if split_dataset is not None:
375
+ # Remove existing language column if it exists
376
+ columns_to_remove = [col for col in split_dataset.column_names if col.lower() in ["language", "lang"]]
377
+ if columns_to_remove:
378
+ print(f"Removing existing language column(s) {columns_to_remove} from {lang} {split_name}")
379
+ datasets[lang][split_name] = split_dataset.remove_columns(columns_to_remove)
380
+
381
+ # Add our consistent language label
382
+ datasets[lang][split_name] = datasets[lang][split_name].add_column("language", [lang] * len(datasets[lang][split_name]))
383
+ else:
384
+ # Single dataset object - this shouldn't happen with current structure but handle gracefully
385
+ print(f"Warning: Unexpected dataset structure for {lang}")
386
+ continue
387
+
388
+
389
+ print("Combining raw datasets before preprocessing...")
390
+
391
+ # Ensure all datasets have compatible schemas before concatenation
392
+ def standardize_dataset_schema(dataset, dataset_name):
393
+ """Standardize dataset schema to ensure compatibility for concatenation"""
394
+ print(f"Standardizing schema for {dataset_name}...")
395
+
396
+ # Keep audio compressed until preprocessing - only set sampling rate
397
+ if "audio" in dataset.column_names:
398
+ print(f" Setting audio feature type to {audio_config['sampling_rate']}Hz (compressed) for {dataset_name}")
399
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=audio_config['sampling_rate'], decode=False))
400
+
401
+ # Remove problematic columns that might have different types
402
+ columns_to_remove = []
403
+ for col in dataset.column_names:
404
+ if col in config['data_processing']['columns_to_remove']:
405
+ columns_to_remove.append(col)
406
+
407
+ if columns_to_remove:
408
+ print(f" Removing incompatible columns: {columns_to_remove}")
409
+ dataset = dataset.remove_columns(columns_to_remove)
410
+
411
+ return dataset
412
+
413
+ # Standardize all training datasets dynamically
414
+ print("Standardizing training datasets...")
415
+ raw_train_datasets = []
416
+ for lang in enabled_languages:
417
+ if "train" in datasets[lang]:
418
+ std_dataset = standardize_dataset_schema(datasets[lang]["train"], f"{lang}_train")
419
+ raw_train_datasets.append(std_dataset)
420
+
421
+ # Standardize validation datasets dynamically
422
+ print("Standardizing validation datasets...")
423
+ raw_val_datasets = []
424
+ for lang in enabled_languages:
425
+ if "validation" in datasets[lang]:
426
+ std_dataset = standardize_dataset_schema(datasets[lang]["validation"], f"{lang}_val")
427
+ raw_val_datasets.append(std_dataset)
428
+
429
+ # Combine datasets if we have any
430
+ if raw_train_datasets:
431
+ print("Combining training datasets...")
432
+ combined_raw_train = concatenate_datasets(raw_train_datasets)
433
+ combined_raw_train = combined_raw_train.shuffle(seed=config['data_processing']['seed'])
434
+ else:
435
+ raise ValueError("No training datasets found. Check your configuration.")
436
+
437
+ if raw_val_datasets:
438
+ print("Combining validation datasets...")
439
+ combined_raw_val = concatenate_datasets(raw_val_datasets)
440
+ combined_raw_val = combined_raw_val.shuffle(seed=config['data_processing']['seed'])
441
+ else:
442
+ print("Warning: No validation datasets found. Training without validation.")
443
+ combined_raw_val = None
444
+
445
+ print("Creating on-the-fly datasets (no preprocessing stored to disk)...")
446
+
447
+ # Create on-the-fly datasets instead of preprocessing and storing
448
+ # Create on-the-fly datasets instead of preprocessing and storing
449
+ combined_train_dataset = WhisperOnTheFlyDataset(
450
+ combined_raw_train,
451
+ processors,
452
+ main_processor,
453
+ MAX_TARGET_LENGTH,
454
+ audio_config
455
+ )
456
+
457
+ # Only create validation dataset if we have validation data
458
+ if combined_raw_val is not None:
459
+ combined_val_dataset = WhisperOnTheFlyDataset(
460
+ combined_raw_val,
461
+ processors,
462
+ main_processor,
463
+ MAX_TARGET_LENGTH,
464
+ audio_config
465
+ )
466
+ else:
467
+ combined_val_dataset = None
468
+
469
+ print("Creating on-the-fly test datasets...")
470
+
471
+ # Create on-the-fly test datasets dynamically
472
+ processed_datasets = {}
473
+
474
+ for lang in enabled_languages:
475
+ processed_datasets[lang] = {}
476
+
477
+ # Handle different test split structures for different languages
478
+ if lang == "chinese":
479
+ # Chinese has multiple test splits
480
+ if "test_net" in datasets[lang]:
481
+ processed_datasets[lang]["test_net"] = WhisperOnTheFlyDataset(
482
+ datasets[lang]["test_net"],
483
+ processors,
484
+ main_processor,
485
+ MAX_TARGET_LENGTH,
486
+ audio_config
487
+ )
488
+ if "test_meeting" in datasets[lang]:
489
+ processed_datasets[lang]["test_meeting"] = WhisperOnTheFlyDataset(
490
+ datasets[lang]["test_meeting"],
491
+ processors,
492
+ main_processor,
493
+ MAX_TARGET_LENGTH,
494
+ audio_config
495
+ )
496
+ else:
497
+ # Standard test split
498
+ if "test" in datasets[lang]:
499
+ processed_datasets[lang]["test"] = WhisperOnTheFlyDataset(
500
+ datasets[lang]["test"],
501
+ processors,
502
+ main_processor,
503
+ MAX_TARGET_LENGTH,
504
+ audio_config
505
+ )
506
+
507
+ # Data Collator
508
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(
509
+ processor=main_processor,
510
+ decoder_start_token_id=model.config.decoder_start_token_id,
511
+ )
512
+
513
+ # Metrics: WER & CER (using Hugging Face Evaluate)
514
+ wer_metric = evaluate.load("wer")
515
+ cer_metric = evaluate.load("cer")
516
+
517
+ def compute_metrics(pred):
518
+ """
519
+ Compute WER and CER metrics for predictions
520
+ """
521
+ pred_ids = pred.predictions
522
+ pred_str = main_processor.batch_decode(pred_ids, skip_special_tokens=True)
523
+
524
+ label_ids = pred.label_ids
525
+ label_ids[label_ids == -100] = main_processor.tokenizer.pad_token_id
526
+ ref_str = main_processor.batch_decode(label_ids, skip_special_tokens=True)
527
+
528
+ # lowercase & strip
529
+ pred_str = [s.lower().strip() for s in pred_str]
530
+ ref_str = [s.lower().strip() for s in ref_str]
531
+
532
+ wer_score = wer_metric.compute(predictions=pred_str, references=ref_str)
533
+ cer_score = cer_metric.compute(predictions=pred_str, references=ref_str)
534
+ return {"wer": wer_score, "cer": cer_score}
535
+
536
+ # Check for multi-GPU setup
537
+ num_gpus = torch.cuda.device_count()
538
+ print(f"Number of available GPUs: {num_gpus}")
539
+
540
+ # Get training configuration
541
+ training_config = config['training']
542
+
543
+ # Adjust batch size and gradient accumulation for multi-GPU
544
+ if num_gpus > 1:
545
+ # With multiple GPUs, use multi-GPU configuration
546
+ gpu_config = training_config['multi_gpu']
547
+ per_device_batch_size = gpu_config['per_device_train_batch_size']
548
+ per_device_eval_batch_size = gpu_config['per_device_eval_batch_size']
549
+ gradient_accumulation_steps = gpu_config['gradient_accumulation_steps']
550
+ print(f"Multi-GPU training detected. Using {num_gpus} GPUs.")
551
+ else:
552
+ # Single GPU configuration
553
+ gpu_config = training_config['single_gpu']
554
+ per_device_batch_size = gpu_config['per_device_train_batch_size']
555
+ per_device_eval_batch_size = gpu_config['per_device_eval_batch_size']
556
+ gradient_accumulation_steps = gpu_config['gradient_accumulation_steps']
557
+ print("Single GPU training.")
558
+
559
+ # Training Arguments
560
+ training_args = Seq2SeqTrainingArguments(
561
+ output_dir=OUTPUT_DIR,
562
+ per_device_train_batch_size=per_device_batch_size,
563
+ gradient_accumulation_steps=gradient_accumulation_steps,
564
+ learning_rate=training_config['learning_rate'],
565
+ warmup_steps=training_config['warmup_steps'],
566
+ max_steps=training_config['max_steps'],
567
+ gradient_checkpointing=training_config['gradient_checkpointing'],
568
+ fp16=training_config['fp16'],
569
+ eval_strategy=training_config['eval_strategy'],
570
+ per_device_eval_batch_size=per_device_eval_batch_size,
571
+ predict_with_generate=training_config['predict_with_generate'],
572
+ generation_max_length=training_config['generation_max_length'],
573
+ save_steps=training_config['save_steps'],
574
+ eval_steps=training_config['eval_steps'],
575
+ logging_steps=training_config['logging_steps'],
576
+ report_to=training_config['report_to'],
577
+ load_best_model_at_end=training_config['load_best_model_at_end'],
578
+ metric_for_best_model=training_config['metric_for_best_model'],
579
+ greater_is_better=training_config['greater_is_better'],
580
+ push_to_hub=training_config['push_to_hub'],
581
+ save_total_limit=training_config['save_total_limit'],
582
+ # Multi-GPU specific settings
583
+ dataloader_drop_last=training_config['dataloader_drop_last'],
584
+ ddp_find_unused_parameters=training_config['ddp_find_unused_parameters'],
585
+ )
586
+
587
+ # Initialize Seq2SeqTrainer
588
+ trainer = Seq2SeqTrainer(
589
+ model=model,
590
+ args=training_args,
591
+ train_dataset=combined_train_dataset,
592
+ eval_dataset=combined_val_dataset,
593
+ data_collator=data_collator,
594
+ tokenizer=main_processor.feature_extractor,
595
+ compute_metrics=compute_metrics,
596
+ )
597
+
598
+ def evaluate_on_test_sets():
599
+ """Evaluate the model on all test sets from enabled languages"""
600
+ print("\n" + "="*60)
601
+ print("EVALUATING ON ALL TEST SETS")
602
+ print("="*60)
603
+
604
+ results = {}
605
+
606
+ for lang in enabled_languages:
607
+ if lang in processed_datasets:
608
+ lang_results = {}
609
+
610
+ if lang == "chinese":
611
+ # Chinese has multiple test splits
612
+ if "test_net" in processed_datasets[lang]:
613
+ print(f"\n***** Evaluating on WenetSpeech Chinese TEST_NET *****")
614
+ chi_testnet_metrics = trainer.predict(processed_datasets[lang]["test_net"], metric_key_prefix=f"test_{lang}_net")
615
+ print(f"Chinese TEST_NET WER: {chi_testnet_metrics.metrics[f'test_{lang}_net_wer']*100:.2f}%")
616
+ print(f"Chinese TEST_NET CER: {chi_testnet_metrics.metrics[f'test_{lang}_net_cer']*100:.2f}%")
617
+ lang_results["test_net"] = chi_testnet_metrics.metrics
618
+
619
+ if "test_meeting" in processed_datasets[lang]:
620
+ print(f"\n***** Evaluating on WenetSpeech Chinese TEST_MEETING *****")
621
+ chi_testmeet_metrics = trainer.predict(processed_datasets[lang]["test_meeting"], metric_key_prefix=f"test_{lang}_meeting")
622
+ print(f"Chinese TEST_MEETING WER: {chi_testmeet_metrics.metrics[f'test_{lang}_meeting_wer']*100:.2f}%")
623
+ print(f"Chinese TEST_MEETING CER: {chi_testmeet_metrics.metrics[f'test_{lang}_meeting_cer']*100:.2f}%")
624
+ lang_results["test_meeting"] = chi_testmeet_metrics.metrics
625
+ else:
626
+ # Standard test split
627
+ if "test" in processed_datasets[lang]:
628
+ print(f"\n***** Evaluating on {lang.title()} test set *****")
629
+ test_metrics = trainer.predict(processed_datasets[lang]["test"], metric_key_prefix=f"test_{lang}")
630
+ print(f"{lang.title()} Test WER: {test_metrics.metrics[f'test_{lang}_wer']*100:.2f}%")
631
+ print(f"{lang.title()} Test CER: {test_metrics.metrics[f'test_{lang}_cer']*100:.2f}%")
632
+ lang_results["test"] = test_metrics.metrics
633
+
634
+ results[lang] = lang_results
635
+
636
+ # Summary
637
+ print("\n" + "="*60)
638
+ print("SUMMARY OF ALL TEST RESULTS")
639
+ print("="*60)
640
+
641
+ for lang in enabled_languages:
642
+ if lang in results:
643
+ if lang == "chinese":
644
+ if "test_net" in results[lang]:
645
+ wer = results[lang]["test_net"][f"test_{lang}_net_wer"] * 100
646
+ cer = results[lang]["test_net"][f"test_{lang}_net_cer"] * 100
647
+ print(f"Chinese-NET: WER={wer:.2f}% | CER={cer:.2f}%")
648
+ if "test_meeting" in results[lang]:
649
+ wer = results[lang]["test_meeting"][f"test_{lang}_meeting_wer"] * 100
650
+ cer = results[lang]["test_meeting"][f"test_{lang}_meeting_cer"] * 100
651
+ print(f"Chinese-MTG: WER={wer:.2f}% | CER={cer:.2f}%")
652
+ else:
653
+ if "test" in results[lang]:
654
+ wer = results[lang]["test"][f"test_{lang}_wer"] * 100
655
+ cer = results[lang]["test"][f"test_{lang}_cer"] * 100
656
+ print(f"{lang.title():12}: WER={wer:.2f}% | CER={cer:.2f}%")
657
+
658
+ return results
659
+
660
+ if __name__ == "__main__":
661
+ print(f"Total training samples: {len(combined_train_dataset)}")
662
+ print(f"Total validation samples: {len(combined_val_dataset)}")
663
+ print("Starting training...")
664
+
665
+ # Fine-tune the model
666
+ trainer.train()
667
+
668
+ # Evaluate on all test sets
669
+ evaluate_on_test_sets()
670
+
671
+
672
+
673
+
inference.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # pip install transformers datasets torch soundfile jiwer
4
+
5
+ from datasets import load_dataset, Audio
6
+ from transformers import pipeline, WhisperProcessor
7
+ from torch.utils.data import DataLoader
8
+ import torch
9
+ from jiwer import wer as jiwer_wer
10
+ from jiwer import cer as jiwer_cer
11
+ import ipdb
12
+
13
+ # 1. Load FLEURS Burmese test set, cast to 16 kHz audio
14
+ ds = load_dataset("google/fleurs", "km_kh", split="test")
15
+ ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
16
+
17
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
18
+
19
+
20
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
21
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
22
+
23
+ # model_id = "openai/whisper-large-v3"
24
+ model_id = "pengyizhou/whisper-fleurs-km_kh"
25
+
26
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
27
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
28
+ )
29
+ model.to(device)
30
+ whisper_model = "openai/whisper-large-v3"
31
+ processor = WhisperProcessor.from_pretrained(whisper_model, language="khmer")
32
+
33
+ asr = pipeline(
34
+ "automatic-speech-recognition",
35
+ model=model,
36
+ tokenizer=processor.tokenizer,
37
+ feature_extractor=processor.feature_extractor,
38
+ torch_dtype=torch_dtype,
39
+ chunk_length_s=30,
40
+ batch_size=64,
41
+ max_new_tokens=440,
42
+ device=device,
43
+ no_repeat_ngram_size=3, # Prevent repeating 3-grams
44
+ repetition_penalty=1.0, # Penalize repetitions (>1.0 reduces repetition)
45
+ length_penalty=1.0, # Control length preference
46
+ num_beams=1, # Use beam search for better quality
47
+ do_sample=False, # Disable sampling for deterministic output
48
+ early_stopping=False, # Stop when sufficient beams are complete
49
+ suppress_tokens=[],
50
+ )
51
+
52
+
53
+ # 3. Batch‐wise transcription function
54
+ def transcribe_batch(batch):
55
+ # `batch["audio"]` is a list of {"array": np.ndarray, ...}
56
+ inputs = [ ex["array"] for ex in batch["audio"] ]
57
+ outputs = asr(inputs) # returns a list of dicts with "text"
58
+ # lower-case and strip to normalize for CER
59
+ preds = [ out["text"].lower().strip() for out in outputs ]
60
+ return {"prediction": preds}
61
+
62
+ # 4. Map over the dataset in chunks of, say, 32 examples at a time
63
+ result = ds.map(
64
+ transcribe_batch,
65
+ batched=True,
66
+ batch_size=64, # feed 32 audios → pipeline will sub-batch into 8s
67
+ remove_columns=ds.column_names
68
+ )
69
+
70
+ # ipdb.set_trace()
71
+ # 5. Compute corpus-level CER with jiwer
72
+ # refs = "\n".join(t.lower().strip() for t in ds["transcription"])
73
+ # preds = "\n".join(t for t in result["prediction"])
74
+ # score = jiwer_cer(refs, preds)
75
+ refs = [t.lower().strip() for t in ds["transcription"]]
76
+ preds = [t for t in result["prediction"]]
77
+ score_cer = jiwer_cer(refs, preds)
78
+ score_wer = jiwer_wer(refs, preds)
79
+
80
+ print(f"CER on FLEURS km_kh: {score_cer*100:.2f}%")
81
+ print(f"WER on FLEURS km_kh: {score_wer*100:.2f}%")
82
+
83
+ with open("./km_kh_finetune.pred", "w") as pred_results:
84
+ for pred in preds:
85
+ pred_results.write("{}\n".format(pred))
86
+
87
+ with open("./km_kh.ref", "w") as ref_results:
88
+ for ref in refs:
89
+ ref_results.write("{}\n".format(ref))