| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Sequence | |
| import numpy as np | |
| import torch | |
| def delay_frames(aligned: torch.Tensor, delays: Sequence[int], pad_id: int) -> torch.Tensor: | |
| channels, total = aligned.shape | |
| max_delay = max(delays) if delays else 0 | |
| out = aligned.new_full((channels, total + max_delay), pad_id) | |
| for idx, delay in enumerate(delays): | |
| out[idx, delay : delay + total] = aligned[idx] | |
| return out | |
| def undelay_frames(delayed: torch.Tensor, delays: Sequence[int], pad_id: int) -> torch.Tensor: | |
| channels, total = delayed.shape | |
| max_delay = max(delays) if delays else 0 | |
| target = max(0, total - max_delay) | |
| out = delayed.new_full((channels, target), pad_id) | |
| for idx, delay in enumerate(delays): | |
| out[idx] = delayed[idx, delay : delay + target] | |
| return out | |
| def mask_audio_logits(logits: torch.Tensor, pad_idx: int, bos_idx: int) -> torch.Tensor: | |
| if logits.shape[-1] == 0: | |
| return logits | |
| max_idx = logits.shape[-1] - 1 | |
| targets = [idx for idx in (pad_idx, bos_idx) if 0 <= idx <= max_idx] | |
| if not targets: | |
| return logits | |
| masked = logits.clone() | |
| neg_inf = torch.finfo(masked.dtype).min | |
| for idx in targets: | |
| masked[..., idx] = neg_inf | |
| return masked | |
| def fill_audio_channels( | |
| delays: Sequence[int], | |
| constants, | |
| step: int, | |
| step_tokens: torch.Tensor, | |
| audio_buf: torch.Tensor, | |
| ) -> None: | |
| for cb, delay in enumerate(delays): | |
| idx = step - delay | |
| in_bounds = idx >= 0 and step < audio_buf.shape[-1] | |
| if in_bounds: | |
| step_tokens[:, 2 + cb, 0] = audio_buf[:, cb, step] | |
| else: | |
| step_tokens[:, 2 + cb, 0] = constants.audio_bos | |
| def write_wav(path: str | Path, audio: np.ndarray, sample_rate: int) -> None: | |
| path = Path(path) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| audio = np.clip(audio, -1.0, 1.0) | |
| pcm16 = (audio * 32767.0).astype(np.int16) | |
| import wave | |
| with wave.open(str(path), "wb") as handle: | |
| handle.setnchannels(1) | |
| handle.setsampwidth(2) | |
| handle.setframerate(sample_rate) | |
| handle.writeframes(pcm16.tobytes()) | |
| __all__ = [ | |
| "delay_frames", | |
| "undelay_frames", | |
| "mask_audio_logits", | |
| "fill_audio_channels", | |
| "write_wav", | |
| ] | |