| | import argparse |
| | from pathlib import Path |
| |
|
| | import torch |
| | import torchaudio |
| | from transformers import ( |
| | AutoConfig, |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | WhisperFeatureExtractor, |
| | ) |
| |
|
| |
|
| | WHISPER_FEAT_CFG = { |
| | "chunk_length": 30, |
| | "feature_extractor_type": "WhisperFeatureExtractor", |
| | "feature_size": 128, |
| | "hop_length": 160, |
| | "n_fft": 400, |
| | "n_samples": 480000, |
| | "nb_max_frames": 3000, |
| | "padding_side": "right", |
| | "padding_value": 0.0, |
| | "processor_class": "WhisperProcessor", |
| | "return_attention_mask": False, |
| | "sampling_rate": 16000, |
| | } |
| |
|
| | def get_audio_token_length(seconds, merge_factor=2): |
| | def get_T_after_cnn(L_in, dilation=1): |
| | for padding, kernel_size, stride in eval("[(1,3,1)] + [(1,3,2)] "): |
| | L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 |
| | L_out = 1 + L_out // stride |
| | L_in = L_out |
| | return L_out |
| |
|
| | mel_len = int(seconds * 100) |
| | audio_len_after_cnn = get_T_after_cnn(mel_len) |
| | audio_token_num = (audio_len_after_cnn - merge_factor) // merge_factor + 1 |
| |
|
| | |
| | audio_token_num = min(audio_token_num, 1500 // merge_factor) |
| |
|
| | return audio_token_num |
| |
|
| | def build_prompt( |
| | audio_path: Path, |
| | tokenizer, |
| | feature_extractor: WhisperFeatureExtractor, |
| | merge_factor: int, |
| | chunk_seconds: int = 30, |
| | ) -> dict: |
| | audio_path = Path(audio_path) |
| | wav, sr = torchaudio.load(str(audio_path)) |
| | wav = wav[:1, :] |
| | if sr != feature_extractor.sampling_rate: |
| | wav = torchaudio.transforms.Resample(sr, feature_extractor.sampling_rate)(wav) |
| |
|
| | tokens = [] |
| | tokens += tokenizer.encode("<|user|>") |
| | tokens += tokenizer.encode("\n") |
| |
|
| | audios = [] |
| | audio_offsets = [] |
| | audio_length = [] |
| | chunk_size = chunk_seconds * feature_extractor.sampling_rate |
| | for start in range(0, wav.shape[1], chunk_size): |
| | chunk = wav[:, start : start + chunk_size] |
| | mel = feature_extractor( |
| | chunk.numpy(), |
| | sampling_rate=feature_extractor.sampling_rate, |
| | return_tensors="pt", |
| | padding="max_length", |
| | )["input_features"] |
| | audios.append(mel) |
| | seconds = chunk.shape[1] / feature_extractor.sampling_rate |
| | num_tokens = get_audio_token_length(seconds, merge_factor) |
| | tokens += tokenizer.encode("<|begin_of_audio|>") |
| | audio_offsets.append(len(tokens)) |
| | tokens += [0] * num_tokens |
| | tokens += tokenizer.encode("<|end_of_audio|>") |
| | audio_length.append(num_tokens) |
| |
|
| | if not audios: |
| | raise ValueError("音频内容为空或加载失败。") |
| |
|
| | tokens += tokenizer.encode("<|user|>") |
| | tokens += tokenizer.encode("\nPlease transcribe this audio into text") |
| |
|
| | tokens += tokenizer.encode("<|assistant|>") |
| | tokens += tokenizer.encode("\n") |
| |
|
| | batch = { |
| | "input_ids": torch.tensor([tokens], dtype=torch.long), |
| | "audios": torch.cat(audios, dim=0), |
| | "audio_offsets": [audio_offsets], |
| | "audio_length": [audio_length], |
| | "attention_mask": torch.ones(1, len(tokens), dtype=torch.long), |
| | } |
| | return batch |
| |
|
| |
|
| | def prepare_inputs(batch: dict, device: torch.device) -> tuple[dict, int]: |
| | tokens = batch["input_ids"].to(device) |
| | attention_mask = batch["attention_mask"].to(device) |
| | audios = batch["audios"].to(device) |
| | model_inputs = { |
| | "inputs": tokens, |
| | "attention_mask": attention_mask, |
| | "audios": audios.to(torch.bfloat16), |
| | "audio_offsets": batch["audio_offsets"], |
| | "audio_length": batch["audio_length"], |
| | } |
| | return model_inputs, tokens.size(1) |
| |
|
| |
|
| | def transcribe( |
| | checkpoint_dir: Path, |
| | audio_path: Path, |
| | tokenizer_path: str | None, |
| | max_new_tokens: int, |
| | device: str, |
| | ): |
| | tokenizer_source = tokenizer_path if tokenizer_path else checkpoint_dir |
| | tokenizer = AutoTokenizer.from_pretrained(tokenizer_source) |
| | feature_extractor = WhisperFeatureExtractor(**WHISPER_FEAT_CFG) |
| |
|
| | config = AutoConfig.from_pretrained(checkpoint_dir, trust_remote_code=True) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | checkpoint_dir, |
| | config=config, |
| | torch_dtype=torch.bfloat16, |
| | trust_remote_code=True, |
| | ).to(device) |
| | model.eval() |
| |
|
| | batch = build_prompt( |
| | audio_path, |
| | tokenizer, |
| | feature_extractor, |
| | merge_factor=config.merge_factor, |
| | ) |
| |
|
| | model_inputs, prompt_len = prepare_inputs(batch, device) |
| |
|
| | with torch.inference_mode(): |
| | generated = model.generate( |
| | **model_inputs, |
| | max_new_tokens=max_new_tokens, |
| | do_sample=False, |
| | ) |
| | transcript_ids = generated[0, prompt_len:].cpu().tolist() |
| | transcript = tokenizer.decode(transcript_ids, skip_special_tokens=True).strip() |
| | print("----------") |
| | print(transcript or "[Empty transcription]") |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Minimal ASR transcription demo.") |
| | parser.add_argument("--checkpoint_dir", type=str, default=str(Path(__file__).parent)) |
| | parser.add_argument("--audio", type=str, required=True, help="Path to audio file.") |
| | parser.add_argument( |
| | "--tokenizer_path", |
| | type=str, |
| | default=None, |
| | help="Tokenizer directory (defaults to checkpoint dir when omitted).", |
| | ) |
| | parser.add_argument("--max_new_tokens", type=int, default=128) |
| | parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") |
| | args = parser.parse_args() |
| |
|
| | transcribe( |
| | checkpoint_dir=Path(args.checkpoint_dir), |
| | audio_path=Path(args.audio), |
| | tokenizer_path=args.tokenizer_path, |
| | max_new_tokens=args.max_new_tokens, |
| | device=args.device, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|