from io import BytesIO from urllib.request import urlopen import soundfile import torch from datasets import load_dataset, Audio import numpy as np from transformers import AutoModel, AutoProcessor, BatchFeature from tqdm import tqdm import json import os import time from datetime import datetime from whisper_normalizer.english import EnglishTextNormalizer from whisper_normalizer.basic import BasicTextNormalizer import sacrebleu from jiwer import cer, wer from torch.utils.data import Dataset, DataLoader import soundfile as sf import re normalizer = { "en_us" : EnglishTextNormalizer(), "ko_kr" : BasicTextNormalizer() } # 모델 및 프로세서 로드 model_id = "junnei/gemma-3-4b-it-speech" revision = "v1.0" model = AutoModel.from_pretrained( model_id, device_map="auto", revision = revision, trust_remote_code=True ).eval() processor = AutoProcessor.from_pretrained( model_id, revision = revision, trust_remote_code=True ) # 결과 저장 디렉토리 생성 results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}" os.makedirs(results_dir, exist_ok=True) INSTRUCTION = { "ast": "Translate the audio to {0}.", "asr": "Transcribe the audio clip into text.", } class CoVoSTDataset(Dataset): def __init__(self, processor, data_dir, ast=False, lang=("en_ko", "Korean")): self.data = load_dataset("junnei/covost2", lang[0], data_dir=data_dir, split='test', trust_remote_code=True ) original_size = len(self.data) self.data = self.data.cast_column("audio", Audio(decode=False)) def identify_corrupted_files(example): try: # 디코딩 시도 sf.read(example["audio"]["path"]) if example['translation'] == "" or example['sentence'] == "": return False return True except Exception: return False self.data = self.data.filter(identify_corrupted_files, num_proc=16) validated_size = len(self.data) self.data = self.data.cast_column("audio", Audio(sampling_rate = 16000, decode=True)) self.lang = lang[0] self.ast = ast print(f"- {self.lang}: {('AST' if self.ast else 'ASR')}") print(f"원본 데이터 개수: {original_size}") print(f"에러 데이터 개수: {original_size - validated_size}") print(f"필터링 비율: {validated_size/original_size:.2%}") self.processor = processor self.instruction = INSTRUCTION["ast"].format(lang[1]) if ast else INSTRUCTION["asr"] def __len__(self): return len(self.data) def __getitem__(self, idx): data = self.data[idx] user_message = { 'role': 'user', 'content': '' + self.instruction, } prompt = self.processor.tokenizer.apply_chat_template( [user_message], tokenize=False, add_generation_prompt=True, add_bos=True ) inputs = self.processor(text=prompt, audio=[data["audio"]["array"]], add_special_tokens=False, return_tensors='pt') sentence = data['sentence'].replace('"', '') answer = f"{data['translation'] if self.ast else sentence}" return { 'input_ids': inputs.input_ids, 'attention_mask': inputs.attention_mask, 'token_type_ids': inputs.token_type_ids, 'input_modes': inputs.input_modes, 'input_audio_embeds': inputs.input_audio_embeds, 'audio_embed_sizes': inputs.audio_embed_sizes, 'sentence': sentence, 'answer': answer, } def select(self, indices): self.data = self.data.select(indices) return self def pad_sequence(sequences, padding_side='right', padding_value=0): """ Pad a list of sequences to the same length. sequences: list of tensors in [seq_len, *] shape """ assert padding_side in ['right', 'left'] max_size = sequences[0].size() trailing_dims = max_size[1:] max_len = max(len(seq) for seq in sequences) batch_size = len(sequences) output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value) for i, seq in enumerate(sequences): length = seq.size(0) if padding_side == 'right': output.data[i, :length] = seq else: output.data[i, -length:] = seq return output def cat_with_pad(tensors, dim, padding_value=0): """ cat along dim, while pad to max for all other dims """ ndim = tensors[0].dim() assert all( t.dim() == ndim for t in tensors[1:] ), 'All tensors must have the same number of dimensions' out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] out_size[dim] = sum(t.shape[dim] for t in tensors) output = tensors[0].new_full(out_size, padding_value) index = 0 for t in tensors: # Create a slice list where every dimension except dim is full slice slices = [slice(0, t.shape[d]) for d in range(ndim)] # Update only the concat dimension slice slices[dim] = slice(index, index + t.shape[dim]) output[slices] = t index += t.shape[dim] return output def covost_collate_fn(batch): input_ids_list = [] input_audio_embeds_list = [] audio_embed_sizes_list = [] audio_attention_mask_list = [] input_modes_list = [] sentence_list = [] answer_list = [] for inputs in batch: input_ids_list.append(inputs['input_ids'][0]) input_audio_embeds_list.append(inputs['input_audio_embeds']) audio_embed_sizes_list.append(inputs['audio_embed_sizes']) audio_attention_mask_list.append( inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool) ) input_modes_list.append(inputs['input_modes']) sentence_list.append(inputs['sentence']) answer_list.append(inputs['answer']) try: input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0) audio_attention_mask = ( pad_sequence(audio_attention_mask_list, padding_side='right', padding_value=False) if len(audio_attention_mask_list) > 1 else None ) except Exception as e: print(e) print(input_ids_list) print(audio_attention_mask) raise attention_mask = (input_ids != 0).long() input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0) audio_embed_sizes = torch.cat(audio_embed_sizes_list) input_modes = torch.cat(input_modes_list) return BatchFeature( { 'input_ids': input_ids, 'attention_mask': attention_mask, 'input_audio_embeds': input_audio_embeds, 'audio_embed_sizes': audio_embed_sizes, 'audio_attention_mask': audio_attention_mask, 'input_modes': input_modes, 'sentence': sentence_list, 'answer': answer_list, } ) def save_results(results, task, source_lang, target_lang=None, sample_idx=None): """결과를 JSON 파일로 저장""" filename = f"{task}_{source_lang}" if target_lang: filename += f"_to_{target_lang}" if sample_idx is not None: filename += f"_sample_{sample_idx}" filepath = os.path.join(results_dir, f"{filename}.json") # 결과에 타임스탬프 추가 results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") with open(filepath, 'w', encoding='utf-8') as f: json.dump(results, f, ensure_ascii=False, indent=2) print(f"결과가 {filepath}에 저장되었습니다.") return filepath def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size = 32, is_asr=True): """ASR(자동 음성 인식) 성능 평가""" task_type = "asr" if is_asr else "translation" eval_lang = source_lang if is_asr else target_lang eval_normalizer = normalizer[eval_lang] sample_results = [] # 샘플 수 처리 if num_samples > 0 and num_samples < len(dataset): indices = np.random.choice(len(dataset), num_samples, replace=False) dataset = dataset.select(indices) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=covost_collate_fn) evaluated_samples = {} # 배치 단위로 처리 for batch_idx, batch in enumerate(tqdm(dataloader)): batch_sentences = batch.pop("sentence") batch_references = batch.pop("answer") # GPU로 이동 if torch.cuda.is_available(): batch = {k: v.to("cuda") for k, v in batch.items()} # 배치 추론 with torch.inference_mode(): generate_ids = model.generate(**batch, max_new_tokens=256, do_sample=False) input_lengths = batch['input_ids'].shape[1] generate_ids = generate_ids[:, input_lengths:] # 디코딩 batch_predictions = processor.batch_decode( generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) # 결과 저장 for i, (sentence, reference, prediction) in enumerate(zip(batch_sentences, batch_references, batch_predictions)): idx = batch_idx * batch_size + i sample_result = { "id": idx, "sentence": sentence, "reference": reference, "prediction": prediction } sample_results.append(sample_result) # 10배치마다 중간 결과 저장 if (batch_idx + 1) % 10 == 0: temp_results = [] # 모든 샘플에 대해 처리 for item in sample_results: sample_id = item["id"] # 이미 평가된 샘플은 평가 결과를 재사용 if sample_id in evaluated_samples: temp_item = item.copy() temp_item.update(evaluated_samples[sample_id]) temp_results.append(temp_item) else: # 아직 평가되지 않은 샘플은 새로 평가 temp_item = item.copy() try: ref = eval_normalizer(item["reference"]) pred = eval_normalizer(item["prediction"]) # BLEU, WER/CER 계산 utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2) utt_wer = round(wer(ref, pred) * 100, 2) metrics = { "bleu": utt_bleu, "cer": utt_cer, "wer": utt_wer } # 평가 결과 저장 evaluated_samples[sample_id] = metrics temp_item.update(metrics) except Exception as e: print(f"Error evaluating sample {sample_id}: {e}") # 오류 발생 시 기본값 설정 metrics = { "bleu": 0, "cer": 100, "wer": 100, "error": str(e) } evaluated_samples[sample_id] = metrics temp_item.update(metrics) temp_results.append(temp_item) partial_results = { "task": task_type, "source_lang": source_lang, "target_lang": target_lang, "num_samples": len(temp_results), "sample_results": temp_results } save_results(partial_results, task_type, source_lang, target_lang) for item in sample_results: ref = eval_normalizer(item["reference"]) pred = eval_normalizer(item["prediction"]) # BLEU, WER/CER 계산 utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2) utt_wer = round(wer(ref, pred) * 100, 2) item.update({ "bleu": utt_bleu, "cer": utt_cer, "wer": utt_wer }) avg_bleu = sum(item["bleu"] for item in sample_results) / len(sample_results) avg_cer = sum(item["cer"] for item in sample_results) / len(sample_results) avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results) results = { "task": task_type, "source_lang": source_lang, "target_lang": target_lang, "num_samples": len(sample_results), "metrics": { "bleu": avg_bleu, "cer": avg_cer, "wer": avg_wer }, "sample_results": sample_results } # 최종 결과 저장 save_results(results, task_type, source_lang, target_lang) return results # 메인 실행 코드 if __name__ == "__main__": # 평가할 언어 목록 (소스 언어) source_languages = [ ("en_us", "English"), # 영어 (미국) #("ko_kr", "Korean"), ] # 번역 대상 언어 목록 (코드, 이름) target_languages = [ ("ko_kr", "Korean"), #("en_us", "English"), ] data_dir = { "en_us" : "/workspace/CommonVoice/EN", #"ko_kr" : "/workspace/CommonVoice/ko", } # 샘플 수 설정 (-1은 전체 데이터셋 사용) num_samples = -1 batch_size = 16 # 모든 소스 언어에 대해 ASR 평가 for source_lang, target_lang in zip(source_languages, target_languages): print(f"\n===== {source_lang[0]} ASR 평가 시작 =====") # 데이터셋 로드 covost = CoVoSTDataset(processor, data_dir[source_lang[0]], ast=False, lang=(f"{source_lang[0].split('_')[0]}_{target_lang[0].split('_')[0]}", f"{target_lang[1]}")) # ASR 평가 asr_results = evaluate_task(covost, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = True) print(f"\n=== {source_lang[0]} ASR 결과 ===") print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}") print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}") print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}") try: print(f"\n===== {source_lang[0]} -> {target_lang[0]} 번역 평가 시작 =====") # 데이터셋 로드 covost = CoVoSTDataset(processor, data_dir[source_lang[0]], ast=True, lang=(f"{source_lang[0].split('_')[0]}_{target_lang[0].split('_')[0]}", f"{target_lang[1]}")) # 번역 평가 translation_results = evaluate_task(covost, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = False) print(f"\n=== {source_lang[0]} -> {target_lang[0]} 번역 결과 ===") print(f"BLEU: {translation_results.get('metrics', {}).get('bleu', 'N/A')}") print(f"WER: {translation_results.get('metrics', {}).get('wer', 'N/A')}") print(f"CER: {translation_results.get('metrics', {}).get('cer', 'N/A')}") except Exception as e: error_info = { "error": str(e), "source_lang": source_lang[0], "target_lang": target_lang[0], "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") } error_file = os.path.join(results_dir, f"error_translation_{source_lang[0]}_to_{target_lang[0]}_global.json") with open(error_file, 'w') as f: json.dump(error_info, f, indent=2) print(f"{source_lang[0]} -> {target_lang[0]} 번역 평가 중 오류 발생: {str(e)}") continue print(f"\n모든 평가가 완료되었습니다. 결과는 {results_dir} 디렉토리에 저장되었습니다.")