Commit
·
b1b221d
1
Parent(s):
702a986
update README
Browse files- README.md +138 -0
- config.yaml +107 -0
- finetune-backup.py +243 -0
- finetune.py +673 -0
- inference.py +89 -0
README.md
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Whisper Fine-tuning for Khmer Language
|
| 2 |
+
|
| 3 |
+
This project provides a configurable way to fine-tune OpenAI's Whisper model specifically on the Khmer language using the Google FLEURS dataset (km_kh).
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- **Flexible Configuration**: All parameters are configurable through YAML files
|
| 8 |
+
- **Multi-GPU Support**: Automatic detection and support for multiple GPUs
|
| 9 |
+
- **Dynamic Language Selection**: Train on any subset of supported languages
|
| 10 |
+
- **On-the-fly Processing**: Efficient memory usage with dynamic audio preprocessing
|
| 11 |
+
- **Comprehensive Evaluation**: Automatic evaluation on test sets
|
| 12 |
+
|
| 13 |
+
## Configuration
|
| 14 |
+
|
| 15 |
+
All parameters are configurable through the `config.yaml` file. This configuration is specifically set up for Khmer language training using the Google FLEURS dataset.
|
| 16 |
+
|
| 17 |
+
### Model Configuration
|
| 18 |
+
- Model checkpoint (default: `openai/whisper-large-v3`)
|
| 19 |
+
- Maximum target length for sequences
|
| 20 |
+
|
| 21 |
+
### Dataset Configuration
|
| 22 |
+
- Uses Google FLEURS Khmer (km_kh) dataset
|
| 23 |
+
- Dataset sources and splits
|
| 24 |
+
- Language-specific settings
|
| 25 |
+
- Training subset ratio (25% of data for faster training)
|
| 26 |
+
|
| 27 |
+
### Training Configuration
|
| 28 |
+
- Learning rate, batch sizes, training steps
|
| 29 |
+
- Multi-GPU vs single GPU settings
|
| 30 |
+
- Evaluation and logging parameters
|
| 31 |
+
|
| 32 |
+
### Environment Configuration
|
| 33 |
+
- CPU core limits
|
| 34 |
+
- Environment variables for optimization
|
| 35 |
+
|
| 36 |
+
### Pushing to Hub
|
| 37 |
+
- I have set the configuration to not push to the Hugging Face Hub by default. You can enable this by setting `push_to_hub: true` in your config file.
|
| 38 |
+
|
| 39 |
+
## Usage
|
| 40 |
+
|
| 41 |
+
### Basic Usage
|
| 42 |
+
```bash
|
| 43 |
+
python finetune.py --config config.yaml
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
### Custom Configuration
|
| 47 |
+
```bash
|
| 48 |
+
python finetune.py --config my_custom_config.yaml
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### Multi-GPU Training
|
| 52 |
+
Since we only have very few training data (around 2.5 hours), multi-GPU training is not recommended.
|
| 53 |
+
|
| 54 |
+
## Configuration File Structure
|
| 55 |
+
|
| 56 |
+
The `config.yaml` file is organized into the following sections:
|
| 57 |
+
|
| 58 |
+
1. **model**: Model checkpoint and sequence length settings
|
| 59 |
+
2. **output**: Output directory configuration
|
| 60 |
+
3. **environment**: Environment variables and CPU settings
|
| 61 |
+
4. **audio**: Audio processing settings (sampling rate)
|
| 62 |
+
5. **languages**: Khmer language configuration
|
| 63 |
+
6. **datasets**: Google FLEURS Khmer dataset configuration
|
| 64 |
+
7. **training**: All training hyperparameters
|
| 65 |
+
8. **data_processing**: Data processing settings
|
| 66 |
+
|
| 67 |
+
## Customizing Your Training
|
| 68 |
+
|
| 69 |
+
### Adjusting Training Parameters
|
| 70 |
+
Modify the `training` section in `config.yaml`:
|
| 71 |
+
- Change learning rate, batch sizes, or training steps
|
| 72 |
+
- Adjust evaluation frequency
|
| 73 |
+
- Configure multi-GPU settings
|
| 74 |
+
|
| 75 |
+
### Environment Optimization
|
| 76 |
+
Adjust the `environment` section to optimize for your system:
|
| 77 |
+
- Set CPU core limits
|
| 78 |
+
- Configure memory usage settings
|
| 79 |
+
|
| 80 |
+
## Configuration
|
| 81 |
+
|
| 82 |
+
The provided `config.yaml` is specifically configured for Khmer language training using the Google FLEURS dataset.
|
| 83 |
+
|
| 84 |
+
## Training Commands
|
| 85 |
+
|
| 86 |
+
### Basic Training
|
| 87 |
+
```bash
|
| 88 |
+
python finetune.py
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
### Single GPU Training
|
| 92 |
+
```bash
|
| 93 |
+
python finetune.py
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
## Inference Guide
|
| 97 |
+
|
| 98 |
+
After training your model, you can use the provided `inference.py` script for speech recognition:
|
| 99 |
+
|
| 100 |
+
```bash
|
| 101 |
+
python inference.py
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
The inference script includes:
|
| 105 |
+
- Model loading from the trained checkpoint
|
| 106 |
+
- Audio preprocessing pipeline
|
| 107 |
+
- Text generation with proper formatting
|
| 108 |
+
- Support for Khmer language transcription
|
| 109 |
+
|
| 110 |
+
### Using the Trained Model
|
| 111 |
+
|
| 112 |
+
The inference script automatically handles:
|
| 113 |
+
- Loading the fine-tuned model weights
|
| 114 |
+
- Audio preprocessing with proper sampling rate
|
| 115 |
+
- Generating transcriptions for Khmer speech
|
| 116 |
+
- Output formatting for evaluation metrics
|
| 117 |
+
|
| 118 |
+
## Dependencies
|
| 119 |
+
|
| 120 |
+
Install required packages:
|
| 121 |
+
```bash
|
| 122 |
+
pip install -r requirements.txt
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
Key dependencies:
|
| 126 |
+
- PyYAML (for configuration loading)
|
| 127 |
+
- torch, transformers, datasets
|
| 128 |
+
- librosa (for audio processing)
|
| 129 |
+
- evaluate (for metrics)
|
| 130 |
+
|
| 131 |
+
## Evaluation Results
|
| 132 |
+
| Language | Metric | Error Rate |
|
| 133 |
+
|-------------|:------:|-----------:|
|
| 134 |
+
| Khmer | CER | 33.18% |
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
**Note**: If you encounter issues running finetune.py, you can use the `finetune-backup.py` file which contains the original hardcoded configuration.
|
config.yaml
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration for training only on Khmer (FLEURS km_kh) data
|
| 2 |
+
# Fine-tuning Whisper on Khmer language using Google FLEURS dataset
|
| 3 |
+
|
| 4 |
+
# Model Configuration
|
| 5 |
+
model:
|
| 6 |
+
checkpoint: "openai/whisper-large-v3"
|
| 7 |
+
max_target_length: 448
|
| 8 |
+
|
| 9 |
+
# Output Configuration
|
| 10 |
+
output:
|
| 11 |
+
output_dir: "./whisper-fleurs-km_kh-small"
|
| 12 |
+
|
| 13 |
+
# Environment Configuration
|
| 14 |
+
environment:
|
| 15 |
+
max_cpu_cores: 20
|
| 16 |
+
test_cpu_cores: 20
|
| 17 |
+
omp_num_threads: "20"
|
| 18 |
+
mkl_num_threads: "20"
|
| 19 |
+
openblas_num_threads: "20"
|
| 20 |
+
veclib_maximum_threads: "20"
|
| 21 |
+
numexpr_num_threads: "20"
|
| 22 |
+
tokenizers_parallelism: "false"
|
| 23 |
+
transformers_no_tf: "1"
|
| 24 |
+
|
| 25 |
+
# Audio Processing Configuration
|
| 26 |
+
audio:
|
| 27 |
+
sampling_rate: 16000
|
| 28 |
+
|
| 29 |
+
# Language Configurations - Khmer only
|
| 30 |
+
languages:
|
| 31 |
+
|
| 32 |
+
khmer:
|
| 33 |
+
whisper_language: "khmer"
|
| 34 |
+
fleurs_language: "km_kh"
|
| 35 |
+
text_key: "transcription"
|
| 36 |
+
train_subset_ratio: 0.25 # Use only 25% of training data for faster training/experimentation
|
| 37 |
+
|
| 38 |
+
# Dataset Configurations - Khmer FLEURS
|
| 39 |
+
datasets:
|
| 40 |
+
|
| 41 |
+
khmer:
|
| 42 |
+
source: "google/fleurs"
|
| 43 |
+
language_code: "km_kh"
|
| 44 |
+
splits:
|
| 45 |
+
train: "train"
|
| 46 |
+
validation: "validation"
|
| 47 |
+
test: "test"
|
| 48 |
+
trust_remote_code: true
|
| 49 |
+
|
| 50 |
+
# Training Configuration
|
| 51 |
+
training:
|
| 52 |
+
# Basic training parameters
|
| 53 |
+
learning_rate: 1.0e-5
|
| 54 |
+
warmup_steps: 100
|
| 55 |
+
max_steps: 800
|
| 56 |
+
|
| 57 |
+
# Batch size and accumulation
|
| 58 |
+
single_gpu:
|
| 59 |
+
per_device_train_batch_size: 16
|
| 60 |
+
per_device_eval_batch_size: 16
|
| 61 |
+
gradient_accumulation_steps: 1
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# Optimization settings
|
| 65 |
+
gradient_checkpointing: true
|
| 66 |
+
fp16: true
|
| 67 |
+
|
| 68 |
+
# Evaluation settings
|
| 69 |
+
eval_strategy: "steps"
|
| 70 |
+
eval_steps: 100
|
| 71 |
+
predict_with_generate: true
|
| 72 |
+
generation_max_length: 225
|
| 73 |
+
|
| 74 |
+
# Saving and logging
|
| 75 |
+
save_steps: 100
|
| 76 |
+
logging_steps: 25
|
| 77 |
+
save_total_limit: 3
|
| 78 |
+
|
| 79 |
+
# Model selection
|
| 80 |
+
load_best_model_at_end: true
|
| 81 |
+
metric_for_best_model: "cer" # Using CER for Khmer (character-based language)
|
| 82 |
+
greater_is_better: false
|
| 83 |
+
|
| 84 |
+
# Reporting
|
| 85 |
+
report_to:
|
| 86 |
+
- "tensorboard"
|
| 87 |
+
|
| 88 |
+
# Hub settings
|
| 89 |
+
push_to_hub: false
|
| 90 |
+
|
| 91 |
+
# Multi-GPU specific settings
|
| 92 |
+
dataloader_drop_last: true
|
| 93 |
+
ddp_find_unused_parameters: false
|
| 94 |
+
|
| 95 |
+
# Data Processing Configuration
|
| 96 |
+
data_processing:
|
| 97 |
+
# Random seed for reproducibility
|
| 98 |
+
seed: 42
|
| 99 |
+
|
| 100 |
+
# Columns to remove during standardization
|
| 101 |
+
columns_to_remove:
|
| 102 |
+
- "id"
|
| 103 |
+
- "num_samples"
|
| 104 |
+
- "path"
|
| 105 |
+
- "speaker_id"
|
| 106 |
+
- "chapter_id"
|
| 107 |
+
- "segment_id"
|
finetune-backup.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# finetune_whisper_fleurs_my_mm.py
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Fine-tune openai/whisper-large-v3 on the FLEURS Burmese dataset (config: "my_mm").
|
| 6 |
+
Based on the Hugging Face blog: https://huggingface.co/blog/fine-tune-whisper
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
os.environ["TRANSFORMERS_NO_TF"] = "1"
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from datasets import load_dataset, Audio
|
| 15 |
+
from transformers import (
|
| 16 |
+
WhisperProcessor,
|
| 17 |
+
WhisperForConditionalGeneration,
|
| 18 |
+
Seq2SeqTrainingArguments,
|
| 19 |
+
Seq2SeqTrainer,
|
| 20 |
+
)
|
| 21 |
+
import ipdb
|
| 22 |
+
import evaluate
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
from dataclasses import dataclass
|
| 26 |
+
from typing import Any, Dict, List, Union
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class DataCollatorSpeechSeq2SeqWithPadding:
|
| 30 |
+
processor: Any
|
| 31 |
+
decoder_start_token_id: int
|
| 32 |
+
|
| 33 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
| 34 |
+
# split inputs and labels since they have to be of different lengths and need different padding methods
|
| 35 |
+
# first treat the audio inputs by simply returning torch tensors
|
| 36 |
+
input_features = [{"input_features": feature["input_features"]} for feature in features]
|
| 37 |
+
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
| 38 |
+
|
| 39 |
+
# get the tokenized label sequences
|
| 40 |
+
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
| 41 |
+
# pad the labels to max length
|
| 42 |
+
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
|
| 43 |
+
|
| 44 |
+
# replace padding with -100 to ignore loss correctly
|
| 45 |
+
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
| 46 |
+
|
| 47 |
+
# if bos token is appended in previous tokenization step,
|
| 48 |
+
# cut bos token here as it's append later anyways
|
| 49 |
+
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
| 50 |
+
labels = labels[:, 1:]
|
| 51 |
+
|
| 52 |
+
batch["labels"] = labels
|
| 53 |
+
|
| 54 |
+
return batch
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# → Choose device (GPU if available)
|
| 60 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 61 |
+
|
| 62 |
+
# 1. Configuration
|
| 63 |
+
LANGUAGE = "km_kh" # FLEURS config for Khmer
|
| 64 |
+
LANGUAGE_WHISPER = "khmer" # Whisper config for Khmer
|
| 65 |
+
MODEL_CHECKPOINT = "openai/whisper-large-v3"
|
| 66 |
+
OUTPUT_DIR = f"./whisper-fleurs-{LANGUAGE}-small"
|
| 67 |
+
TRAIN_SPLIT = "train"
|
| 68 |
+
VALID_SPLIT = "validation"
|
| 69 |
+
TEST_SPLIT = "test"
|
| 70 |
+
MAX_TARGET_LENGTH= 448
|
| 71 |
+
# 2. Load FLEURS Dataset (audio at 16 kHz)
|
| 72 |
+
raw_datasets = load_dataset("google/fleurs", LANGUAGE,
|
| 73 |
+
split={ "train": TRAIN_SPLIT,
|
| 74 |
+
"validation": VALID_SPLIT,
|
| 75 |
+
"test": TEST_SPLIT })
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# Cast “audio” column to 16 kHz
|
| 80 |
+
for split in ["train", "validation", "test"]:
|
| 81 |
+
raw_datasets[split] = raw_datasets[split].cast_column("audio", Audio(sampling_rate=16_000))
|
| 82 |
+
|
| 83 |
+
raw_datasets["train"] = raw_datasets["train"].train_test_split(test_size=0.75, seed=42)["test"]
|
| 84 |
+
# 3. Load Whisper Processor & Model
|
| 85 |
+
|
| 86 |
+
processor = WhisperProcessor.from_pretrained(MODEL_CHECKPOINT, language=LANGUAGE_WHISPER)
|
| 87 |
+
model = WhisperForConditionalGeneration.from_pretrained(MODEL_CHECKPOINT)
|
| 88 |
+
model.to(device)
|
| 89 |
+
|
| 90 |
+
# 4. Preprocessing Function
|
| 91 |
+
# - Extract log‐Mel features from audio
|
| 92 |
+
# - Tokenize the target transcription
|
| 93 |
+
def preprocess_batch(batch):
|
| 94 |
+
# batch["audio"]["array"] is a list of NumPy arrays @ 16 kHz
|
| 95 |
+
audio_arrays = [example["array"] for example in batch["audio"]]
|
| 96 |
+
# 4a. Feature extraction (log‐Mel + normalization)
|
| 97 |
+
inputs = processor.feature_extractor(
|
| 98 |
+
audio_arrays,
|
| 99 |
+
sampling_rate=16_000,
|
| 100 |
+
return_tensors="pt"
|
| 101 |
+
)
|
| 102 |
+
# 4b. Tokenize (labels) using the Whisper tokenizer
|
| 103 |
+
# We prefix with target language ID (e.g. "<|my_mm|>") if necessary;
|
| 104 |
+
# but for FLEURS, the default Whisper language‐ID tokens should suffice.
|
| 105 |
+
labels = processor.tokenizer(
|
| 106 |
+
batch["transcription"],
|
| 107 |
+
return_tensors="pt",
|
| 108 |
+
padding="longest",
|
| 109 |
+
truncation=True,
|
| 110 |
+
max_length=MAX_TARGET_LENGTH
|
| 111 |
+
)
|
| 112 |
+
# ipdb.set_trace()
|
| 113 |
+
# rename for trainer:
|
| 114 |
+
inputs["input_features"] = inputs.pop("input_features")
|
| 115 |
+
inputs["labels"] = labels.input_ids
|
| 116 |
+
return inputs
|
| 117 |
+
|
| 118 |
+
# 5. Apply preprocessing to train/validation/test
|
| 119 |
+
# - Remove all non‐audio columns after mapping
|
| 120 |
+
train_dataset = raw_datasets["train"].map(
|
| 121 |
+
preprocess_batch,
|
| 122 |
+
remove_columns=raw_datasets["train"].column_names,
|
| 123 |
+
batched=True,
|
| 124 |
+
batch_size=16, # adjust batch_size to your memory
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# ipdb.set_trace()
|
| 128 |
+
eval_dataset = raw_datasets["validation"].map(
|
| 129 |
+
preprocess_batch,
|
| 130 |
+
remove_columns=raw_datasets["validation"].column_names,
|
| 131 |
+
batched=True,
|
| 132 |
+
batch_size=8,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
test_dataset = raw_datasets["test"].map(
|
| 136 |
+
preprocess_batch,
|
| 137 |
+
remove_columns=raw_datasets["test"].column_names,
|
| 138 |
+
batched=True,
|
| 139 |
+
batch_size=8,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# 6. Data Collator
|
| 143 |
+
# This will pad input_features and labels to the maximum length in the batch,
|
| 144 |
+
# and replace padding token ID in labels by -100 to ignore them in loss computation.
|
| 145 |
+
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
| 146 |
+
processor=processor,
|
| 147 |
+
decoder_start_token_id=model.config.decoder_start_token_id,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# 7. Metrics: WER & CER (using Hugging Face Evaluate)
|
| 151 |
+
wer_metric = evaluate.load("wer")
|
| 152 |
+
cer_metric = evaluate.load("cer")
|
| 153 |
+
|
| 154 |
+
def compute_metrics(pred):
|
| 155 |
+
"""
|
| 156 |
+
pred.predictions: raw token IDs from generate()
|
| 157 |
+
pred.label_ids: token IDs used as labels
|
| 158 |
+
"""
|
| 159 |
+
# 7a. decode predictions → strings
|
| 160 |
+
pred_ids = pred.predictions
|
| 161 |
+
# ensure we skip special tokens
|
| 162 |
+
pred_str = processor.batch_decode(pred_ids,
|
| 163 |
+
skip_special_tokens=True)
|
| 164 |
+
# 7b. decode references → strings, replacing -100 with padding_token_id
|
| 165 |
+
label_ids = pred.label_ids
|
| 166 |
+
# replace -100 with pad_token_id so that the tokenizer does not crash
|
| 167 |
+
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
|
| 168 |
+
ref_str = processor.batch_decode(label_ids, skip_special_tokens=True)
|
| 169 |
+
|
| 170 |
+
# lowercase & strip
|
| 171 |
+
pred_str = [s.lower().strip() for s in pred_str]
|
| 172 |
+
ref_str = [s.lower().strip() for s in ref_str]
|
| 173 |
+
|
| 174 |
+
wer_score = wer_metric.compute(predictions=pred_str, references=ref_str)
|
| 175 |
+
cer_score = cer_metric.compute(predictions=pred_str, references=ref_str)
|
| 176 |
+
return { "wer": wer_score, "cer": cer_score }
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
"""
|
| 180 |
+
# 8. Training Arguments
|
| 181 |
+
training_args = Seq2SeqTrainingArguments(
|
| 182 |
+
output_dir=OUTPUT_DIR,
|
| 183 |
+
per_device_train_batch_size=4, # reduce if you OOM; or increase if large GPU
|
| 184 |
+
per_device_eval_batch_size=4,
|
| 185 |
+
gradient_accumulation_steps=2, # to simulate a larger batch
|
| 186 |
+
evaluation_strategy="steps",
|
| 187 |
+
eval_steps=500, # evaluate every 500 steps
|
| 188 |
+
logging_steps=250,
|
| 189 |
+
save_steps=1000,
|
| 190 |
+
num_train_epochs=3,
|
| 191 |
+
learning_rate=1e-5,
|
| 192 |
+
warmup_steps=500,
|
| 193 |
+
fp16=True, # use mixed precision if supported
|
| 194 |
+
predict_with_generate=True, # for computing WER/CER we need generate()
|
| 195 |
+
save_total_limit=2,
|
| 196 |
+
push_to_hub=False,
|
| 197 |
+
)
|
| 198 |
+
"""
|
| 199 |
+
training_args = Seq2SeqTrainingArguments(
|
| 200 |
+
output_dir=OUTPUT_DIR,
|
| 201 |
+
per_device_train_batch_size=16,
|
| 202 |
+
gradient_accumulation_steps=1,
|
| 203 |
+
learning_rate=1e-5,
|
| 204 |
+
warmup_steps=100,
|
| 205 |
+
max_steps=800,
|
| 206 |
+
gradient_checkpointing=True,
|
| 207 |
+
fp16=True,
|
| 208 |
+
eval_strategy="steps",
|
| 209 |
+
per_device_eval_batch_size=8,
|
| 210 |
+
predict_with_generate=True,
|
| 211 |
+
generation_max_length=448,
|
| 212 |
+
save_steps=100,
|
| 213 |
+
eval_steps=100,
|
| 214 |
+
logging_steps=10,
|
| 215 |
+
report_to=["tensorboard"],
|
| 216 |
+
load_best_model_at_end=True,
|
| 217 |
+
metric_for_best_model="wer",
|
| 218 |
+
greater_is_better=False,
|
| 219 |
+
push_to_hub=True
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# 9. Initialize Seq2SeqTrainer
|
| 224 |
+
trainer = Seq2SeqTrainer(
|
| 225 |
+
model=model,
|
| 226 |
+
args=training_args,
|
| 227 |
+
train_dataset=train_dataset,
|
| 228 |
+
eval_dataset=eval_dataset,
|
| 229 |
+
data_collator=data_collator,
|
| 230 |
+
tokenizer=processor.feature_extractor, # feature_extractor + tokenizer packed in processor
|
| 231 |
+
compute_metrics=compute_metrics,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# 10. Fine-tune
|
| 235 |
+
if __name__ == "__main__":
|
| 236 |
+
# 10a. Train
|
| 237 |
+
trainer.train()
|
| 238 |
+
|
| 239 |
+
# 10b. Evaluate on TEST split
|
| 240 |
+
print("\n***** Evaluating on TEST split *****")
|
| 241 |
+
test_metrics = trainer.predict(test_dataset, metric_key_prefix="test")
|
| 242 |
+
print(f"Test WER: {test_metrics.metrics['test_wer']*100:.2f}%")
|
| 243 |
+
print(f"Test CER: {test_metrics.metrics['test_cer']*100:.2f}%")
|
finetune.py
ADDED
|
@@ -0,0 +1,673 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# finetune_whisper_mix_datasets.py
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Fine-tune openai/whisper-large-v3 on mixed datasets from different languages:
|
| 6 |
+
- FLEURS Cebuano (ceb_ph)
|
| 7 |
+
- FLEURS Khmer (km_kh)
|
| 8 |
+
- Switchboard1 English
|
| 9 |
+
- WenetSpeech Chinese
|
| 10 |
+
- Eng-Indon-CS
|
| 11 |
+
- Eng-Malay-CS
|
| 12 |
+
Based on the Hugging Face blog: https://huggingface.co/blog/fine-tune-whisper
|
| 13 |
+
|
| 14 |
+
To run this script on multiple GPUs, you have several options:
|
| 15 |
+
|
| 16 |
+
1. **Automatic Multi-GPU (DataParallel-style):**
|
| 17 |
+
python finetune_whisper_mix_datasets.py
|
| 18 |
+
|
| 19 |
+
The script will automatically detect and use all available GPUs.
|
| 20 |
+
|
| 21 |
+
2. **Distributed Training with torchrun (Recommended for 2+ GPUs):**
|
| 22 |
+
torchrun --nproc_per_node=2 finetune_whisper_mix_datasets.py
|
| 23 |
+
|
| 24 |
+
This uses DistributedDataParallel which is more efficient.
|
| 25 |
+
|
| 26 |
+
3. **Distributed Training with accelerate (Alternative):**
|
| 27 |
+
accelerate launch --num_processes=2 finetune_whisper_mix_datasets.py
|
| 28 |
+
|
| 29 |
+
Requires: pip install accelerate
|
| 30 |
+
|
| 31 |
+
Note: With 2 GPUs, the effective batch size becomes:
|
| 32 |
+
per_device_batch_size * num_gpus * gradient_accumulation_steps
|
| 33 |
+
= 24 * 2 * 1 = 48 (compared to 32 with single GPU)
|
| 34 |
+
|
| 35 |
+
CPU Core Limiting:
|
| 36 |
+
The script automatically limits CPU usage to 20 cores using environment variables.
|
| 37 |
+
You can also set these manually before running:
|
| 38 |
+
export OMP_NUM_THREADS=20
|
| 39 |
+
export MKL_NUM_THREADS=20
|
| 40 |
+
export NUMEXPR_NUM_THREADS=20
|
| 41 |
+
python finetune_whisper_mix_datasets.py
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
import os
|
| 45 |
+
import random
|
| 46 |
+
import io
|
| 47 |
+
import yaml
|
| 48 |
+
import argparse
|
| 49 |
+
from itertools import chain
|
| 50 |
+
|
| 51 |
+
# Load configuration from YAML file
|
| 52 |
+
def load_config(config_path):
|
| 53 |
+
with open(config_path, 'r') as file:
|
| 54 |
+
return yaml.safe_load(file)
|
| 55 |
+
|
| 56 |
+
# Parse command line arguments
|
| 57 |
+
parser = argparse.ArgumentParser(description='Fine-tune Whisper on mixed datasets')
|
| 58 |
+
parser.add_argument('--config', type=str, default='config.yaml',
|
| 59 |
+
help='Path to configuration YAML file')
|
| 60 |
+
args = parser.parse_args()
|
| 61 |
+
|
| 62 |
+
# Load configuration
|
| 63 |
+
config = load_config(args.config)
|
| 64 |
+
|
| 65 |
+
# Set environment variables from config
|
| 66 |
+
env_config = config['environment']
|
| 67 |
+
os.environ["OMP_NUM_THREADS"] = env_config['omp_num_threads']
|
| 68 |
+
os.environ["MKL_NUM_THREADS"] = env_config['mkl_num_threads']
|
| 69 |
+
os.environ["OPENBLAS_NUM_THREADS"] = env_config['openblas_num_threads']
|
| 70 |
+
os.environ["VECLIB_MAXIMUM_THREADS"] = env_config['veclib_maximum_threads']
|
| 71 |
+
os.environ["NUMEXPR_NUM_THREADS"] = env_config['numexpr_num_threads']
|
| 72 |
+
os.environ["TOKENIZERS_PARALLELISM"] = env_config['tokenizers_parallelism']
|
| 73 |
+
os.environ["TRANSFORMERS_NO_TF"] = env_config['transformers_no_tf']
|
| 74 |
+
|
| 75 |
+
import torch
|
| 76 |
+
from datasets import load_dataset, Audio, concatenate_datasets, Dataset
|
| 77 |
+
from torch.utils.data import Dataset as TorchDataset
|
| 78 |
+
from transformers import (
|
| 79 |
+
WhisperProcessor,
|
| 80 |
+
WhisperForConditionalGeneration,
|
| 81 |
+
Seq2SeqTrainingArguments,
|
| 82 |
+
Seq2SeqTrainer,
|
| 83 |
+
)
|
| 84 |
+
import ipdb
|
| 85 |
+
import evaluate
|
| 86 |
+
import numpy as np
|
| 87 |
+
import ipdb
|
| 88 |
+
|
| 89 |
+
# Multi-GPU setup
|
| 90 |
+
if torch.cuda.device_count() > 1:
|
| 91 |
+
print(f"Setting up for {torch.cuda.device_count()} GPUs")
|
| 92 |
+
# Enable distributed training environment variables if not already set
|
| 93 |
+
if "LOCAL_RANK" not in os.environ:
|
| 94 |
+
os.environ["LOCAL_RANK"] = "0"
|
| 95 |
+
if "WORLD_SIZE" not in os.environ:
|
| 96 |
+
os.environ["WORLD_SIZE"] = str(torch.cuda.device_count())
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
from dataclasses import dataclass
|
| 100 |
+
from typing import Any, Dict, List, Union
|
| 101 |
+
|
| 102 |
+
class WhisperOnTheFlyDataset(TorchDataset):
|
| 103 |
+
"""Custom dataset that preprocesses audio on-the-fly during training"""
|
| 104 |
+
|
| 105 |
+
def __init__(self, dataset, processors, main_processor, max_target_length, audio_config):
|
| 106 |
+
self.dataset = dataset
|
| 107 |
+
self.processors = processors
|
| 108 |
+
self.main_processor = main_processor
|
| 109 |
+
self.max_target_length = max_target_length
|
| 110 |
+
self.sampling_rate = audio_config['sampling_rate']
|
| 111 |
+
|
| 112 |
+
def __len__(self):
|
| 113 |
+
return len(self.dataset)
|
| 114 |
+
|
| 115 |
+
def __getitem__(self, idx):
|
| 116 |
+
item = self.dataset[idx]
|
| 117 |
+
# Process audio
|
| 118 |
+
audio_sample = item["audio"]
|
| 119 |
+
audio_data = self._process_audio(audio_sample)
|
| 120 |
+
|
| 121 |
+
# Extract with main processor
|
| 122 |
+
inputs = self.main_processor.feature_extractor(
|
| 123 |
+
audio_data,
|
| 124 |
+
sampling_rate=self.sampling_rate,
|
| 125 |
+
return_tensors="pt"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Process text with appropriate processor
|
| 129 |
+
lang = item["language"]
|
| 130 |
+
if lang in ["cebuano", "khmer"]:
|
| 131 |
+
text = item["transcription"]
|
| 132 |
+
else: # english, chinese
|
| 133 |
+
text = item["text"]
|
| 134 |
+
|
| 135 |
+
# Tokenize with appropriate processor
|
| 136 |
+
if lang == "cebuano":
|
| 137 |
+
labels = self.processors["cebuano"].tokenizer(
|
| 138 |
+
text,
|
| 139 |
+
return_tensors="pt",
|
| 140 |
+
padding=False,
|
| 141 |
+
truncation=True,
|
| 142 |
+
max_length=self.max_target_length
|
| 143 |
+
)
|
| 144 |
+
elif lang == "khmer":
|
| 145 |
+
labels = self.processors["khmer"].tokenizer(
|
| 146 |
+
text,
|
| 147 |
+
return_tensors="pt",
|
| 148 |
+
padding=False,
|
| 149 |
+
truncation=True,
|
| 150 |
+
max_length=self.max_target_length
|
| 151 |
+
)
|
| 152 |
+
elif lang == "english":
|
| 153 |
+
labels = self.processors["english"].tokenizer(
|
| 154 |
+
text,
|
| 155 |
+
return_tensors="pt",
|
| 156 |
+
padding=False
|
| 157 |
+
)
|
| 158 |
+
elif lang == "chinese":
|
| 159 |
+
labels = self.processors["chinese"].tokenizer(
|
| 160 |
+
text,
|
| 161 |
+
return_tensors="pt",
|
| 162 |
+
padding=False
|
| 163 |
+
)
|
| 164 |
+
elif lang == "indonesian":
|
| 165 |
+
labels = self.processors["indonesian"].tokenizer(
|
| 166 |
+
text,
|
| 167 |
+
return_tensors="pt",
|
| 168 |
+
padding=False
|
| 169 |
+
)
|
| 170 |
+
else: # Malay
|
| 171 |
+
labels = self.processors["malay"].tokenizer(
|
| 172 |
+
text,
|
| 173 |
+
return_tensors="pt",
|
| 174 |
+
padding=False
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
return {
|
| 178 |
+
"input_features": inputs.input_features.squeeze(0),
|
| 179 |
+
"labels": labels.input_ids.squeeze(0),
|
| 180 |
+
"language": lang
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
def _process_audio(self, audio_sample):
|
| 184 |
+
"""Process audio sample into numpy array"""
|
| 185 |
+
import librosa
|
| 186 |
+
|
| 187 |
+
if isinstance(audio_sample, dict):
|
| 188 |
+
if "array" in audio_sample:
|
| 189 |
+
return audio_sample["array"]
|
| 190 |
+
elif "bytes" in audio_sample and audio_sample["bytes"] is not None:
|
| 191 |
+
audio_array, _ = librosa.load(io.BytesIO(audio_sample["bytes"]), sr=self.sampling_rate)
|
| 192 |
+
return audio_array
|
| 193 |
+
elif "path" in audio_sample:
|
| 194 |
+
audio_array, _ = librosa.load(audio_sample["path"], sr=self.sampling_rate)
|
| 195 |
+
return audio_array
|
| 196 |
+
else:
|
| 197 |
+
raise ValueError(f"Unknown audio dict format: {audio_sample.keys()}")
|
| 198 |
+
elif isinstance(audio_sample, str):
|
| 199 |
+
audio_array, _ = librosa.load(audio_sample, sr=self.sampling_rate)
|
| 200 |
+
return audio_array
|
| 201 |
+
else:
|
| 202 |
+
return audio_sample
|
| 203 |
+
|
| 204 |
+
@dataclass
|
| 205 |
+
class DataCollatorSpeechSeq2SeqWithPadding:
|
| 206 |
+
processor: Any
|
| 207 |
+
decoder_start_token_id: int
|
| 208 |
+
|
| 209 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
| 210 |
+
# split inputs and labels since they have to be of different lengths and need different padding methods
|
| 211 |
+
# first treat the audio inputs by simply returning torch tensors
|
| 212 |
+
input_features = [{"input_features": feature["input_features"]} for feature in features]
|
| 213 |
+
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
| 214 |
+
|
| 215 |
+
# get the tokenized label sequences
|
| 216 |
+
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
| 217 |
+
# pad the labels to max length
|
| 218 |
+
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
|
| 219 |
+
|
| 220 |
+
# replace padding with -100 to ignore loss correctly
|
| 221 |
+
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
| 222 |
+
|
| 223 |
+
# if bos token is appended in previous tokenization step,
|
| 224 |
+
# cut bos token here as it's append later anyways
|
| 225 |
+
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
| 226 |
+
labels = labels[:, 1:]
|
| 227 |
+
|
| 228 |
+
batch["labels"] = labels
|
| 229 |
+
|
| 230 |
+
return batch
|
| 231 |
+
|
| 232 |
+
# → Choose device (GPU if available)
|
| 233 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 234 |
+
|
| 235 |
+
# Extract configuration values
|
| 236 |
+
MODEL_CHECKPOINT = config['model']['checkpoint']
|
| 237 |
+
OUTPUT_DIR = config['output']['output_dir']
|
| 238 |
+
MAX_TARGET_LENGTH = config['model']['max_target_length']
|
| 239 |
+
|
| 240 |
+
# CPU usage configuration for dataset preprocessing
|
| 241 |
+
MAX_CPU_CORES = config['environment']['max_cpu_cores']
|
| 242 |
+
TEST_CPU_CORES = config['environment']['test_cpu_cores']
|
| 243 |
+
|
| 244 |
+
# Language configurations for each dataset
|
| 245 |
+
DATASET_CONFIGS = config['languages']
|
| 246 |
+
|
| 247 |
+
print("Loading datasets...")
|
| 248 |
+
|
| 249 |
+
# Load datasets for each language dynamically based on configuration
|
| 250 |
+
datasets = {}
|
| 251 |
+
dataset_configs = config['datasets']
|
| 252 |
+
audio_config = config['audio']
|
| 253 |
+
|
| 254 |
+
# Get list of enabled languages from both languages and datasets config
|
| 255 |
+
enabled_languages = set(config['languages'].keys()) & set(config['datasets'].keys())
|
| 256 |
+
print(f"Enabled languages: {list(enabled_languages)}")
|
| 257 |
+
|
| 258 |
+
def load_fleurs_dataset(lang_name, lang_config, dataset_config):
|
| 259 |
+
"""Load FLEURS dataset for a language"""
|
| 260 |
+
print(f"Loading FLEURS {lang_name.title()}...")
|
| 261 |
+
lang_datasets = load_dataset(
|
| 262 |
+
dataset_config['source'],
|
| 263 |
+
dataset_config['language_code'],
|
| 264 |
+
split={k: v for k, v in dataset_config['splits'].items()},
|
| 265 |
+
trust_remote_code=dataset_config['trust_remote_code']
|
| 266 |
+
)
|
| 267 |
+
# DON'T decode audio yet - keep it compressed until preprocessing
|
| 268 |
+
for split in dataset_config['splits'].keys():
|
| 269 |
+
lang_datasets[split] = lang_datasets[split].cast_column("audio", Audio(sampling_rate=audio_config['sampling_rate'], decode=False))
|
| 270 |
+
|
| 271 |
+
# Use subset of training data if specified
|
| 272 |
+
if 'train_subset_ratio' in lang_config:
|
| 273 |
+
train_subset_ratio = lang_config['train_subset_ratio']
|
| 274 |
+
lang_datasets["train"] = lang_datasets["train"].train_test_split(test_size=1-train_subset_ratio, seed=config['data_processing']['seed'])["train"]
|
| 275 |
+
|
| 276 |
+
return lang_datasets
|
| 277 |
+
|
| 278 |
+
def load_simple_dataset(lang_name, dataset_config):
|
| 279 |
+
"""Load simple dataset with train/validation/test splits"""
|
| 280 |
+
print(f"Loading {lang_name.title()}...")
|
| 281 |
+
lang_dataset = load_dataset(dataset_config['source'], split={k: v for k, v in dataset_config['splits'].items()})
|
| 282 |
+
return lang_dataset
|
| 283 |
+
|
| 284 |
+
def load_english_dataset(lang_config, dataset_config):
|
| 285 |
+
"""Load English dataset with custom train/validation split"""
|
| 286 |
+
print("Loading English...")
|
| 287 |
+
swb_train = load_dataset(dataset_config['train_dataset'], split=dataset_config['train_split'], streaming=dataset_config['streaming'])
|
| 288 |
+
swb_test = load_dataset(dataset_config['test_dataset'], split=dataset_config['test_split'], streaming=dataset_config['streaming'])
|
| 289 |
+
# Split into train/validation
|
| 290 |
+
validation_size = lang_config['validation_size']
|
| 291 |
+
swb_val = swb_train.take(validation_size)
|
| 292 |
+
swb_train = swb_train.skip(validation_size)
|
| 293 |
+
return {
|
| 294 |
+
"train": swb_train,
|
| 295 |
+
"validation": swb_val,
|
| 296 |
+
"test": swb_test
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
def load_chinese_dataset(dataset_config):
|
| 300 |
+
"""Load Chinese dataset with multiple test splits"""
|
| 301 |
+
print("Loading Chinese...")
|
| 302 |
+
wenet_train = load_dataset(dataset_config['train_dataset'], streaming=dataset_config['streaming'])
|
| 303 |
+
wenet_valid = load_dataset(dataset_config['validation_dataset'], dataset_config['validation_config'], split="validation", streaming=dataset_config['streaming'])
|
| 304 |
+
wenet_testnet = load_dataset(dataset_config['test_net_dataset'], dataset_config['test_net_config'], split="test", streaming=dataset_config['streaming'])
|
| 305 |
+
wenet_testmeeting = load_dataset(dataset_config['test_meeting_dataset'], dataset_config['test_meeting_config'], split="test", streaming=dataset_config['streaming'])
|
| 306 |
+
return {
|
| 307 |
+
"train": wenet_train["train"],
|
| 308 |
+
"validation": wenet_valid,
|
| 309 |
+
"test_net": wenet_testnet,
|
| 310 |
+
"test_meeting": wenet_testmeeting
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
# Load datasets for each enabled language
|
| 314 |
+
for lang in enabled_languages:
|
| 315 |
+
lang_config = config['languages'][lang]
|
| 316 |
+
dataset_config = dataset_configs[lang]
|
| 317 |
+
|
| 318 |
+
if lang in ['cebuano', 'khmer']:
|
| 319 |
+
# FLEURS datasets
|
| 320 |
+
datasets[lang] = load_fleurs_dataset(lang, lang_config, dataset_config)
|
| 321 |
+
elif lang == 'english':
|
| 322 |
+
# English with custom validation split
|
| 323 |
+
datasets[lang] = load_english_dataset(lang_config, dataset_config)
|
| 324 |
+
elif lang == 'chinese':
|
| 325 |
+
# Chinese with multiple test splits
|
| 326 |
+
datasets[lang] = load_chinese_dataset(dataset_config)
|
| 327 |
+
elif lang in ['indonesian', 'malay']:
|
| 328 |
+
# Simple datasets with standard splits
|
| 329 |
+
datasets[lang] = load_simple_dataset(lang, dataset_config)
|
| 330 |
+
else:
|
| 331 |
+
print(f"Warning: Unknown language {lang}, treating as simple dataset")
|
| 332 |
+
datasets[lang] = load_simple_dataset(lang, dataset_config)
|
| 333 |
+
|
| 334 |
+
print("Setting up processors...")
|
| 335 |
+
|
| 336 |
+
# Create processors for each enabled language
|
| 337 |
+
processors = {}
|
| 338 |
+
for lang in enabled_languages:
|
| 339 |
+
lang_config = config['languages'][lang]
|
| 340 |
+
processors[lang] = WhisperProcessor.from_pretrained(
|
| 341 |
+
MODEL_CHECKPOINT,
|
| 342 |
+
language=lang_config["whisper_language"]
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# Use the first available processor as the main one, preferring English if available
|
| 346 |
+
if "english" in processors:
|
| 347 |
+
main_processor = processors["english"]
|
| 348 |
+
elif processors:
|
| 349 |
+
main_processor = processors[list(processors.keys())[0]]
|
| 350 |
+
else:
|
| 351 |
+
raise ValueError("No processors created. Check your language configuration.")
|
| 352 |
+
model = WhisperForConditionalGeneration.from_pretrained(MODEL_CHECKPOINT)
|
| 353 |
+
|
| 354 |
+
# Multi-GPU handling
|
| 355 |
+
if torch.cuda.device_count() > 1:
|
| 356 |
+
print(f"Using {torch.cuda.device_count()} GPUs for training")
|
| 357 |
+
# The model will be automatically distributed by the Trainer
|
| 358 |
+
model.to(device)
|
| 359 |
+
else:
|
| 360 |
+
model.to(device)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
print("Adding language labels to raw datasets...")
|
| 365 |
+
|
| 366 |
+
# Remove existing language columns and add our own consistent language labels for each enabled language
|
| 367 |
+
for lang in enabled_languages:
|
| 368 |
+
lang_datasets = datasets[lang]
|
| 369 |
+
|
| 370 |
+
# Handle different dataset structures
|
| 371 |
+
if isinstance(lang_datasets, dict):
|
| 372 |
+
# Dataset with explicit splits (train/validation/test)
|
| 373 |
+
for split_name, split_dataset in lang_datasets.items():
|
| 374 |
+
if split_dataset is not None:
|
| 375 |
+
# Remove existing language column if it exists
|
| 376 |
+
columns_to_remove = [col for col in split_dataset.column_names if col.lower() in ["language", "lang"]]
|
| 377 |
+
if columns_to_remove:
|
| 378 |
+
print(f"Removing existing language column(s) {columns_to_remove} from {lang} {split_name}")
|
| 379 |
+
datasets[lang][split_name] = split_dataset.remove_columns(columns_to_remove)
|
| 380 |
+
|
| 381 |
+
# Add our consistent language label
|
| 382 |
+
datasets[lang][split_name] = datasets[lang][split_name].add_column("language", [lang] * len(datasets[lang][split_name]))
|
| 383 |
+
else:
|
| 384 |
+
# Single dataset object - this shouldn't happen with current structure but handle gracefully
|
| 385 |
+
print(f"Warning: Unexpected dataset structure for {lang}")
|
| 386 |
+
continue
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
print("Combining raw datasets before preprocessing...")
|
| 390 |
+
|
| 391 |
+
# Ensure all datasets have compatible schemas before concatenation
|
| 392 |
+
def standardize_dataset_schema(dataset, dataset_name):
|
| 393 |
+
"""Standardize dataset schema to ensure compatibility for concatenation"""
|
| 394 |
+
print(f"Standardizing schema for {dataset_name}...")
|
| 395 |
+
|
| 396 |
+
# Keep audio compressed until preprocessing - only set sampling rate
|
| 397 |
+
if "audio" in dataset.column_names:
|
| 398 |
+
print(f" Setting audio feature type to {audio_config['sampling_rate']}Hz (compressed) for {dataset_name}")
|
| 399 |
+
dataset = dataset.cast_column("audio", Audio(sampling_rate=audio_config['sampling_rate'], decode=False))
|
| 400 |
+
|
| 401 |
+
# Remove problematic columns that might have different types
|
| 402 |
+
columns_to_remove = []
|
| 403 |
+
for col in dataset.column_names:
|
| 404 |
+
if col in config['data_processing']['columns_to_remove']:
|
| 405 |
+
columns_to_remove.append(col)
|
| 406 |
+
|
| 407 |
+
if columns_to_remove:
|
| 408 |
+
print(f" Removing incompatible columns: {columns_to_remove}")
|
| 409 |
+
dataset = dataset.remove_columns(columns_to_remove)
|
| 410 |
+
|
| 411 |
+
return dataset
|
| 412 |
+
|
| 413 |
+
# Standardize all training datasets dynamically
|
| 414 |
+
print("Standardizing training datasets...")
|
| 415 |
+
raw_train_datasets = []
|
| 416 |
+
for lang in enabled_languages:
|
| 417 |
+
if "train" in datasets[lang]:
|
| 418 |
+
std_dataset = standardize_dataset_schema(datasets[lang]["train"], f"{lang}_train")
|
| 419 |
+
raw_train_datasets.append(std_dataset)
|
| 420 |
+
|
| 421 |
+
# Standardize validation datasets dynamically
|
| 422 |
+
print("Standardizing validation datasets...")
|
| 423 |
+
raw_val_datasets = []
|
| 424 |
+
for lang in enabled_languages:
|
| 425 |
+
if "validation" in datasets[lang]:
|
| 426 |
+
std_dataset = standardize_dataset_schema(datasets[lang]["validation"], f"{lang}_val")
|
| 427 |
+
raw_val_datasets.append(std_dataset)
|
| 428 |
+
|
| 429 |
+
# Combine datasets if we have any
|
| 430 |
+
if raw_train_datasets:
|
| 431 |
+
print("Combining training datasets...")
|
| 432 |
+
combined_raw_train = concatenate_datasets(raw_train_datasets)
|
| 433 |
+
combined_raw_train = combined_raw_train.shuffle(seed=config['data_processing']['seed'])
|
| 434 |
+
else:
|
| 435 |
+
raise ValueError("No training datasets found. Check your configuration.")
|
| 436 |
+
|
| 437 |
+
if raw_val_datasets:
|
| 438 |
+
print("Combining validation datasets...")
|
| 439 |
+
combined_raw_val = concatenate_datasets(raw_val_datasets)
|
| 440 |
+
combined_raw_val = combined_raw_val.shuffle(seed=config['data_processing']['seed'])
|
| 441 |
+
else:
|
| 442 |
+
print("Warning: No validation datasets found. Training without validation.")
|
| 443 |
+
combined_raw_val = None
|
| 444 |
+
|
| 445 |
+
print("Creating on-the-fly datasets (no preprocessing stored to disk)...")
|
| 446 |
+
|
| 447 |
+
# Create on-the-fly datasets instead of preprocessing and storing
|
| 448 |
+
# Create on-the-fly datasets instead of preprocessing and storing
|
| 449 |
+
combined_train_dataset = WhisperOnTheFlyDataset(
|
| 450 |
+
combined_raw_train,
|
| 451 |
+
processors,
|
| 452 |
+
main_processor,
|
| 453 |
+
MAX_TARGET_LENGTH,
|
| 454 |
+
audio_config
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
# Only create validation dataset if we have validation data
|
| 458 |
+
if combined_raw_val is not None:
|
| 459 |
+
combined_val_dataset = WhisperOnTheFlyDataset(
|
| 460 |
+
combined_raw_val,
|
| 461 |
+
processors,
|
| 462 |
+
main_processor,
|
| 463 |
+
MAX_TARGET_LENGTH,
|
| 464 |
+
audio_config
|
| 465 |
+
)
|
| 466 |
+
else:
|
| 467 |
+
combined_val_dataset = None
|
| 468 |
+
|
| 469 |
+
print("Creating on-the-fly test datasets...")
|
| 470 |
+
|
| 471 |
+
# Create on-the-fly test datasets dynamically
|
| 472 |
+
processed_datasets = {}
|
| 473 |
+
|
| 474 |
+
for lang in enabled_languages:
|
| 475 |
+
processed_datasets[lang] = {}
|
| 476 |
+
|
| 477 |
+
# Handle different test split structures for different languages
|
| 478 |
+
if lang == "chinese":
|
| 479 |
+
# Chinese has multiple test splits
|
| 480 |
+
if "test_net" in datasets[lang]:
|
| 481 |
+
processed_datasets[lang]["test_net"] = WhisperOnTheFlyDataset(
|
| 482 |
+
datasets[lang]["test_net"],
|
| 483 |
+
processors,
|
| 484 |
+
main_processor,
|
| 485 |
+
MAX_TARGET_LENGTH,
|
| 486 |
+
audio_config
|
| 487 |
+
)
|
| 488 |
+
if "test_meeting" in datasets[lang]:
|
| 489 |
+
processed_datasets[lang]["test_meeting"] = WhisperOnTheFlyDataset(
|
| 490 |
+
datasets[lang]["test_meeting"],
|
| 491 |
+
processors,
|
| 492 |
+
main_processor,
|
| 493 |
+
MAX_TARGET_LENGTH,
|
| 494 |
+
audio_config
|
| 495 |
+
)
|
| 496 |
+
else:
|
| 497 |
+
# Standard test split
|
| 498 |
+
if "test" in datasets[lang]:
|
| 499 |
+
processed_datasets[lang]["test"] = WhisperOnTheFlyDataset(
|
| 500 |
+
datasets[lang]["test"],
|
| 501 |
+
processors,
|
| 502 |
+
main_processor,
|
| 503 |
+
MAX_TARGET_LENGTH,
|
| 504 |
+
audio_config
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
# Data Collator
|
| 508 |
+
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
| 509 |
+
processor=main_processor,
|
| 510 |
+
decoder_start_token_id=model.config.decoder_start_token_id,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# Metrics: WER & CER (using Hugging Face Evaluate)
|
| 514 |
+
wer_metric = evaluate.load("wer")
|
| 515 |
+
cer_metric = evaluate.load("cer")
|
| 516 |
+
|
| 517 |
+
def compute_metrics(pred):
|
| 518 |
+
"""
|
| 519 |
+
Compute WER and CER metrics for predictions
|
| 520 |
+
"""
|
| 521 |
+
pred_ids = pred.predictions
|
| 522 |
+
pred_str = main_processor.batch_decode(pred_ids, skip_special_tokens=True)
|
| 523 |
+
|
| 524 |
+
label_ids = pred.label_ids
|
| 525 |
+
label_ids[label_ids == -100] = main_processor.tokenizer.pad_token_id
|
| 526 |
+
ref_str = main_processor.batch_decode(label_ids, skip_special_tokens=True)
|
| 527 |
+
|
| 528 |
+
# lowercase & strip
|
| 529 |
+
pred_str = [s.lower().strip() for s in pred_str]
|
| 530 |
+
ref_str = [s.lower().strip() for s in ref_str]
|
| 531 |
+
|
| 532 |
+
wer_score = wer_metric.compute(predictions=pred_str, references=ref_str)
|
| 533 |
+
cer_score = cer_metric.compute(predictions=pred_str, references=ref_str)
|
| 534 |
+
return {"wer": wer_score, "cer": cer_score}
|
| 535 |
+
|
| 536 |
+
# Check for multi-GPU setup
|
| 537 |
+
num_gpus = torch.cuda.device_count()
|
| 538 |
+
print(f"Number of available GPUs: {num_gpus}")
|
| 539 |
+
|
| 540 |
+
# Get training configuration
|
| 541 |
+
training_config = config['training']
|
| 542 |
+
|
| 543 |
+
# Adjust batch size and gradient accumulation for multi-GPU
|
| 544 |
+
if num_gpus > 1:
|
| 545 |
+
# With multiple GPUs, use multi-GPU configuration
|
| 546 |
+
gpu_config = training_config['multi_gpu']
|
| 547 |
+
per_device_batch_size = gpu_config['per_device_train_batch_size']
|
| 548 |
+
per_device_eval_batch_size = gpu_config['per_device_eval_batch_size']
|
| 549 |
+
gradient_accumulation_steps = gpu_config['gradient_accumulation_steps']
|
| 550 |
+
print(f"Multi-GPU training detected. Using {num_gpus} GPUs.")
|
| 551 |
+
else:
|
| 552 |
+
# Single GPU configuration
|
| 553 |
+
gpu_config = training_config['single_gpu']
|
| 554 |
+
per_device_batch_size = gpu_config['per_device_train_batch_size']
|
| 555 |
+
per_device_eval_batch_size = gpu_config['per_device_eval_batch_size']
|
| 556 |
+
gradient_accumulation_steps = gpu_config['gradient_accumulation_steps']
|
| 557 |
+
print("Single GPU training.")
|
| 558 |
+
|
| 559 |
+
# Training Arguments
|
| 560 |
+
training_args = Seq2SeqTrainingArguments(
|
| 561 |
+
output_dir=OUTPUT_DIR,
|
| 562 |
+
per_device_train_batch_size=per_device_batch_size,
|
| 563 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 564 |
+
learning_rate=training_config['learning_rate'],
|
| 565 |
+
warmup_steps=training_config['warmup_steps'],
|
| 566 |
+
max_steps=training_config['max_steps'],
|
| 567 |
+
gradient_checkpointing=training_config['gradient_checkpointing'],
|
| 568 |
+
fp16=training_config['fp16'],
|
| 569 |
+
eval_strategy=training_config['eval_strategy'],
|
| 570 |
+
per_device_eval_batch_size=per_device_eval_batch_size,
|
| 571 |
+
predict_with_generate=training_config['predict_with_generate'],
|
| 572 |
+
generation_max_length=training_config['generation_max_length'],
|
| 573 |
+
save_steps=training_config['save_steps'],
|
| 574 |
+
eval_steps=training_config['eval_steps'],
|
| 575 |
+
logging_steps=training_config['logging_steps'],
|
| 576 |
+
report_to=training_config['report_to'],
|
| 577 |
+
load_best_model_at_end=training_config['load_best_model_at_end'],
|
| 578 |
+
metric_for_best_model=training_config['metric_for_best_model'],
|
| 579 |
+
greater_is_better=training_config['greater_is_better'],
|
| 580 |
+
push_to_hub=training_config['push_to_hub'],
|
| 581 |
+
save_total_limit=training_config['save_total_limit'],
|
| 582 |
+
# Multi-GPU specific settings
|
| 583 |
+
dataloader_drop_last=training_config['dataloader_drop_last'],
|
| 584 |
+
ddp_find_unused_parameters=training_config['ddp_find_unused_parameters'],
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
# Initialize Seq2SeqTrainer
|
| 588 |
+
trainer = Seq2SeqTrainer(
|
| 589 |
+
model=model,
|
| 590 |
+
args=training_args,
|
| 591 |
+
train_dataset=combined_train_dataset,
|
| 592 |
+
eval_dataset=combined_val_dataset,
|
| 593 |
+
data_collator=data_collator,
|
| 594 |
+
tokenizer=main_processor.feature_extractor,
|
| 595 |
+
compute_metrics=compute_metrics,
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
def evaluate_on_test_sets():
|
| 599 |
+
"""Evaluate the model on all test sets from enabled languages"""
|
| 600 |
+
print("\n" + "="*60)
|
| 601 |
+
print("EVALUATING ON ALL TEST SETS")
|
| 602 |
+
print("="*60)
|
| 603 |
+
|
| 604 |
+
results = {}
|
| 605 |
+
|
| 606 |
+
for lang in enabled_languages:
|
| 607 |
+
if lang in processed_datasets:
|
| 608 |
+
lang_results = {}
|
| 609 |
+
|
| 610 |
+
if lang == "chinese":
|
| 611 |
+
# Chinese has multiple test splits
|
| 612 |
+
if "test_net" in processed_datasets[lang]:
|
| 613 |
+
print(f"\n***** Evaluating on WenetSpeech Chinese TEST_NET *****")
|
| 614 |
+
chi_testnet_metrics = trainer.predict(processed_datasets[lang]["test_net"], metric_key_prefix=f"test_{lang}_net")
|
| 615 |
+
print(f"Chinese TEST_NET WER: {chi_testnet_metrics.metrics[f'test_{lang}_net_wer']*100:.2f}%")
|
| 616 |
+
print(f"Chinese TEST_NET CER: {chi_testnet_metrics.metrics[f'test_{lang}_net_cer']*100:.2f}%")
|
| 617 |
+
lang_results["test_net"] = chi_testnet_metrics.metrics
|
| 618 |
+
|
| 619 |
+
if "test_meeting" in processed_datasets[lang]:
|
| 620 |
+
print(f"\n***** Evaluating on WenetSpeech Chinese TEST_MEETING *****")
|
| 621 |
+
chi_testmeet_metrics = trainer.predict(processed_datasets[lang]["test_meeting"], metric_key_prefix=f"test_{lang}_meeting")
|
| 622 |
+
print(f"Chinese TEST_MEETING WER: {chi_testmeet_metrics.metrics[f'test_{lang}_meeting_wer']*100:.2f}%")
|
| 623 |
+
print(f"Chinese TEST_MEETING CER: {chi_testmeet_metrics.metrics[f'test_{lang}_meeting_cer']*100:.2f}%")
|
| 624 |
+
lang_results["test_meeting"] = chi_testmeet_metrics.metrics
|
| 625 |
+
else:
|
| 626 |
+
# Standard test split
|
| 627 |
+
if "test" in processed_datasets[lang]:
|
| 628 |
+
print(f"\n***** Evaluating on {lang.title()} test set *****")
|
| 629 |
+
test_metrics = trainer.predict(processed_datasets[lang]["test"], metric_key_prefix=f"test_{lang}")
|
| 630 |
+
print(f"{lang.title()} Test WER: {test_metrics.metrics[f'test_{lang}_wer']*100:.2f}%")
|
| 631 |
+
print(f"{lang.title()} Test CER: {test_metrics.metrics[f'test_{lang}_cer']*100:.2f}%")
|
| 632 |
+
lang_results["test"] = test_metrics.metrics
|
| 633 |
+
|
| 634 |
+
results[lang] = lang_results
|
| 635 |
+
|
| 636 |
+
# Summary
|
| 637 |
+
print("\n" + "="*60)
|
| 638 |
+
print("SUMMARY OF ALL TEST RESULTS")
|
| 639 |
+
print("="*60)
|
| 640 |
+
|
| 641 |
+
for lang in enabled_languages:
|
| 642 |
+
if lang in results:
|
| 643 |
+
if lang == "chinese":
|
| 644 |
+
if "test_net" in results[lang]:
|
| 645 |
+
wer = results[lang]["test_net"][f"test_{lang}_net_wer"] * 100
|
| 646 |
+
cer = results[lang]["test_net"][f"test_{lang}_net_cer"] * 100
|
| 647 |
+
print(f"Chinese-NET: WER={wer:.2f}% | CER={cer:.2f}%")
|
| 648 |
+
if "test_meeting" in results[lang]:
|
| 649 |
+
wer = results[lang]["test_meeting"][f"test_{lang}_meeting_wer"] * 100
|
| 650 |
+
cer = results[lang]["test_meeting"][f"test_{lang}_meeting_cer"] * 100
|
| 651 |
+
print(f"Chinese-MTG: WER={wer:.2f}% | CER={cer:.2f}%")
|
| 652 |
+
else:
|
| 653 |
+
if "test" in results[lang]:
|
| 654 |
+
wer = results[lang]["test"][f"test_{lang}_wer"] * 100
|
| 655 |
+
cer = results[lang]["test"][f"test_{lang}_cer"] * 100
|
| 656 |
+
print(f"{lang.title():12}: WER={wer:.2f}% | CER={cer:.2f}%")
|
| 657 |
+
|
| 658 |
+
return results
|
| 659 |
+
|
| 660 |
+
if __name__ == "__main__":
|
| 661 |
+
print(f"Total training samples: {len(combined_train_dataset)}")
|
| 662 |
+
print(f"Total validation samples: {len(combined_val_dataset)}")
|
| 663 |
+
print("Starting training...")
|
| 664 |
+
|
| 665 |
+
# Fine-tune the model
|
| 666 |
+
trainer.train()
|
| 667 |
+
|
| 668 |
+
# Evaluate on all test sets
|
| 669 |
+
evaluate_on_test_sets()
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
|
inference.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# pip install transformers datasets torch soundfile jiwer
|
| 4 |
+
|
| 5 |
+
from datasets import load_dataset, Audio
|
| 6 |
+
from transformers import pipeline, WhisperProcessor
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
import torch
|
| 9 |
+
from jiwer import wer as jiwer_wer
|
| 10 |
+
from jiwer import cer as jiwer_cer
|
| 11 |
+
import ipdb
|
| 12 |
+
|
| 13 |
+
# 1. Load FLEURS Burmese test set, cast to 16 kHz audio
|
| 14 |
+
ds = load_dataset("google/fleurs", "km_kh", split="test")
|
| 15 |
+
ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
|
| 16 |
+
|
| 17 |
+
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 21 |
+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 22 |
+
|
| 23 |
+
# model_id = "openai/whisper-large-v3"
|
| 24 |
+
model_id = "pengyizhou/whisper-fleurs-km_kh"
|
| 25 |
+
|
| 26 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
| 27 |
+
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
| 28 |
+
)
|
| 29 |
+
model.to(device)
|
| 30 |
+
whisper_model = "openai/whisper-large-v3"
|
| 31 |
+
processor = WhisperProcessor.from_pretrained(whisper_model, language="khmer")
|
| 32 |
+
|
| 33 |
+
asr = pipeline(
|
| 34 |
+
"automatic-speech-recognition",
|
| 35 |
+
model=model,
|
| 36 |
+
tokenizer=processor.tokenizer,
|
| 37 |
+
feature_extractor=processor.feature_extractor,
|
| 38 |
+
torch_dtype=torch_dtype,
|
| 39 |
+
chunk_length_s=30,
|
| 40 |
+
batch_size=64,
|
| 41 |
+
max_new_tokens=440,
|
| 42 |
+
device=device,
|
| 43 |
+
no_repeat_ngram_size=3, # Prevent repeating 3-grams
|
| 44 |
+
repetition_penalty=1.0, # Penalize repetitions (>1.0 reduces repetition)
|
| 45 |
+
length_penalty=1.0, # Control length preference
|
| 46 |
+
num_beams=1, # Use beam search for better quality
|
| 47 |
+
do_sample=False, # Disable sampling for deterministic output
|
| 48 |
+
early_stopping=False, # Stop when sufficient beams are complete
|
| 49 |
+
suppress_tokens=[],
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# 3. Batch‐wise transcription function
|
| 54 |
+
def transcribe_batch(batch):
|
| 55 |
+
# `batch["audio"]` is a list of {"array": np.ndarray, ...}
|
| 56 |
+
inputs = [ ex["array"] for ex in batch["audio"] ]
|
| 57 |
+
outputs = asr(inputs) # returns a list of dicts with "text"
|
| 58 |
+
# lower-case and strip to normalize for CER
|
| 59 |
+
preds = [ out["text"].lower().strip() for out in outputs ]
|
| 60 |
+
return {"prediction": preds}
|
| 61 |
+
|
| 62 |
+
# 4. Map over the dataset in chunks of, say, 32 examples at a time
|
| 63 |
+
result = ds.map(
|
| 64 |
+
transcribe_batch,
|
| 65 |
+
batched=True,
|
| 66 |
+
batch_size=64, # feed 32 audios → pipeline will sub-batch into 8s
|
| 67 |
+
remove_columns=ds.column_names
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# ipdb.set_trace()
|
| 71 |
+
# 5. Compute corpus-level CER with jiwer
|
| 72 |
+
# refs = "\n".join(t.lower().strip() for t in ds["transcription"])
|
| 73 |
+
# preds = "\n".join(t for t in result["prediction"])
|
| 74 |
+
# score = jiwer_cer(refs, preds)
|
| 75 |
+
refs = [t.lower().strip() for t in ds["transcription"]]
|
| 76 |
+
preds = [t for t in result["prediction"]]
|
| 77 |
+
score_cer = jiwer_cer(refs, preds)
|
| 78 |
+
score_wer = jiwer_wer(refs, preds)
|
| 79 |
+
|
| 80 |
+
print(f"CER on FLEURS km_kh: {score_cer*100:.2f}%")
|
| 81 |
+
print(f"WER on FLEURS km_kh: {score_wer*100:.2f}%")
|
| 82 |
+
|
| 83 |
+
with open("./km_kh_finetune.pred", "w") as pred_results:
|
| 84 |
+
for pred in preds:
|
| 85 |
+
pred_results.write("{}\n".format(pred))
|
| 86 |
+
|
| 87 |
+
with open("./km_kh.ref", "w") as ref_results:
|
| 88 |
+
for ref in refs:
|
| 89 |
+
ref_results.write("{}\n".format(ref))
|