pengyizhou commited on
Commit
73a3144
·
1 Parent(s): 3ba76d4

update inference

Browse files
README.md CHANGED
@@ -140,10 +140,17 @@ Key dependencies:
140
  - librosa (for audio processing)
141
  - evaluate (for metrics)
142
 
 
 
 
 
 
 
143
  ## Evaluation Results
144
  | Language | Metric | Error Rate |
145
  |-------------|:------:|-----------:|
146
- | Khmer | CER | 33.18% |
 
147
 
148
 
149
 
 
140
  - librosa (for audio processing)
141
  - evaluate (for metrics)
142
 
143
+ ## Zero-shot Results
144
+ | LID | Metric | Error Rate |
145
+ |-------------|:------:|-----------:|
146
+ | Khmer | CER | 86.77% |
147
+ | Auto | CER | 86.39% |
148
+
149
  ## Evaluation Results
150
  | Language | Metric | Error Rate |
151
  |-------------|:------:|-----------:|
152
+ | Khmer | CER | 55.66% |
153
+ | Auto | CER | 55.77% |
154
 
155
 
156
 
config.yaml CHANGED
@@ -4,11 +4,11 @@
4
  # Model Configuration
5
  model:
6
  checkpoint: "openai/whisper-large-v3"
7
- max_target_length: 448
8
 
9
  # Output Configuration
10
  output:
11
- output_dir: "./whisper-fleurs-km_kh-small"
12
 
13
  # Environment Configuration
14
  environment:
@@ -60,6 +60,10 @@ training:
60
  per_device_eval_batch_size: 16
61
  gradient_accumulation_steps: 1
62
 
 
 
 
 
63
 
64
  # Optimization settings
65
  gradient_checkpointing: true
@@ -86,7 +90,8 @@ training:
86
  - "tensorboard"
87
 
88
  # Hub settings
89
- push_to_hub: false
 
90
 
91
  # Multi-GPU specific settings
92
  dataloader_drop_last: true
 
4
  # Model Configuration
5
  model:
6
  checkpoint: "openai/whisper-large-v3"
7
+ max_target_length: 446
8
 
9
  # Output Configuration
10
  output:
11
+ output_dir: "./ft-lid-whisper-fleurs-km_kh-small"
12
 
13
  # Environment Configuration
14
  environment:
 
60
  per_device_eval_batch_size: 16
61
  gradient_accumulation_steps: 1
62
 
63
+ multi_gpu:
64
+ per_device_train_batch_size: 4
65
+ per_device_eval_batch_size: 4
66
+ gradient_accumulation_steps: 1
67
 
68
  # Optimization settings
69
  gradient_checkpointing: true
 
90
  - "tensorboard"
91
 
92
  # Hub settings
93
+ push_to_hub: true
94
+ hub_private_repo: false # Not pushing to a private repo for Khmer
95
 
96
  # Multi-GPU specific settings
97
  dataloader_drop_last: true
finetune.py CHANGED
@@ -47,6 +47,7 @@ import io
47
  import yaml
48
  import argparse
49
  from itertools import chain
 
50
 
51
  # Load configuration from YAML file
52
  def load_config(config_path):
@@ -132,6 +133,19 @@ class WhisperOnTheFlyDataset(TorchDataset):
132
  else: # english, chinese
133
  text = item["text"]
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  # Tokenize with appropriate processor
136
  if lang == "cebuano":
137
  labels = self.processors["cebuano"].tokenizer(
@@ -177,7 +191,8 @@ class WhisperOnTheFlyDataset(TorchDataset):
177
  return {
178
  "input_features": inputs.input_features.squeeze(0),
179
  "labels": labels.input_ids.squeeze(0),
180
- "language": lang
 
181
  }
182
 
183
  def _process_audio(self, audio_sample):
@@ -216,16 +231,53 @@ class DataCollatorSpeechSeq2SeqWithPadding:
216
  label_features = [{"input_ids": feature["labels"]} for feature in features]
217
  # pad the labels to max length
218
  labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
 
 
 
219
 
220
- # replace padding with -100 to ignore loss correctly
221
- labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  # if bos token is appended in previous tokenization step,
224
  # cut bos token here as it's append later anyways
225
- if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
226
- labels = labels[:, 1:]
227
 
228
- batch["labels"] = labels
 
229
 
230
  return batch
231
 
@@ -300,9 +352,9 @@ def load_chinese_dataset(dataset_config):
300
  """Load Chinese dataset with multiple test splits"""
301
  print("Loading Chinese...")
302
  wenet_train = load_dataset(dataset_config['train_dataset'], streaming=dataset_config['streaming'])
303
- wenet_valid = load_dataset(dataset_config['validation_dataset'], dataset_config['validation_config'], split="validation", streaming=dataset_config['streaming'])
304
- wenet_testnet = load_dataset(dataset_config['test_net_dataset'], dataset_config['test_net_config'], split="test", streaming=dataset_config['streaming'])
305
- wenet_testmeeting = load_dataset(dataset_config['test_meeting_dataset'], dataset_config['test_meeting_config'], split="test", streaming=dataset_config['streaming'])
306
  return {
307
  "train": wenet_train["train"],
308
  "validation": wenet_valid,
@@ -352,12 +404,16 @@ else:
352
  model = WhisperForConditionalGeneration.from_pretrained(MODEL_CHECKPOINT)
353
 
354
  # Multi-GPU handling
 
 
 
 
355
  if torch.cuda.device_count() > 1:
356
  print(f"Using {torch.cuda.device_count()} GPUs for training")
357
  # The model will be automatically distributed by the Trainer
358
- model.to(device)
359
  else:
360
- model.to(device)
361
 
362
 
363
 
@@ -519,10 +575,13 @@ def compute_metrics(pred):
519
  Compute WER and CER metrics for predictions
520
  """
521
  pred_ids = pred.predictions
 
522
  pred_str = main_processor.batch_decode(pred_ids, skip_special_tokens=True)
523
 
524
  label_ids = pred.label_ids
 
525
  label_ids[label_ids == -100] = main_processor.tokenizer.pad_token_id
 
526
  ref_str = main_processor.batch_decode(label_ids, skip_special_tokens=True)
527
 
528
  # lowercase & strip
@@ -531,7 +590,19 @@ def compute_metrics(pred):
531
 
532
  wer_score = wer_metric.compute(predictions=pred_str, references=ref_str)
533
  cer_score = cer_metric.compute(predictions=pred_str, references=ref_str)
534
- return {"wer": wer_score, "cer": cer_score}
 
 
 
 
 
 
 
 
 
 
 
 
535
 
536
  # Check for multi-GPU setup
537
  num_gpus = torch.cuda.device_count()
@@ -578,6 +649,7 @@ training_args = Seq2SeqTrainingArguments(
578
  metric_for_best_model=training_config['metric_for_best_model'],
579
  greater_is_better=training_config['greater_is_better'],
580
  push_to_hub=training_config['push_to_hub'],
 
581
  save_total_limit=training_config['save_total_limit'],
582
  # Multi-GPU specific settings
583
  dataloader_drop_last=training_config['dataloader_drop_last'],
@@ -603,22 +675,50 @@ def evaluate_on_test_sets():
603
 
604
  results = {}
605
 
 
 
 
 
 
 
 
 
 
 
606
  for lang in enabled_languages:
607
  if lang in processed_datasets:
608
  lang_results = {}
609
 
 
 
 
 
 
 
 
 
 
 
610
  if lang == "chinese":
611
  # Chinese has multiple test splits
612
  if "test_net" in processed_datasets[lang]:
613
  print(f"\n***** Evaluating on WenetSpeech Chinese TEST_NET *****")
614
- chi_testnet_metrics = trainer.predict(processed_datasets[lang]["test_net"], metric_key_prefix=f"test_{lang}_net")
 
 
 
 
615
  print(f"Chinese TEST_NET WER: {chi_testnet_metrics.metrics[f'test_{lang}_net_wer']*100:.2f}%")
616
  print(f"Chinese TEST_NET CER: {chi_testnet_metrics.metrics[f'test_{lang}_net_cer']*100:.2f}%")
617
  lang_results["test_net"] = chi_testnet_metrics.metrics
618
 
619
  if "test_meeting" in processed_datasets[lang]:
620
  print(f"\n***** Evaluating on WenetSpeech Chinese TEST_MEETING *****")
621
- chi_testmeet_metrics = trainer.predict(processed_datasets[lang]["test_meeting"], metric_key_prefix=f"test_{lang}_meeting")
 
 
 
 
622
  print(f"Chinese TEST_MEETING WER: {chi_testmeet_metrics.metrics[f'test_{lang}_meeting_wer']*100:.2f}%")
623
  print(f"Chinese TEST_MEETING CER: {chi_testmeet_metrics.metrics[f'test_{lang}_meeting_cer']*100:.2f}%")
624
  lang_results["test_meeting"] = chi_testmeet_metrics.metrics
@@ -626,7 +726,11 @@ def evaluate_on_test_sets():
626
  # Standard test split
627
  if "test" in processed_datasets[lang]:
628
  print(f"\n***** Evaluating on {lang.title()} test set *****")
629
- test_metrics = trainer.predict(processed_datasets[lang]["test"], metric_key_prefix=f"test_{lang}")
 
 
 
 
630
  print(f"{lang.title()} Test WER: {test_metrics.metrics[f'test_{lang}_wer']*100:.2f}%")
631
  print(f"{lang.title()} Test CER: {test_metrics.metrics[f'test_{lang}_cer']*100:.2f}%")
632
  lang_results["test"] = test_metrics.metrics
@@ -666,7 +770,7 @@ if __name__ == "__main__":
666
  trainer.train()
667
 
668
  # Evaluate on all test sets
669
- evaluate_on_test_sets()
670
 
671
 
672
 
 
47
  import yaml
48
  import argparse
49
  from itertools import chain
50
+ import torch.distributed as dist
51
 
52
  # Load configuration from YAML file
53
  def load_config(config_path):
 
133
  else: # english, chinese
134
  text = item["text"]
135
 
136
+ # Map language to Whisper language token ID
137
+ lang_id_map = {
138
+ "english": 50259, # <|en|>
139
+ "chinese": 50260, # <|zh|>
140
+ "indonesian": 50275, # <|id|>
141
+ "malay": 50282, # <|ms|>
142
+ "khmer": 50323, # <|km|>
143
+ "cebuano": 50348, # <|tl|> (using Tagalog as fallback for Cebuano)
144
+ }
145
+
146
+ # Get language token ID
147
+ lang_token_id = lang_id_map.get(lang)
148
+
149
  # Tokenize with appropriate processor
150
  if lang == "cebuano":
151
  labels = self.processors["cebuano"].tokenizer(
 
191
  return {
192
  "input_features": inputs.input_features.squeeze(0),
193
  "labels": labels.input_ids.squeeze(0),
194
+ "language": lang,
195
+ "language_token_id": lang_token_id
196
  }
197
 
198
  def _process_audio(self, audio_sample):
 
231
  label_features = [{"input_ids": feature["labels"]} for feature in features]
232
  # pad the labels to max length
233
  labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
234
+
235
+ # Get original labels before modification
236
+ labels = labels_batch["input_ids"]
237
 
238
+ # Task ID is fixed for transcription (50360)
239
+ task_token_id = 50360 # transcribe task
240
+
241
+ # Create a tensor to store new labels with language and task tokens prepended
242
+ batch_size = labels.size(0)
243
+ seq_length = labels.size(1)
244
+ # Add 2 tokens (lang token and task token) at the beginning
245
+ new_labels = torch.full((batch_size, seq_length + 2), self.processor.tokenizer.pad_token_id, dtype=labels.dtype, device=labels.device)
246
+
247
+ # Add the language token and task token at the beginning for each sample
248
+ for i, feature in enumerate(features):
249
+ # Add SOT token as first token (50258)
250
+ # new_labels[i, 0] = 50258 # SOT token
251
+
252
+ # Add language token as second token if available
253
+ if "language_token_id" in feature and feature["language_token_id"] is not None:
254
+ new_labels[i, 0] = feature["language_token_id"]
255
+
256
+ # Add task token as third token
257
+ new_labels[i, 1] = task_token_id
258
+
259
+ # Copy the original label tokens after the special tokens
260
+ token_length = min(seq_length, labels.size(1))
261
+ new_labels[i, 2:2+token_length] = labels[i, :token_length]
262
+
263
+ # Create new attention mask for padded sequences
264
+ new_attention_mask = torch.zeros_like(new_labels, dtype=torch.long)
265
+ for i in range(batch_size):
266
+ # Find the last non-padding token in the original sequence
267
+ orig_seq_len = (labels[i] != self.processor.tokenizer.pad_token_id).sum().item()
268
+ # Set attention mask to 1 for all tokens up to the end of the sequence + 2 special tokens
269
+ new_attention_mask[i, :orig_seq_len+2] = 1
270
+
271
+ # Replace padding with -100 to ignore loss correctly
272
+ new_labels = new_labels.masked_fill(new_attention_mask.ne(1), -100)
273
 
274
  # if bos token is appended in previous tokenization step,
275
  # cut bos token here as it's append later anyways
276
+ if (new_labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
277
+ new_labels = new_labels[:, 1:]
278
 
279
+ batch["labels"] = new_labels
280
+ batch["attention_mask"] = new_attention_mask
281
 
282
  return batch
283
 
 
352
  """Load Chinese dataset with multiple test splits"""
353
  print("Loading Chinese...")
354
  wenet_train = load_dataset(dataset_config['train_dataset'], streaming=dataset_config['streaming'])
355
+ wenet_valid = load_dataset(dataset_config['validation_dataset'], dataset_config['validation_config'], split="validation", streaming=dataset_config['streaming'], trust_remote_code=dataset_config['trust_remote_code'])
356
+ wenet_testnet = load_dataset(dataset_config['test_net_dataset'], dataset_config['test_net_config'], split="test", streaming=dataset_config['streaming'], trust_remote_code=dataset_config['trust_remote_code'])
357
+ wenet_testmeeting = load_dataset(dataset_config['test_meeting_dataset'], dataset_config['test_meeting_config'], split="test", streaming=dataset_config['streaming'], trust_remote_code=dataset_config['trust_remote_code'])
358
  return {
359
  "train": wenet_train["train"],
360
  "validation": wenet_valid,
 
404
  model = WhisperForConditionalGeneration.from_pretrained(MODEL_CHECKPOINT)
405
 
406
  # Multi-GPU handling
407
+ local_rank = int(os.environ["LOCAL_RANK"])
408
+ torch.cuda.set_device(local_rank)
409
+ print(f"Using GPU {local_rank} for training")
410
+ dist.init_process_group(backend="nccl")
411
  if torch.cuda.device_count() > 1:
412
  print(f"Using {torch.cuda.device_count()} GPUs for training")
413
  # The model will be automatically distributed by the Trainer
414
+ model.to(torch.device("cuda", local_rank))
415
  else:
416
+ model.to(torch.device("cuda", local_rank))
417
 
418
 
419
 
 
575
  Compute WER and CER metrics for predictions
576
  """
577
  pred_ids = pred.predictions
578
+ # Decode predictions, skipping special tokens
579
  pred_str = main_processor.batch_decode(pred_ids, skip_special_tokens=True)
580
 
581
  label_ids = pred.label_ids
582
+ # Replace -100 with pad token ID for decoding
583
  label_ids[label_ids == -100] = main_processor.tokenizer.pad_token_id
584
+ # Decode reference texts, also skipping special tokens
585
  ref_str = main_processor.batch_decode(label_ids, skip_special_tokens=True)
586
 
587
  # lowercase & strip
 
590
 
591
  wer_score = wer_metric.compute(predictions=pred_str, references=ref_str)
592
  cer_score = cer_metric.compute(predictions=pred_str, references=ref_str)
593
+
594
+ # Combine metrics
595
+ metrics = {"wer": wer_score, "cer": cer_score}
596
+
597
+ # Log example predictions
598
+ if len(pred_str) > 0:
599
+ num_examples = min(3, len(pred_str))
600
+ for i in range(num_examples):
601
+ print(f"Example {i}:")
602
+ print(f" Reference: {ref_str[i]}")
603
+ print(f" Prediction: {pred_str[i]}")
604
+
605
+ return metrics
606
 
607
  # Check for multi-GPU setup
608
  num_gpus = torch.cuda.device_count()
 
649
  metric_for_best_model=training_config['metric_for_best_model'],
650
  greater_is_better=training_config['greater_is_better'],
651
  push_to_hub=training_config['push_to_hub'],
652
+ hub_private_repo=training_config['hub_private_repo'], # Always push to private repo
653
  save_total_limit=training_config['save_total_limit'],
654
  # Multi-GPU specific settings
655
  dataloader_drop_last=training_config['dataloader_drop_last'],
 
675
 
676
  results = {}
677
 
678
+ # Define language-specific generation parameters
679
+ lang_id_map = {
680
+ "english": 50259, # <|en|>
681
+ "chinese": 50260, # <|zh|>
682
+ "indonesian": 50275, # <|id|>
683
+ "malay": 50282, # <|ms|>
684
+ "khmer": 50323, # <|km|>
685
+ "cebuano": 50348, # <|tl|> (using Tagalog as fallback for Cebuano)
686
+ }
687
+
688
  for lang in enabled_languages:
689
  if lang in processed_datasets:
690
  lang_results = {}
691
 
692
+ # Set language-specific generation parameters
693
+ lang_token_id = lang_id_map.get(lang)
694
+ task_token_id = 50360 # transcribe task
695
+
696
+ # Define forced decoder IDs for generation if language is supported
697
+ forced_decoder_ids = None
698
+ if lang_token_id:
699
+ forced_decoder_ids = [[1, lang_token_id], [2, task_token_id]]
700
+ print(f"Using forced_decoder_ids for {lang}: {forced_decoder_ids}")
701
+
702
  if lang == "chinese":
703
  # Chinese has multiple test splits
704
  if "test_net" in processed_datasets[lang]:
705
  print(f"\n***** Evaluating on WenetSpeech Chinese TEST_NET *****")
706
+ chi_testnet_metrics = trainer.predict(
707
+ processed_datasets[lang]["test_net"],
708
+ metric_key_prefix=f"test_{lang}_net",
709
+ forced_decoder_ids=forced_decoder_ids
710
+ )
711
  print(f"Chinese TEST_NET WER: {chi_testnet_metrics.metrics[f'test_{lang}_net_wer']*100:.2f}%")
712
  print(f"Chinese TEST_NET CER: {chi_testnet_metrics.metrics[f'test_{lang}_net_cer']*100:.2f}%")
713
  lang_results["test_net"] = chi_testnet_metrics.metrics
714
 
715
  if "test_meeting" in processed_datasets[lang]:
716
  print(f"\n***** Evaluating on WenetSpeech Chinese TEST_MEETING *****")
717
+ chi_testmeet_metrics = trainer.predict(
718
+ processed_datasets[lang]["test_meeting"],
719
+ metric_key_prefix=f"test_{lang}_meeting",
720
+ forced_decoder_ids=forced_decoder_ids
721
+ )
722
  print(f"Chinese TEST_MEETING WER: {chi_testmeet_metrics.metrics[f'test_{lang}_meeting_wer']*100:.2f}%")
723
  print(f"Chinese TEST_MEETING CER: {chi_testmeet_metrics.metrics[f'test_{lang}_meeting_cer']*100:.2f}%")
724
  lang_results["test_meeting"] = chi_testmeet_metrics.metrics
 
726
  # Standard test split
727
  if "test" in processed_datasets[lang]:
728
  print(f"\n***** Evaluating on {lang.title()} test set *****")
729
+ test_metrics = trainer.predict(
730
+ processed_datasets[lang]["test"],
731
+ metric_key_prefix=f"test_{lang}",
732
+ forced_decoder_ids=forced_decoder_ids
733
+ )
734
  print(f"{lang.title()} Test WER: {test_metrics.metrics[f'test_{lang}_wer']*100:.2f}%")
735
  print(f"{lang.title()} Test CER: {test_metrics.metrics[f'test_{lang}_cer']*100:.2f}%")
736
  lang_results["test"] = test_metrics.metrics
 
770
  trainer.train()
771
 
772
  # Evaluate on all test sets
773
+ # evaluate_on_test_sets()
774
 
775
 
776
 
inference.py CHANGED
@@ -22,14 +22,14 @@ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
22
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
23
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
24
 
25
- model_id = "./whisper-fleurs-km_kh-small"
26
 
27
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
28
  model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
29
  )
30
  model.to(device)
31
  whisper_model = "openai/whisper-large-v3"
32
- processor = WhisperProcessor.from_pretrained(whisper_model, language="khmer")
33
 
34
  asr = pipeline(
35
  "automatic-speech-recognition",
@@ -44,7 +44,6 @@ asr = pipeline(
44
  num_beams=1, # Use beam search for better quality
45
  do_sample=False, # Disable sampling for deterministic output
46
  early_stopping=False, # Stop when sufficient beams are complete
47
- suppress_tokens=[],
48
  )
49
 
50
 
@@ -52,7 +51,7 @@ asr = pipeline(
52
  def transcribe_batch(batch):
53
  # `batch["audio"]` is a list of {"array": np.ndarray, ...}
54
  inputs = [ ex["array"] for ex in batch["audio"] ]
55
- outputs = asr(inputs) # returns a list of dicts with "text"
56
  # lower-case and strip to normalize for CER
57
  preds = [ out["text"].lower().strip() for out in outputs ]
58
  return {"prediction": preds}
 
22
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
23
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
24
 
25
+ model_id = "./ft-lid-whisper-fleurs-km_kh-small"
26
 
27
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
28
  model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
29
  )
30
  model.to(device)
31
  whisper_model = "openai/whisper-large-v3"
32
+ processor = WhisperProcessor.from_pretrained(whisper_model)
33
 
34
  asr = pipeline(
35
  "automatic-speech-recognition",
 
44
  num_beams=1, # Use beam search for better quality
45
  do_sample=False, # Disable sampling for deterministic output
46
  early_stopping=False, # Stop when sufficient beams are complete
 
47
  )
48
 
49
 
 
51
  def transcribe_batch(batch):
52
  # `batch["audio"]` is a list of {"array": np.ndarray, ...}
53
  inputs = [ ex["array"] for ex in batch["audio"] ]
54
+ outputs = asr(inputs, generate_kwargs={"language": "khmer"}) # returns a list of dicts with "text"
55
  # lower-case and strip to normalize for CER
56
  preds = [ out["text"].lower().strip() for out in outputs ]
57
  return {"prediction": preds}
inference/compute-wer.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import re, sys, unicodedata
5
+ import codecs
6
+
7
+ remove_tag = True
8
+ spacelist = [' ', '\t', '\r', '\n']
9
+ puncts = [
10
+ '!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』',
11
+ '《', '》', '(', ')', '(', ')', '[', ']', '【', '】', '{', '}', '〔', '〕',
12
+ '⟨', '⟩', '《', '》'
13
+ ]
14
+
15
+
16
+ def characterize(string):
17
+ res = []
18
+ i = 0
19
+ while i < len(string):
20
+ char = string[i]
21
+ if char in puncts:
22
+ i += 1
23
+ continue
24
+ cat1 = unicodedata.category(char)
25
+ #https://unicodebook.readthedocs.io/unicode.html#unicode-categories
26
+ if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned
27
+ i += 1
28
+ continue
29
+ if cat1 == 'Lo': # letter-other
30
+ res.append(char)
31
+ i += 1
32
+ else:
33
+ # some input looks like: <unk><noise>, we want to separate it to two words.
34
+ sep = ' '
35
+ if char == '<': sep = '>'
36
+ j = i + 1
37
+ while j < len(string):
38
+ c = string[j]
39
+ if ord(c) >= 128 or (c in spacelist) or (c == sep):
40
+ break
41
+ j += 1
42
+ if j < len(string) and string[j] == '>':
43
+ j += 1
44
+ res.append(string[i:j])
45
+ i = j
46
+ return res
47
+
48
+
49
+ def stripoff_tags(x):
50
+ if not x: return ''
51
+ chars = []
52
+ i = 0
53
+ T = len(x)
54
+ while i < T:
55
+ if x[i] == '<':
56
+ while i < T and x[i] != '>':
57
+ i += 1
58
+ i += 1
59
+ else:
60
+ chars.append(x[i])
61
+ i += 1
62
+ return ''.join(chars)
63
+
64
+
65
+ def normalize(sentence, ignore_words, cs, split=None):
66
+ """ sentence, ignore_words are both in unicode
67
+ """
68
+ new_sentence = []
69
+ for token in sentence:
70
+ x = token
71
+ if not cs:
72
+ x = x.upper()
73
+ if x in ignore_words:
74
+ continue
75
+ if remove_tag:
76
+ x = stripoff_tags(x)
77
+ x = re.sub(r'[.,!?;:()\[\]{}<>""„""«»‹›\/\\|@#$%^&*_=+~`-]', '', x)
78
+ # Skip tokens containing any digits
79
+ if re.search(r'\d', x):
80
+ continue
81
+ if not x:
82
+ continue
83
+ if split and x in split:
84
+ new_sentence += split[x]
85
+ else:
86
+ new_sentence.append(x)
87
+ return new_sentence
88
+
89
+
90
+ class Calculator:
91
+
92
+ def __init__(self):
93
+ self.data = {}
94
+ self.space = []
95
+ self.cost = {}
96
+ self.cost['cor'] = 0
97
+ self.cost['sub'] = 1
98
+ self.cost['del'] = 1
99
+ self.cost['ins'] = 1
100
+
101
+ def calculate(self, lab, rec):
102
+ # Initialization
103
+ lab.insert(0, '')
104
+ rec.insert(0, '')
105
+ while len(self.space) < len(lab):
106
+ self.space.append([])
107
+ for row in self.space:
108
+ for element in row:
109
+ element['dist'] = 0
110
+ element['error'] = 'non'
111
+ while len(row) < len(rec):
112
+ row.append({'dist': 0, 'error': 'non'})
113
+ for i in range(len(lab)):
114
+ self.space[i][0]['dist'] = i
115
+ self.space[i][0]['error'] = 'del'
116
+ for j in range(len(rec)):
117
+ self.space[0][j]['dist'] = j
118
+ self.space[0][j]['error'] = 'ins'
119
+ self.space[0][0]['error'] = 'non'
120
+ for token in lab:
121
+ if token not in self.data and len(token) > 0:
122
+ self.data[token] = {
123
+ 'all': 0,
124
+ 'cor': 0,
125
+ 'sub': 0,
126
+ 'ins': 0,
127
+ 'del': 0
128
+ }
129
+ for token in rec:
130
+ if token not in self.data and len(token) > 0:
131
+ self.data[token] = {
132
+ 'all': 0,
133
+ 'cor': 0,
134
+ 'sub': 0,
135
+ 'ins': 0,
136
+ 'del': 0
137
+ }
138
+ # Computing edit distance
139
+ for i, lab_token in enumerate(lab):
140
+ for j, rec_token in enumerate(rec):
141
+ if i == 0 or j == 0:
142
+ continue
143
+ min_dist = sys.maxsize
144
+ min_error = 'none'
145
+ dist = self.space[i - 1][j]['dist'] + self.cost['del']
146
+ error = 'del'
147
+ if dist < min_dist:
148
+ min_dist = dist
149
+ min_error = error
150
+ dist = self.space[i][j - 1]['dist'] + self.cost['ins']
151
+ error = 'ins'
152
+ if dist < min_dist:
153
+ min_dist = dist
154
+ min_error = error
155
+ if lab_token == rec_token:
156
+ dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
157
+ error = 'cor'
158
+ else:
159
+ dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
160
+ error = 'sub'
161
+ if dist < min_dist:
162
+ min_dist = dist
163
+ min_error = error
164
+ self.space[i][j]['dist'] = min_dist
165
+ self.space[i][j]['error'] = min_error
166
+ # Tracing back
167
+ result = {
168
+ 'lab': [],
169
+ 'rec': [],
170
+ 'all': 0,
171
+ 'cor': 0,
172
+ 'sub': 0,
173
+ 'ins': 0,
174
+ 'del': 0
175
+ }
176
+ i = len(lab) - 1
177
+ j = len(rec) - 1
178
+ while True:
179
+ if self.space[i][j]['error'] == 'cor': # correct
180
+ if len(lab[i]) > 0:
181
+ self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
182
+ self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
183
+ result['all'] = result['all'] + 1
184
+ result['cor'] = result['cor'] + 1
185
+ result['lab'].insert(0, lab[i])
186
+ result['rec'].insert(0, rec[j])
187
+ i = i - 1
188
+ j = j - 1
189
+ elif self.space[i][j]['error'] == 'sub': # substitution
190
+ if len(lab[i]) > 0:
191
+ self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
192
+ self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
193
+ result['all'] = result['all'] + 1
194
+ result['sub'] = result['sub'] + 1
195
+ result['lab'].insert(0, lab[i])
196
+ result['rec'].insert(0, rec[j])
197
+ i = i - 1
198
+ j = j - 1
199
+ elif self.space[i][j]['error'] == 'del': # deletion
200
+ if len(lab[i]) > 0:
201
+ self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
202
+ self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
203
+ result['all'] = result['all'] + 1
204
+ result['del'] = result['del'] + 1
205
+ result['lab'].insert(0, lab[i])
206
+ result['rec'].insert(0, "")
207
+ i = i - 1
208
+ elif self.space[i][j]['error'] == 'ins': # insertion
209
+ if len(rec[j]) > 0:
210
+ self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
211
+ result['ins'] = result['ins'] + 1
212
+ result['lab'].insert(0, "")
213
+ result['rec'].insert(0, rec[j])
214
+ j = j - 1
215
+ elif self.space[i][j]['error'] == 'non': # starting point
216
+ break
217
+ else: # shouldn't reach here
218
+ print(
219
+ 'this should not happen , i = {i} , j = {j} , error = {error}'
220
+ .format(i=i, j=j, error=self.space[i][j]['error']))
221
+ return result
222
+
223
+ def overall(self):
224
+ result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
225
+ for token in self.data:
226
+ result['all'] = result['all'] + self.data[token]['all']
227
+ result['cor'] = result['cor'] + self.data[token]['cor']
228
+ result['sub'] = result['sub'] + self.data[token]['sub']
229
+ result['ins'] = result['ins'] + self.data[token]['ins']
230
+ result['del'] = result['del'] + self.data[token]['del']
231
+ return result
232
+
233
+ def cluster(self, data):
234
+ result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
235
+ for token in data:
236
+ if token in self.data:
237
+ result['all'] = result['all'] + self.data[token]['all']
238
+ result['cor'] = result['cor'] + self.data[token]['cor']
239
+ result['sub'] = result['sub'] + self.data[token]['sub']
240
+ result['ins'] = result['ins'] + self.data[token]['ins']
241
+ result['del'] = result['del'] + self.data[token]['del']
242
+ return result
243
+
244
+ def keys(self):
245
+ return list(self.data.keys())
246
+
247
+
248
+ def width(string):
249
+ return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
250
+
251
+
252
+ def default_cluster(word):
253
+
254
+ # unicode_names = [unicodedata.name(char) for char in word]
255
+ unicode_names = []
256
+ for char in word:
257
+ try:
258
+ unicode_names.append(unicodedata.name(char))
259
+ except ValueError:
260
+ unicode_names.append("UNK")
261
+ for i in reversed(range(len(unicode_names))):
262
+ if unicode_names[i].startswith('DIGIT'): # 1
263
+ unicode_names[i] = 'Number' # 'DIGIT'
264
+ elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH')
265
+ or unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')):
266
+ # 明 / 郎
267
+ unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH'
268
+ elif (unicode_names[i].startswith('LATIN CAPITAL LETTER')
269
+ or unicode_names[i].startswith('LATIN SMALL LETTER')):
270
+ # A / a
271
+ unicode_names[i] = 'English' # 'LATIN LETTER'
272
+ elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め
273
+ unicode_names[i] = 'Japanese' # 'GANA LETTER'
274
+ elif (unicode_names[i].startswith('AMPERSAND')
275
+ or unicode_names[i].startswith('APOSTROPHE')
276
+ or unicode_names[i].startswith('COMMERCIAL AT')
277
+ or unicode_names[i].startswith('DEGREE CELSIUS')
278
+ or unicode_names[i].startswith('EQUALS SIGN')
279
+ or unicode_names[i].startswith('FULL STOP')
280
+ or unicode_names[i].startswith('HYPHEN-MINUS')
281
+ or unicode_names[i].startswith('LOW LINE')
282
+ or unicode_names[i].startswith('NUMBER SIGN')
283
+ or unicode_names[i].startswith('PLUS SIGN')
284
+ or unicode_names[i].startswith('SEMICOLON')):
285
+ # & / ' / @ / ℃ / = / . / - / _ / # / + / ;
286
+ del unicode_names[i]
287
+ else:
288
+ return 'Other'
289
+ if len(unicode_names) == 0:
290
+ return 'Other'
291
+ if len(unicode_names) == 1:
292
+ return unicode_names[0]
293
+ for i in range(len(unicode_names) - 1):
294
+ if unicode_names[i] != unicode_names[i + 1]:
295
+ return 'Other'
296
+ return unicode_names[0]
297
+
298
+
299
+ def usage():
300
+ print(
301
+ "compute-wer.py : compute word error rate (WER) and align recognition results and references."
302
+ )
303
+ print(
304
+ " usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer"
305
+ )
306
+
307
+
308
+ if __name__ == '__main__':
309
+ if len(sys.argv) == 1:
310
+ usage()
311
+ sys.exit(0)
312
+ calculator = Calculator()
313
+ cluster_file = ''
314
+ ignore_words = set()
315
+ tochar = False
316
+ verbose = 1
317
+ padding_symbol = ' '
318
+ case_sensitive = False
319
+ max_words_per_line = sys.maxsize
320
+ split = None
321
+ while len(sys.argv) > 3:
322
+ a = '--maxw='
323
+ if sys.argv[1].startswith(a):
324
+ b = sys.argv[1][len(a):]
325
+ del sys.argv[1]
326
+ max_words_per_line = int(b)
327
+ continue
328
+ a = '--rt='
329
+ if sys.argv[1].startswith(a):
330
+ b = sys.argv[1][len(a):].lower()
331
+ del sys.argv[1]
332
+ remove_tag = (b == 'true') or (b != '0')
333
+ continue
334
+ a = '--cs='
335
+ if sys.argv[1].startswith(a):
336
+ b = sys.argv[1][len(a):].lower()
337
+ del sys.argv[1]
338
+ case_sensitive = (b == 'true') or (b != '0')
339
+ continue
340
+ a = '--cluster='
341
+ if sys.argv[1].startswith(a):
342
+ cluster_file = sys.argv[1][len(a):]
343
+ del sys.argv[1]
344
+ continue
345
+ a = '--splitfile='
346
+ if sys.argv[1].startswith(a):
347
+ split_file = sys.argv[1][len(a):]
348
+ del sys.argv[1]
349
+ split = dict()
350
+ with codecs.open(split_file, 'r', 'utf-8') as fh:
351
+ for line in fh: # line in unicode
352
+ words = line.strip().split()
353
+ if len(words) >= 2:
354
+ split[words[0]] = words[1:]
355
+ continue
356
+ a = '--ig='
357
+ if sys.argv[1].startswith(a):
358
+ ignore_file = sys.argv[1][len(a):]
359
+ del sys.argv[1]
360
+ with codecs.open(ignore_file, 'r', 'utf-8') as fh:
361
+ for line in fh: # line in unicode
362
+ line = line.strip()
363
+ if len(line) > 0:
364
+ ignore_words.add(line)
365
+ continue
366
+ a = '--char='
367
+ if sys.argv[1].startswith(a):
368
+ b = sys.argv[1][len(a):].lower()
369
+ del sys.argv[1]
370
+ tochar = (b == 'true') or (b != '0')
371
+ continue
372
+ a = '--v='
373
+ if sys.argv[1].startswith(a):
374
+ b = sys.argv[1][len(a):].lower()
375
+ del sys.argv[1]
376
+ verbose = 0
377
+ try:
378
+ verbose = int(b)
379
+ except:
380
+ if b == 'true' or b != '0':
381
+ verbose = 1
382
+ continue
383
+ a = '--padding-symbol='
384
+ if sys.argv[1].startswith(a):
385
+ b = sys.argv[1][len(a):].lower()
386
+ del sys.argv[1]
387
+ if b == 'space':
388
+ padding_symbol = ' '
389
+ elif b == 'underline':
390
+ padding_symbol = '_'
391
+ continue
392
+ if True or sys.argv[1].startswith('-'):
393
+ #ignore invalid switch
394
+ del sys.argv[1]
395
+ continue
396
+
397
+ if not case_sensitive:
398
+ ig = set([w.upper() for w in ignore_words])
399
+ ignore_words = ig
400
+
401
+ default_clusters = {}
402
+ default_words = {}
403
+
404
+ ref_file = sys.argv[1]
405
+ hyp_file = sys.argv[2]
406
+ rec_set = {}
407
+ if split and not case_sensitive:
408
+ newsplit = dict()
409
+ for w in split:
410
+ words = split[w]
411
+ for i in range(len(words)):
412
+ words[i] = words[i].upper()
413
+ newsplit[w.upper()] = words
414
+ split = newsplit
415
+
416
+ with codecs.open(hyp_file, 'r', 'utf-8') as fh:
417
+ for line in fh:
418
+ if tochar:
419
+ array = characterize(line)
420
+ else:
421
+ array = line.strip().split()
422
+ if len(array) == 0: continue
423
+ fid = array[0]
424
+ rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive,
425
+ split)
426
+
427
+ # compute error rate on the interaction of reference file and hyp file
428
+ for line in open(ref_file, 'r', encoding='utf-8'):
429
+ if tochar:
430
+ array = characterize(line)
431
+ else:
432
+ array = line.rstrip('\n').split()
433
+ if len(array) == 0: continue
434
+ fid = array[0]
435
+ if fid not in rec_set:
436
+ continue
437
+ lab = normalize(array[1:], ignore_words, case_sensitive, split)
438
+ rec = rec_set[fid]
439
+ if verbose:
440
+ print('\nutt: %s' % fid)
441
+
442
+ for word in rec + lab:
443
+ if word not in default_words:
444
+ default_cluster_name = default_cluster(word)
445
+ if default_cluster_name not in default_clusters:
446
+ default_clusters[default_cluster_name] = {}
447
+ if word not in default_clusters[default_cluster_name]:
448
+ default_clusters[default_cluster_name][word] = 1
449
+ default_words[word] = default_cluster_name
450
+
451
+ result = calculator.calculate(lab, rec)
452
+ if verbose:
453
+ if result['all'] != 0:
454
+ wer = float(result['ins'] + result['sub'] +
455
+ result['del']) * 100.0 / result['all']
456
+ else:
457
+ wer = 0.0
458
+ print('WER: %4.2f %%' % wer, end=' ')
459
+ print('N=%d C=%d S=%d D=%d I=%d' %
460
+ (result['all'], result['cor'], result['sub'], result['del'],
461
+ result['ins']))
462
+ space = {}
463
+ space['lab'] = []
464
+ space['rec'] = []
465
+ for idx in range(len(result['lab'])):
466
+ len_lab = width(result['lab'][idx])
467
+ len_rec = width(result['rec'][idx])
468
+ length = max(len_lab, len_rec)
469
+ space['lab'].append(length - len_lab)
470
+ space['rec'].append(length - len_rec)
471
+ upper_lab = len(result['lab'])
472
+ upper_rec = len(result['rec'])
473
+ lab1, rec1 = 0, 0
474
+ while lab1 < upper_lab or rec1 < upper_rec:
475
+ if verbose > 1:
476
+ print('lab(%s):' % fid.encode('utf-8'), end=' ')
477
+ else:
478
+ print('lab:', end=' ')
479
+ lab2 = min(upper_lab, lab1 + max_words_per_line)
480
+ for idx in range(lab1, lab2):
481
+ token = result['lab'][idx]
482
+ print('{token}'.format(token=token), end='')
483
+ for n in range(space['lab'][idx]):
484
+ print(padding_symbol, end='')
485
+ print(' ', end='')
486
+ print()
487
+ if verbose > 1:
488
+ print('rec(%s):' % fid.encode('utf-8'), end=' ')
489
+ else:
490
+ print('rec:', end=' ')
491
+ rec2 = min(upper_rec, rec1 + max_words_per_line)
492
+ for idx in range(rec1, rec2):
493
+ token = result['rec'][idx]
494
+ print('{token}'.format(token=token), end='')
495
+ for n in range(space['rec'][idx]):
496
+ print(padding_symbol, end='')
497
+ print(' ', end='')
498
+ print('\n', end='\n')
499
+ lab1 = lab2
500
+ rec1 = rec2
501
+
502
+ if verbose:
503
+ print(
504
+ '==========================================================================='
505
+ )
506
+ print()
507
+
508
+ result = calculator.overall()
509
+ if result['all'] != 0:
510
+ wer = float(result['ins'] + result['sub'] +
511
+ result['del']) * 100.0 / result['all']
512
+ else:
513
+ wer = 0.0
514
+ print('Overall -> %4.2f %%' % wer, end=' ')
515
+ print('N=%d C=%d S=%d D=%d I=%d' %
516
+ (result['all'], result['cor'], result['sub'], result['del'],
517
+ result['ins']))
518
+ if not verbose:
519
+ print()
520
+
521
+ if verbose:
522
+ for cluster_id in default_clusters:
523
+ result = calculator.cluster(
524
+ [k for k in default_clusters[cluster_id]])
525
+ if result['all'] != 0:
526
+ wer = float(result['ins'] + result['sub'] +
527
+ result['del']) * 100.0 / result['all']
528
+ else:
529
+ wer = 0.0
530
+ print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
531
+ print('N=%d C=%d S=%d D=%d I=%d' %
532
+ (result['all'], result['cor'], result['sub'], result['del'],
533
+ result['ins']))
534
+ if len(cluster_file) > 0: # compute separated WERs for word clusters
535
+ cluster_id = ''
536
+ cluster = []
537
+ for line in open(cluster_file, 'r', encoding='utf-8'):
538
+ for token in line.decode('utf-8').rstrip('\n').split():
539
+ # end of cluster reached, like </Keyword>
540
+ if token[0:2] == '</' and token[len(token)-1] == '>' and \
541
+ token.lstrip('</').rstrip('>') == cluster_id :
542
+ result = calculator.cluster(cluster)
543
+ if result['all'] != 0:
544
+ wer = float(result['ins'] + result['sub'] +
545
+ result['del']) * 100.0 / result['all']
546
+ else:
547
+ wer = 0.0
548
+ print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
549
+ print('N=%d C=%d S=%d D=%d I=%d' %
550
+ (result['all'], result['cor'], result['sub'],
551
+ result['del'], result['ins']))
552
+ cluster_id = ''
553
+ cluster = []
554
+ # begin of cluster reached, like <Keyword>
555
+ elif token[0] == '<' and token[len(token)-1] == '>' and \
556
+ cluster_id == '' :
557
+ cluster_id = token.lstrip('<').rstrip('>')
558
+ cluster = []
559
+ # general terms, like WEATHER / CAR / ...
560
+ else:
561
+ cluster.append(token)
562
+ print()
563
+ print(
564
+ '==========================================================================='
565
+ )
inference/inference-finetune-lid.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # pip install transformers datasets torch soundfile jiwer
4
+
5
+ from datasets import load_dataset, Audio
6
+ from transformers import pipeline, WhisperProcessor
7
+ from torch.utils.data import DataLoader
8
+ import torch
9
+ from jiwer import wer as jiwer_wer
10
+ from jiwer import cer as jiwer_cer
11
+ import ipdb
12
+ import subprocess
13
+ import os
14
+
15
+ # 1. Load FLEURS Burmese test set, cast to 16 kHz audio
16
+ ds = load_dataset("google/fleurs", "km_kh", split="test", trust_remote_code=True)
17
+ ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
18
+
19
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
20
+
21
+
22
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
23
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
24
+
25
+ model_id = "pengyizhou/whisper-fleurs-km_kh-small"
26
+
27
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
28
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
29
+ )
30
+ model.to(device)
31
+ whisper_model = "openai/whisper-large-v3"
32
+ processor = WhisperProcessor.from_pretrained(whisper_model)
33
+
34
+ asr = pipeline(
35
+ "automatic-speech-recognition",
36
+ model=model,
37
+ tokenizer=processor.tokenizer,
38
+ feature_extractor=processor.feature_extractor,
39
+ torch_dtype=torch_dtype,
40
+ chunk_length_s=30,
41
+ batch_size=64,
42
+ max_new_tokens=225,
43
+ device=device,
44
+ num_beams=1, # Use beam search for better quality
45
+ )
46
+
47
+ generate_kwargs = {
48
+ "condition_on_prev_tokens": False,
49
+ "compression_ratio_threshold": 1.35, # zlib compression ratio threshold (in token space)
50
+ "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
51
+ "logprob_threshold": -1.0,
52
+ "language": "khmer", # Specify the language for transcription
53
+ }
54
+
55
+
56
+ # 3. Batch‐wise transcription function
57
+ def transcribe_batch(batch):
58
+ # `batch["audio"]` is a list of {"array": np.ndarray, ...}
59
+ inputs = [ ex["array"] for ex in batch["audio"] ]
60
+ outputs = asr(inputs, generate_kwargs=generate_kwargs) # returns a list of dicts with "text"
61
+ # lower-case and strip to normalize for CER
62
+ preds = [ out["text"].lower().strip() for out in outputs ]
63
+ return {"prediction": preds}
64
+
65
+ # 4. Map over the dataset in chunks of, say, 32 examples at a time
66
+ result = ds.map(
67
+ transcribe_batch,
68
+ batched=True,
69
+ batch_size=64, # feed 32 audios → pipeline will sub-batch into 8s
70
+ remove_columns=ds.column_names
71
+ )
72
+
73
+ # ipdb.set_trace()
74
+ # 5. Compute corpus-level CER with jiwer
75
+ # refs = "\n".join(t.lower().strip() for t in ds["transcription"])
76
+ # preds = "\n".join(t for t in result["prediction"])
77
+ # score = jiwer_cer(refs, preds)
78
+ ids = [key for key in ds["id"]]
79
+ refs = [t.lower().strip() for t in ds["transcription"]]
80
+ preds = [t for t in result["prediction"]]
81
+ score_cer = jiwer_cer(refs, preds)
82
+ score_wer = jiwer_wer(refs, preds)
83
+
84
+ print(f"CER on FLEURS km_kh: {score_cer*100:.2f}%")
85
+ print(f"WER on FLEURS km_kh: {score_wer*100:.2f}%")
86
+
87
+ # Function to add spaces between characters for CER calculation
88
+ def add_char_spaces(text):
89
+ """Add spaces between each character for character-level evaluation"""
90
+ return ' '.join(list(text.strip()))
91
+
92
+ with open("./km_kh_finetune.pred", "w") as pred_results:
93
+ for key, pred in zip(ids, preds):
94
+ pred_with_spaces = add_char_spaces(pred)
95
+ pred_results.write("{} {}\n".format(key, pred_with_spaces))
96
+
97
+ with open("./km_kh.ref", "w") as ref_results:
98
+ for key, ref in zip(ids, refs):
99
+ ref_with_spaces = add_char_spaces(ref)
100
+ ref_results.write("{} {}\n".format(key, ref_with_spaces))
101
+
102
+ # Generate WER file using compute-wer.py
103
+ print("Generating detailed WER analysis...")
104
+
105
+ # Check if compute-wer.py exists
106
+ compute_wer_script = "./compute-wer.py"
107
+ if not os.path.exists(compute_wer_script):
108
+ # Try to find it in parent directories or common locations
109
+ possible_locations = [
110
+ "./compute-wer.py",
111
+ ]
112
+ for location in possible_locations:
113
+ if os.path.exists(location):
114
+ compute_wer_script = location
115
+ break
116
+ else:
117
+ print(f"Warning: compute-wer.py not found. Tried: {[compute_wer_script] + possible_locations}")
118
+ print("Skipping detailed WER analysis.")
119
+ compute_wer_script = None
120
+
121
+ if compute_wer_script:
122
+ try:
123
+ # Run compute-wer.py with character-level analysis
124
+ ref_file = "./km_kh.ref"
125
+ hyp_file = "./km_kh_finetune.pred"
126
+ wer_file = "./km_kh_finetune.wer"
127
+
128
+ cmd = [
129
+ "python", compute_wer_script,
130
+ "--char=1", # Character-level analysis
131
+ "--v=1", # Verbose output
132
+ ref_file,
133
+ hyp_file
134
+ ]
135
+
136
+ print(f"Running: {' '.join(cmd)} > {wer_file}")
137
+
138
+ # Run the command and redirect output to wer file
139
+ with open(wer_file, "w") as wer_output:
140
+ result = subprocess.run(
141
+ cmd,
142
+ stdout=wer_output,
143
+ stderr=subprocess.PIPE,
144
+ text=True,
145
+ check=True
146
+ )
147
+
148
+ print(f"CER analysis saved to {wer_file}")
149
+
150
+ # Optionally, print the first few lines of the WER file
151
+ if os.path.exists(wer_file):
152
+ print("\nFirst few lines of WER analysis:")
153
+ with open(wer_file, "r") as f:
154
+ lines = f.readlines()
155
+ for i, line in enumerate(lines[:10]): # Show first 10 lines
156
+ print(f" {line.rstrip()}")
157
+ if len(lines) > 10:
158
+ print(f" ... ({len(lines) - 10} more lines)")
159
+
160
+ except subprocess.CalledProcessError as e:
161
+ print(f"Error running compute-wer.py: {e}")
162
+ if e.stderr:
163
+ print(f"Error details: {e.stderr}")
164
+ except Exception as e:
165
+ print(f"Unexpected error: {e}")
166
+
167
+ print("Inference and CER analysis completed!")
168
+
inference/inference-finetune-nolid.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # pip install transformers datasets torch soundfile jiwer
4
+
5
+ from datasets import load_dataset, Audio
6
+ from transformers import pipeline, WhisperProcessor
7
+ from torch.utils.data import DataLoader
8
+ import torch
9
+ from jiwer import wer as jiwer_wer
10
+ from jiwer import cer as jiwer_cer
11
+ import ipdb
12
+ import subprocess
13
+ import os
14
+
15
+ # 1. Load FLEURS Burmese test set, cast to 16 kHz audio
16
+ ds = load_dataset("google/fleurs", "km_kh", split="test", trust_remote_code=True)
17
+ ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
18
+
19
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
20
+
21
+
22
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
23
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
24
+
25
+ model_id = "pengyizhou/whisper-fleurs-km_kh-small"
26
+
27
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
28
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
29
+ )
30
+ model.to(device)
31
+ whisper_model = "openai/whisper-large-v3"
32
+ processor = WhisperProcessor.from_pretrained(whisper_model)
33
+
34
+ asr = pipeline(
35
+ "automatic-speech-recognition",
36
+ model=model,
37
+ tokenizer=processor.tokenizer,
38
+ feature_extractor=processor.feature_extractor,
39
+ torch_dtype=torch_dtype,
40
+ chunk_length_s=30,
41
+ batch_size=64,
42
+ max_new_tokens=225,
43
+ device=device,
44
+ num_beams=1, # Use beam search for better quality
45
+ )
46
+
47
+ generate_kwargs = {
48
+ "condition_on_prev_tokens": False,
49
+ "compression_ratio_threshold": 1.35, # zlib compression ratio threshold (in token space)
50
+ "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
51
+ "logprob_threshold": -1.0,
52
+ }
53
+
54
+
55
+ # 3. Batch‐wise transcription function
56
+ def transcribe_batch(batch):
57
+ # `batch["audio"]` is a list of {"array": np.ndarray, ...}
58
+ inputs = [ ex["array"] for ex in batch["audio"] ]
59
+ outputs = asr(inputs, generate_kwargs=generate_kwargs) # returns a list of dicts with "text"
60
+ # lower-case and strip to normalize for CER
61
+ preds = [ out["text"].lower().strip() for out in outputs ]
62
+ return {"prediction": preds}
63
+
64
+ # 4. Map over the dataset in chunks of, say, 32 examples at a time
65
+ result = ds.map(
66
+ transcribe_batch,
67
+ batched=True,
68
+ batch_size=64, # feed 32 audios → pipeline will sub-batch into 8s
69
+ remove_columns=ds.column_names
70
+ )
71
+
72
+ # ipdb.set_trace()
73
+ # 5. Compute corpus-level CER with jiwer
74
+ # refs = "\n".join(t.lower().strip() for t in ds["transcription"])
75
+ # preds = "\n".join(t for t in result["prediction"])
76
+ # score = jiwer_cer(refs, preds)
77
+ ids = [key for key in ds["id"]]
78
+ refs = [t.lower().strip() for t in ds["transcription"]]
79
+ preds = [t for t in result["prediction"]]
80
+ score_cer = jiwer_cer(refs, preds)
81
+ score_wer = jiwer_wer(refs, preds)
82
+
83
+ print(f"CER on FLEURS km_kh: {score_cer*100:.2f}%")
84
+ print(f"WER on FLEURS km_kh: {score_wer*100:.2f}%")
85
+
86
+ # Function to add spaces between characters for CER calculation
87
+ def add_char_spaces(text):
88
+ """Add spaces between each character for character-level evaluation"""
89
+ return ' '.join(list(text.strip()))
90
+
91
+ with open("./km_kh_finetune_nolid.pred", "w") as pred_results:
92
+ for key, pred in zip(ids, preds):
93
+ pred_with_spaces = add_char_spaces(pred)
94
+ pred_results.write("{} {}\n".format(key, pred_with_spaces))
95
+
96
+ with open("./km_kh.ref", "w") as ref_results:
97
+ for key, ref in zip(ids, refs):
98
+ ref_with_spaces = add_char_spaces(ref)
99
+ ref_results.write("{} {}\n".format(key, ref_with_spaces))
100
+
101
+ # Generate WER file using compute-wer.py
102
+ print("Generating detailed WER analysis...")
103
+
104
+ # Check if compute-wer.py exists
105
+ compute_wer_script = "./compute-wer.py"
106
+ if not os.path.exists(compute_wer_script):
107
+ # Try to find it in parent directories or common locations
108
+ possible_locations = [
109
+ "./compute-wer.py",
110
+ ]
111
+ for location in possible_locations:
112
+ if os.path.exists(location):
113
+ compute_wer_script = location
114
+ break
115
+ else:
116
+ print(f"Warning: compute-wer.py not found. Tried: {[compute_wer_script] + possible_locations}")
117
+ print("Skipping detailed WER analysis.")
118
+ compute_wer_script = None
119
+
120
+ if compute_wer_script:
121
+ try:
122
+ # Run compute-wer.py with character-level analysis
123
+ ref_file = "./km_kh.ref"
124
+ hyp_file = "./km_kh_finetune_nolid.pred"
125
+ wer_file = "./km_kh_finetune_nolid.wer"
126
+
127
+ cmd = [
128
+ "python", compute_wer_script,
129
+ "--char=1", # Character-level analysis
130
+ "--v=1", # Verbose output
131
+ ref_file,
132
+ hyp_file
133
+ ]
134
+
135
+ print(f"Running: {' '.join(cmd)} > {wer_file}")
136
+
137
+ # Run the command and redirect output to wer file
138
+ with open(wer_file, "w") as wer_output:
139
+ result = subprocess.run(
140
+ cmd,
141
+ stdout=wer_output,
142
+ stderr=subprocess.PIPE,
143
+ text=True,
144
+ check=True
145
+ )
146
+
147
+ print(f"CER analysis saved to {wer_file}")
148
+
149
+ # Optionally, print the first few lines of the WER file
150
+ if os.path.exists(wer_file):
151
+ print("\nFirst few lines of WER analysis:")
152
+ with open(wer_file, "r") as f:
153
+ lines = f.readlines()
154
+ for i, line in enumerate(lines[:10]): # Show first 10 lines
155
+ print(f" {line.rstrip()}")
156
+ if len(lines) > 10:
157
+ print(f" ... ({len(lines) - 10} more lines)")
158
+
159
+ except subprocess.CalledProcessError as e:
160
+ print(f"Error running compute-wer.py: {e}")
161
+ if e.stderr:
162
+ print(f"Error details: {e.stderr}")
163
+ except Exception as e:
164
+ print(f"Unexpected error: {e}")
165
+
166
+ print("Inference and CER analysis completed!")
167
+
inference/inference-ft.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ python ./inference-finetune-lid.py
4
+ python ./inference-finetune-nolid.py
inference/inference-zeroshot-nolid.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # pip install transformers datasets torch soundfile jiwer
4
+
5
+ from datasets import load_dataset, Audio
6
+ from transformers import pipeline, WhisperProcessor
7
+ from torch.utils.data import DataLoader
8
+ import torch
9
+ from jiwer import wer as jiwer_wer
10
+ from jiwer import cer as jiwer_cer
11
+ import ipdb
12
+ import subprocess
13
+ import os
14
+
15
+ # 1. Load FLEURS Burmese test set, cast to 16 kHz audio
16
+ ds = load_dataset("google/fleurs", "km_kh", split="test")
17
+ ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
18
+
19
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
20
+
21
+
22
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
23
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
24
+
25
+ model_id = "openai/whisper-large-v3"
26
+
27
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
28
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
29
+ )
30
+ model.to(device)
31
+ whisper_model = "openai/whisper-large-v3"
32
+ processor = WhisperProcessor.from_pretrained(whisper_model, language="khmer")
33
+
34
+ asr = pipeline(
35
+ "automatic-speech-recognition",
36
+ model=model,
37
+ tokenizer=processor.tokenizer,
38
+ feature_extractor=processor.feature_extractor,
39
+ torch_dtype=torch_dtype,
40
+ chunk_length_s=30,
41
+ batch_size=64,
42
+ max_new_tokens=225,
43
+ device=device,
44
+ num_beams=1, # Use beam search for better quality
45
+ )
46
+ generate_kwargs = {
47
+ "condition_on_prev_tokens": False,
48
+ "compression_ratio_threshold": 1.35, # zlib compression ratio threshold (in token space)
49
+ "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
50
+ "logprob_threshold": -1.0,
51
+ }
52
+
53
+ # 3. Batch‐wise transcription function
54
+ def transcribe_batch(batch):
55
+ # `batch["audio"]` is a list of {"array": np.ndarray, ...}
56
+ inputs = [ ex["array"] for ex in batch["audio"] ]
57
+ outputs = asr(inputs, generate_kwargs=generate_kwargs) # returns a list of dicts with "text"
58
+ # lower-case and strip to normalize for CER
59
+ preds = [ out["text"].lower().strip() for out in outputs ]
60
+ return {"prediction": preds}
61
+
62
+ # 4. Map over the dataset in chunks of, say, 32 examples at a time
63
+ result = ds.map(
64
+ transcribe_batch,
65
+ batched=True,
66
+ batch_size=64, # feed 64 audios → pipeline will sub-batch into 8s
67
+ remove_columns=ds.column_names
68
+ )
69
+
70
+ # ipdb.set_trace()
71
+ # 5. Compute corpus-level CER with jiwer
72
+ # refs = "\n".join(t.lower().strip() for t in ds["transcription"])
73
+ # preds = "\n".join(t for t in result["prediction"])
74
+ # score = jiwer_cer(refs, preds)
75
+ ids = [key for key in ds["id"]]
76
+ refs = [t.lower().strip() for t in ds["transcription"]]
77
+ preds = [t for t in result["prediction"]]
78
+ score_cer = jiwer_cer(refs, preds)
79
+ score_wer = jiwer_wer(refs, preds)
80
+
81
+ print(f"CER on FLEURS km_kh: {score_cer*100:.2f}%")
82
+ print(f"WER on FLEURS km_kh: {score_wer*100:.2f}%")
83
+ # Function to add spaces between characters for CER calculation
84
+ def add_char_spaces(text):
85
+ """Add spaces between each character for character-level evaluation"""
86
+ return ' '.join(list(text.strip()))
87
+
88
+ with open("./km_kh_zs_nolid.pred", "w") as pred_results:
89
+ for key, pred in zip(ids, preds):
90
+ pred_with_spaces = add_char_spaces(pred)
91
+ pred_results.write("{} {}\n".format(key, pred_with_spaces))
92
+
93
+ with open("./km_kh.ref", "w") as ref_results:
94
+ for key, ref in zip(ids, refs):
95
+ ref_with_spaces = add_char_spaces(ref)
96
+ ref_results.write("{} {}\n".format(key, ref_with_spaces))
97
+
98
+ # Generate WER file using compute-wer.py
99
+ print("Generating detailed WER analysis...")
100
+
101
+ # Check if compute-wer.py exists
102
+ compute_wer_script = "./compute-wer.py"
103
+ if not os.path.exists(compute_wer_script):
104
+ # Try to find it in parent directories or common locations
105
+ possible_locations = [
106
+ "./compute-wer.py",
107
+ ]
108
+ for location in possible_locations:
109
+ if os.path.exists(location):
110
+ compute_wer_script = location
111
+ break
112
+ else:
113
+ print(f"Warning: compute-wer.py not found. Tried: {[compute_wer_script] + possible_locations}")
114
+ print("Skipping detailed WER analysis.")
115
+ compute_wer_script = None
116
+
117
+ if compute_wer_script:
118
+ try:
119
+ # Run compute-wer.py with character-level analysis
120
+ ref_file = "./km_kh.ref"
121
+ hyp_file = "./km_kh_zs_nolid.pred"
122
+ wer_file = "./km_kh_zs_nolid.wer"
123
+
124
+ cmd = [
125
+ "python", compute_wer_script,
126
+ "--char=1", # Character-level analysis
127
+ "--v=1", # Verbose output
128
+ ref_file,
129
+ hyp_file
130
+ ]
131
+
132
+ print(f"Running: {' '.join(cmd)} > {wer_file}")
133
+
134
+ # Run the command and redirect output to wer file
135
+ with open(wer_file, "w") as wer_output:
136
+ result = subprocess.run(
137
+ cmd,
138
+ stdout=wer_output,
139
+ stderr=subprocess.PIPE,
140
+ text=True,
141
+ check=True
142
+ )
143
+
144
+ print(f"CER analysis saved to {wer_file}")
145
+
146
+ # Optionally, print the first few lines of the WER file
147
+ if os.path.exists(wer_file):
148
+ print("\nFirst few lines of WER analysis:")
149
+ with open(wer_file, "r") as f:
150
+ lines = f.readlines()
151
+ for i, line in enumerate(lines[:10]): # Show first 10 lines
152
+ print(f" {line.rstrip()}")
153
+ if len(lines) > 10:
154
+ print(f" ... ({len(lines) - 10} more lines)")
155
+
156
+ except subprocess.CalledProcessError as e:
157
+ print(f"Error running compute-wer.py: {e}")
158
+ if e.stderr:
159
+ print(f"Error details: {e.stderr}")
160
+ except Exception as e:
161
+ print(f"Unexpected error: {e}")
162
+
163
+ print("Inference and CER analysis completed!")
164
+
inference/inference-zeroshot.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # pip install transformers datasets torch soundfile jiwer
4
+
5
+ from datasets import load_dataset, Audio
6
+ from transformers import pipeline, WhisperProcessor
7
+ from torch.utils.data import DataLoader
8
+ import torch
9
+ from jiwer import wer as jiwer_wer
10
+ from jiwer import cer as jiwer_cer
11
+ import ipdb
12
+ import subprocess
13
+ import os
14
+
15
+ # 1. Load FLEURS Burmese test set, cast to 16 kHz audio
16
+ ds = load_dataset("google/fleurs", "km_kh", split="test")
17
+ ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
18
+
19
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
20
+
21
+
22
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
23
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
24
+
25
+ model_id = "openai/whisper-large-v3"
26
+
27
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
28
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
29
+ )
30
+ model.to(device)
31
+ whisper_model = "openai/whisper-large-v3"
32
+ processor = WhisperProcessor.from_pretrained(whisper_model, language="khmer")
33
+
34
+ asr = pipeline(
35
+ "automatic-speech-recognition",
36
+ model=model,
37
+ tokenizer=processor.tokenizer,
38
+ feature_extractor=processor.feature_extractor,
39
+ torch_dtype=torch_dtype,
40
+ chunk_length_s=30,
41
+ batch_size=64,
42
+ max_new_tokens=225,
43
+ device=device,
44
+ num_beams=1, # Use beam search for better quality
45
+ )
46
+
47
+ generate_kwargs = {
48
+ "condition_on_prev_tokens": False,
49
+ "compression_ratio_threshold": 1.35, # zlib compression ratio threshold (in token space)
50
+ "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
51
+ "logprob_threshold": -1.0,
52
+ "language": "khmer", # Specify the language for transcription
53
+ }
54
+ # 3. Batch‐wise transcription function
55
+ def transcribe_batch(batch):
56
+ # `batch["audio"]` is a list of {"array": np.ndarray, ...}
57
+ inputs = [ ex["array"] for ex in batch["audio"] ]
58
+ outputs = asr(inputs, generate_kwargs=generate_kwargs) # returns a list of dicts with "text"
59
+ # lower-case and strip to normalize for CER
60
+ preds = [ out["text"].lower().strip() for out in outputs ]
61
+ return {"prediction": preds}
62
+
63
+ # 4. Map over the dataset in chunks of, say, 32 examples at a time
64
+ result = ds.map(
65
+ transcribe_batch,
66
+ batched=True,
67
+ batch_size=64, # feed 32 audios → pipeline will sub-batch into 8s
68
+ remove_columns=ds.column_names
69
+ )
70
+
71
+ # ipdb.set_trace()
72
+ # 5. Compute corpus-level CER with jiwer
73
+ # refs = "\n".join(t.lower().strip() for t in ds["transcription"])
74
+ # preds = "\n".join(t for t in result["prediction"])
75
+ # score = jiwer_cer(refs, preds)
76
+ ids = [key for key in ds["id"]]
77
+ refs = [t.lower().strip() for t in ds["transcription"]]
78
+ preds = [t for t in result["prediction"]]
79
+ score_cer = jiwer_cer(refs, preds)
80
+ score_wer = jiwer_wer(refs, preds)
81
+
82
+ print(f"CER on FLEURS km_kh: {score_cer*100:.2f}%")
83
+ print(f"WER on FLEURS km_kh: {score_wer*100:.2f}%")
84
+
85
+ # Function to add spaces between characters for CER calculation
86
+ def add_char_spaces(text):
87
+ """Add spaces between each character for character-level evaluation"""
88
+ return ' '.join(list(text.strip()))
89
+
90
+ with open("./km_kh_zs_lid.pred", "w") as pred_results:
91
+ for key, pred in zip(ids, preds):
92
+ pred_with_spaces = add_char_spaces(pred)
93
+ pred_results.write("{} {}\n".format(key, pred_with_spaces))
94
+
95
+ with open("./km_kh.ref", "w") as ref_results:
96
+ for key, ref in zip(ids, refs):
97
+ ref_with_spaces = add_char_spaces(ref)
98
+ ref_results.write("{} {}\n".format(key, ref_with_spaces))
99
+
100
+ # Generate WER file using compute-wer.py
101
+ print("Generating detailed WER analysis...")
102
+
103
+ # Check if compute-wer.py exists
104
+ compute_wer_script = "./compute-wer.py"
105
+ if not os.path.exists(compute_wer_script):
106
+ # Try to find it in parent directories or common locations
107
+ possible_locations = [
108
+ "./compute-wer.py",
109
+ ]
110
+ for location in possible_locations:
111
+ if os.path.exists(location):
112
+ compute_wer_script = location
113
+ break
114
+ else:
115
+ print(f"Warning: compute-wer.py not found. Tried: {[compute_wer_script] + possible_locations}")
116
+ print("Skipping detailed WER analysis.")
117
+ compute_wer_script = None
118
+
119
+ if compute_wer_script:
120
+ try:
121
+ # Run compute-wer.py with character-level analysis
122
+ ref_file = "./km_kh.ref"
123
+ hyp_file = "./km_kh_zs_lid.pred"
124
+ wer_file = "./km_kh_zs_lid.wer"
125
+
126
+ cmd = [
127
+ "python", compute_wer_script,
128
+ "--char=1", # Character-level analysis
129
+ "--v=1", # Verbose output
130
+ ref_file,
131
+ hyp_file
132
+ ]
133
+
134
+ print(f"Running: {' '.join(cmd)} > {wer_file}")
135
+
136
+ # Run the command and redirect output to wer file
137
+ with open(wer_file, "w") as wer_output:
138
+ result = subprocess.run(
139
+ cmd,
140
+ stdout=wer_output,
141
+ stderr=subprocess.PIPE,
142
+ text=True,
143
+ check=True
144
+ )
145
+
146
+ print(f"CER analysis saved to {wer_file}")
147
+
148
+ # Optionally, print the first few lines of the WER file
149
+ if os.path.exists(wer_file):
150
+ print("\nFirst few lines of WER analysis:")
151
+ with open(wer_file, "r") as f:
152
+ lines = f.readlines()
153
+ for i, line in enumerate(lines[:10]): # Show first 10 lines
154
+ print(f" {line.rstrip()}")
155
+ if len(lines) > 10:
156
+ print(f" ... ({len(lines) - 10} more lines)")
157
+
158
+ except subprocess.CalledProcessError as e:
159
+ print(f"Error running compute-wer.py: {e}")
160
+ if e.stderr:
161
+ print(f"Error details: {e.stderr}")
162
+ except Exception as e:
163
+ print(f"Unexpected error: {e}")
164
+
165
+ print("Inference and CER analysis completed!")
inference/inference-zs.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ python ./inference-zeroshot.py
4
+ python ./inference-zeroshot-nolid.py
5
+ python ./inference.py
inference/km_kh.ref ADDED
The diff for this file is too large to render. See raw diff
 
inference/km_kh_finetune.pred ADDED
The diff for this file is too large to render. See raw diff
 
inference/km_kh_finetune.wer ADDED
The diff for this file is too large to render. See raw diff
 
inference/km_kh_finetune_nolid.pred ADDED
The diff for this file is too large to render. See raw diff
 
inference/km_kh_finetune_nolid.wer ADDED
The diff for this file is too large to render. See raw diff
 
inference/km_kh_zs_lid.pred ADDED
The diff for this file is too large to render. See raw diff
 
inference/km_kh_zs_lid.wer ADDED
The diff for this file is too large to render. See raw diff
 
inference/km_kh_zs_nolid.pred ADDED
The diff for this file is too large to render. See raw diff
 
inference/km_kh_zs_nolid.wer ADDED
The diff for this file is too large to render. See raw diff