File size: 3,035 Bytes
cc1f176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b87d4ca
cc1f176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#!/usr/bin/env python

# pip install transformers datasets torch soundfile jiwer

from datasets import load_dataset, Audio
from transformers import pipeline, WhisperProcessor
from torch.utils.data import DataLoader
import torch
from jiwer import wer as jiwer_wer
from jiwer import cer as jiwer_cer
import ipdb

# 1. Load FLEURS Cebuano test set, cast to 16 kHz audio
ds = load_dataset("google/fleurs", "ceb_ph", split="test")
ds = ds.cast_column("audio", Audio(sampling_rate=16_000))

from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline


device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
whisper_model = "openai/whisper-large-v3"
# processor = WhisperProcessor.from_pretrained(whisper_model, language="khmer")
processor = WhisperProcessor.from_pretrained(whisper_model, language="tagalog")

asr = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype,
    chunk_length_s=30,
    batch_size=64,
    max_new_tokens=440,
    device=device,
    num_beams=1,                   # Use beam search for better quality
    do_sample=False,               # Disable sampling for deterministic output
    early_stopping=False,           # Stop when sufficient beams are complete
    suppress_tokens=[],
)


# 3. Batch‐wise transcription function
def transcribe_batch(batch):
    # `batch["audio"]` is a list of {"array": np.ndarray, ...}
    inputs = [ ex["array"] for ex in batch["audio"] ]
    outputs = asr(inputs, generate_kwargs={"language": "khmer"})  # returns a list of dicts with "text"
    # lower-case and strip to normalize for CER
    preds = [ out["text"].lower().strip() for out in outputs ]
    return {"prediction": preds}

# 4. Map over the dataset in chunks of, say, 32 examples at a time
result = ds.map(
    transcribe_batch,
    batched=True,
    batch_size=64,              # feed 32 audios → pipeline will sub-batch into 8s
    remove_columns=ds.column_names
)

# ipdb.set_trace()
# 5. Compute corpus-level CER with jiwer
# refs = "\n".join(t.lower().strip() for t in ds["transcription"])
# preds = "\n".join(t for t in result["prediction"])
# score = jiwer_cer(refs, preds)
refs = [t.lower().strip() for t in ds["transcription"]]
preds = [t for t in result["prediction"]]
score_cer = jiwer_cer(refs, preds)
score_wer = jiwer_wer(refs, preds)

print(f"CER on FLEURS ceb_ph: {score_cer*100:.2f}%")
print(f"WER on FLEURS ceb_ph: {score_wer*100:.2f}%")

with open("./ceb_ph_zeroshot.pred", "w") as pred_results:
    for pred in preds:
        pred_results.write("{}\n".format(pred))

with open("./ceb_ph.ref", "w") as ref_results:
    for ref in refs:
        ref_results.write("{}\n".format(ref))