| | import argparse |
| | import pyarabic.araby as araby |
| | from transformers import pipeline |
| | from transformers.models.whisper.english_normalizer import BasicTextNormalizer |
| | from datasets import load_dataset, Audio |
| | import evaluate |
| | from tqdm.auto import tqdm |
| |
|
| | wer_metric = evaluate.load("wer") |
| |
|
| |
|
| | def is_target_text_in_range(ref): |
| | if ref.strip() == "ignore time segment in scoring": |
| | return False |
| | else: |
| | return ref.strip() != "" |
| |
|
| |
|
| | def get_text(sample): |
| | if "text" in sample: |
| | return sample["text"] |
| | elif "sentence" in sample: |
| | return sample["sentence"] |
| | elif "normalized_text" in sample: |
| | return sample["normalized_text"] |
| | elif "transcript" in sample: |
| | return sample["transcript"] |
| | elif "transcription" in sample: |
| | return sample["transcription"] |
| | else: |
| | raise ValueError( |
| | f"Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of " |
| | ".join{sample.keys()}. Ensure a text column name is present in the dataset." |
| | ) |
| |
|
| |
|
| | whisper_norm = BasicTextNormalizer() |
| |
|
| |
|
| | def normalise(batch): |
| | batch["norm_text"] = get_text(batch) |
| | return batch |
| |
|
| |
|
| | def remove_diacritics(batch): |
| | batch["norm_text"] = araby.strip_diacritics(get_text(batch)) |
| | return batch |
| |
|
| |
|
| | def data(dataset): |
| | for i, item in enumerate(dataset): |
| | yield {**item["audio"], "reference": item["norm_text"]} |
| |
|
| |
|
| | def main(args): |
| | batch_size = args.batch_size |
| | whisper_asr = pipeline( |
| | "automatic-speech-recognition", model=args.model_id, device=args.device |
| | ) |
| |
|
| | whisper_asr.model.config.forced_decoder_ids = ( |
| | whisper_asr.tokenizer.get_decoder_prompt_ids( |
| | language=args.language, task="transcribe" |
| | ) |
| | ) |
| |
|
| | dataset = load_dataset( |
| | args.dataset, |
| | args.config, |
| | split=args.split, |
| | streaming=args.streaming, |
| | use_auth_token=True, |
| | ) |
| | |
| | if args.streaming: |
| | dataset = dataset.take(args.max_eval_samples) |
| | else: |
| | if args.max_eval_samples is not None: |
| | dataset = dataset.select(range(args.max_eval_samples)) |
| |
|
| | dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) |
| | dataset = dataset.map(normalise) |
| | dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"]) |
| |
|
| | predictions = [] |
| | references = [] |
| |
|
| | |
| | if not args.streaming: |
| | pbar = tqdm(total=len(dataset)) |
| |
|
| | for out in whisper_asr(data(dataset), batch_size=batch_size): |
| |
|
| | pred = out["text"] |
| | true = out["reference"][0] |
| |
|
| | if args.remove_diacritics: |
| | pred = araby.strip_diacritics(pred) |
| | true = araby.strip_diacritics(true) |
| |
|
| | if args.normalise: |
| | pred = whisper_norm(pred) |
| | true = whisper_norm(true) |
| |
|
| | predictions.append(pred) |
| | references.append(true) |
| |
|
| | if not args.streaming: |
| | pbar.update(1) |
| | if not args.streaming: |
| | pbar.close() |
| | wer = wer_metric.compute(references=references, predictions=predictions) |
| | wer = round(100 * wer, 2) |
| |
|
| | print("WER:", wer) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| |
|
| | parser.add_argument( |
| | "--model_id", |
| | type=str, |
| | required=True, |
| | help="Model identifier. Should be loadable with 🤗 Transformers", |
| | ) |
| | parser.add_argument( |
| | "--dataset", |
| | type=str, |
| | default="mozilla-foundation/common_voice_11_0", |
| | help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets", |
| | ) |
| | parser.add_argument( |
| | "--config", |
| | type=str, |
| | required=True, |
| | help="Config of the dataset. *E.g.* `'en'` for the English split of Common Voice", |
| | ) |
| | parser.add_argument( |
| | "--split", |
| | type=str, |
| | default="test", |
| | help="Split of the dataset. *E.g.* `'test'`", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--device", |
| | type=int, |
| | default=-1, |
| | help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.", |
| | ) |
| | parser.add_argument( |
| | "--batch_size", |
| | type=int, |
| | default=16, |
| | help="Number of samples to go through each streamed batch.", |
| | ) |
| | parser.add_argument( |
| | "--max_eval_samples", |
| | type=int, |
| | default=None, |
| | help="Number of samples to be evaluated. Put a lower number e.g. 64 for testing this script.", |
| | ) |
| | parser.add_argument( |
| | "--streaming", |
| | default=False, |
| | action="store_true", |
| | help="Choose whether you'd like to download the entire dataset or stream it during the evaluation.", |
| | ) |
| | parser.add_argument( |
| | "--language", |
| | type=str, |
| | required=True, |
| | help="Two letter language code for the transcription language, e.g. use 'en' for English.", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--remove_diacritics", |
| | default=False, |
| | action="store_true", |
| | help="Choose whether you'd like remove_diacritics", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--normalise", |
| | default=False, |
| | action="store_true", |
| | help="Choose whether you'd like whisper norm", |
| | ) |
| |
|
| | args = parser.parse_args() |
| |
|
| | main(args) |
| |
|