#!/usr/bin/env python # pip install transformers datasets torch soundfile jiwer from datasets import load_dataset, Audio from transformers import pipeline, WhisperProcessor from torch.utils.data import DataLoader import torch from jiwer import wer as jiwer_wer from jiwer import cer as jiwer_cer import ipdb import subprocess import os # 1. Load FLEURS Burmese test set, cast to 16 kHz audio ds = load_dataset("google/fleurs", "km_kh", split="test", trust_remote_code=True) ds = ds.cast_column("audio", Audio(sampling_rate=16_000)) from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model_id = "./ft-lid-whisper-fleurs-km_kh-small" model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) model.to(device) whisper_model = "openai/whisper-large-v3" processor = WhisperProcessor.from_pretrained(whisper_model) asr = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, chunk_length_s=30, batch_size=64, max_new_tokens=228, device=device, num_beams=1, # Use beam search for better quality do_sample=False, # Disable sampling for deterministic output early_stopping=False, # Stop when sufficient beams are complete ) # 3. Batch‐wise transcription function def transcribe_batch(batch): # `batch["audio"]` is a list of {"array": np.ndarray, ...} inputs = [ ex["array"] for ex in batch["audio"] ] outputs = asr(inputs, generate_kwargs={"language": "khmer"}) # returns a list of dicts with "text" # lower-case and strip to normalize for CER preds = [ out["text"].lower().strip() for out in outputs ] return {"prediction": preds} # 4. Map over the dataset in chunks of, say, 32 examples at a time result = ds.map( transcribe_batch, batched=True, batch_size=64, # feed 32 audios → pipeline will sub-batch into 8s remove_columns=ds.column_names ) # ipdb.set_trace() # 5. Compute corpus-level CER with jiwer # refs = "\n".join(t.lower().strip() for t in ds["transcription"]) # preds = "\n".join(t for t in result["prediction"]) # score = jiwer_cer(refs, preds) ids = [key for key in ds["id"]] refs = [t.lower().strip() for t in ds["transcription"]] preds = [t for t in result["prediction"]] score_cer = jiwer_cer(refs, preds) score_wer = jiwer_wer(refs, preds) print(f"CER on FLEURS km_kh: {score_cer*100:.2f}%") print(f"WER on FLEURS km_kh: {score_wer*100:.2f}%") # Function to add spaces between characters for CER calculation def add_char_spaces(text): """Add spaces between each character for character-level evaluation""" return ' '.join(list(text.strip())) with open("./km_kh_finetune.pred", "w") as pred_results: for key, pred in zip(ids, preds): pred_with_spaces = add_char_spaces(pred) pred_results.write("{} {}\n".format(key, pred_with_spaces)) with open("./km_kh.ref", "w") as ref_results: for key, ref in zip(ids, refs): ref_with_spaces = add_char_spaces(ref) ref_results.write("{} {}\n".format(key, ref_with_spaces)) # Generate WER file using compute-wer.py print("Generating detailed WER analysis...") # Check if compute-wer.py exists compute_wer_script = "./compute-wer.py" if not os.path.exists(compute_wer_script): # Try to find it in parent directories or common locations possible_locations = [ "./compute-wer.py", ] for location in possible_locations: if os.path.exists(location): compute_wer_script = location break else: print(f"Warning: compute-wer.py not found. Tried: {[compute_wer_script] + possible_locations}") print("Skipping detailed WER analysis.") compute_wer_script = None if compute_wer_script: try: # Run compute-wer.py with character-level analysis ref_file = "./km_kh.ref" hyp_file = "./km_kh_finetune.pred" wer_file = "./km_kh.wer" cmd = [ "python", compute_wer_script, "--char=1", # Character-level analysis "--v=1", # Verbose output ref_file, hyp_file ] print(f"Running: {' '.join(cmd)} > {wer_file}") # Run the command and redirect output to wer file with open(wer_file, "w") as wer_output: result = subprocess.run( cmd, stdout=wer_output, stderr=subprocess.PIPE, text=True, check=True ) print(f"CER analysis saved to {wer_file}") # Optionally, print the first few lines of the WER file if os.path.exists(wer_file): print("\nFirst few lines of WER analysis:") with open(wer_file, "r") as f: lines = f.readlines() for i, line in enumerate(lines[:10]): # Show first 10 lines print(f" {line.rstrip()}") if len(lines) > 10: print(f" ... ({len(lines) - 10} more lines)") except subprocess.CalledProcessError as e: print(f"Error running compute-wer.py: {e}") if e.stderr: print(f"Error details: {e.stderr}") except Exception as e: print(f"Unexpected error: {e}") print("Inference and CER analysis completed!")