Dia2-2B / dia2 /audio /grid.py
NariLabs's picture
Upload folder using huggingface_hub
1315cad verified
raw
history blame
2.32 kB
from __future__ import annotations
from pathlib import Path
from typing import Sequence
import numpy as np
import torch
def delay_frames(aligned: torch.Tensor, delays: Sequence[int], pad_id: int) -> torch.Tensor:
channels, total = aligned.shape
max_delay = max(delays) if delays else 0
out = aligned.new_full((channels, total + max_delay), pad_id)
for idx, delay in enumerate(delays):
out[idx, delay : delay + total] = aligned[idx]
return out
def undelay_frames(delayed: torch.Tensor, delays: Sequence[int], pad_id: int) -> torch.Tensor:
channels, total = delayed.shape
max_delay = max(delays) if delays else 0
target = max(0, total - max_delay)
out = delayed.new_full((channels, target), pad_id)
for idx, delay in enumerate(delays):
out[idx] = delayed[idx, delay : delay + target]
return out
def mask_audio_logits(logits: torch.Tensor, pad_idx: int, bos_idx: int) -> torch.Tensor:
if logits.shape[-1] == 0:
return logits
max_idx = logits.shape[-1] - 1
targets = [idx for idx in (pad_idx, bos_idx) if 0 <= idx <= max_idx]
if not targets:
return logits
masked = logits.clone()
neg_inf = torch.finfo(masked.dtype).min
for idx in targets:
masked[..., idx] = neg_inf
return masked
def fill_audio_channels(
delays: Sequence[int],
constants,
step: int,
step_tokens: torch.Tensor,
audio_buf: torch.Tensor,
) -> None:
for cb, delay in enumerate(delays):
idx = step - delay
in_bounds = idx >= 0 and step < audio_buf.shape[-1]
if in_bounds:
step_tokens[:, 2 + cb, 0] = audio_buf[:, cb, step]
else:
step_tokens[:, 2 + cb, 0] = constants.audio_bos
def write_wav(path: str | Path, audio: np.ndarray, sample_rate: int) -> None:
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
audio = np.clip(audio, -1.0, 1.0)
pcm16 = (audio * 32767.0).astype(np.int16)
import wave
with wave.open(str(path), "wb") as handle:
handle.setnchannels(1)
handle.setsampwidth(2)
handle.setframerate(sample_rate)
handle.writeframes(pcm16.tobytes())
__all__ = [
"delay_frames",
"undelay_frames",
"mask_audio_logits",
"fill_audio_channels",
"write_wav",
]