#!/usr/bin/env python # pip install transformers datasets torch soundfile jiwer from datasets import load_dataset, Audio from transformers import pipeline, WhisperProcessor from torch.utils.data import DataLoader import torch from jiwer import wer as jiwer_wer from jiwer import cer as jiwer_cer import ipdb # 1. Load FLEURS Cebuano test set, cast to 16 kHz audio ds = load_dataset("google/fleurs", "ceb_ph", split="test") ds = ds.cast_column("audio", Audio(sampling_rate=16_000)) from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model_id = "openai/whisper-large-v3" model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) model.to(device) whisper_model = "openai/whisper-large-v3" # processor = WhisperProcessor.from_pretrained(whisper_model, language="khmer") processor = WhisperProcessor.from_pretrained(whisper_model, language="tagalog") asr = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, chunk_length_s=30, batch_size=64, max_new_tokens=440, device=device, num_beams=1, # Use beam search for better quality do_sample=False, # Disable sampling for deterministic output early_stopping=False, # Stop when sufficient beams are complete suppress_tokens=[], ) # 3. Batch‐wise transcription function def transcribe_batch(batch): # `batch["audio"]` is a list of {"array": np.ndarray, ...} inputs = [ ex["array"] for ex in batch["audio"] ] outputs = asr(inputs, generate_kwargs={"language": "khmer"}) # returns a list of dicts with "text" # lower-case and strip to normalize for CER preds = [ out["text"].lower().strip() for out in outputs ] return {"prediction": preds} # 4. Map over the dataset in chunks of, say, 32 examples at a time result = ds.map( transcribe_batch, batched=True, batch_size=64, # feed 32 audios → pipeline will sub-batch into 8s remove_columns=ds.column_names ) # ipdb.set_trace() # 5. Compute corpus-level CER with jiwer # refs = "\n".join(t.lower().strip() for t in ds["transcription"]) # preds = "\n".join(t for t in result["prediction"]) # score = jiwer_cer(refs, preds) refs = [t.lower().strip() for t in ds["transcription"]] preds = [t for t in result["prediction"]] score_cer = jiwer_cer(refs, preds) score_wer = jiwer_wer(refs, preds) print(f"CER on FLEURS ceb_ph: {score_cer*100:.2f}%") print(f"WER on FLEURS ceb_ph: {score_wer*100:.2f}%") with open("./ceb_ph_zeroshot.pred", "w") as pred_results: for pred in preds: pred_results.write("{}\n".format(pred)) with open("./ceb_ph.ref", "w") as ref_results: for ref in refs: ref_results.write("{}\n".format(ref))