Dia2-2B / dia2 /runtime /audio_io.py
NariLabs's picture
Upload folder using huggingface_hub
1315cad verified
raw
history blame
2.33 kB
from __future__ import annotations
from pathlib import Path
from typing import Union
import numpy as np
import sphn
import torch
import torch.nn.functional as F
from ..audio import MimiCodec
PathLike = Union[str, Path]
def load_mono_audio(path: PathLike, target_sr: int) -> np.ndarray:
"""Read an audio file, convert to mono float32, and resample to target_sr."""
path = str(path)
try:
audio, sr = sphn.read_wav(path)
except Exception:
import soundfile as sf # Local fallback
audio, sr = sf.read(path, dtype="float32", always_2d=False)
audio = np.asarray(audio, dtype=np.float32)
if audio.ndim == 2:
audio = audio.mean(axis=1)
if sr != target_sr:
if hasattr(sphn, "resample_audio"):
audio = sphn.resample_audio(audio, sr, target_sr).astype(np.float32)
else:
audio = _resample_linear(audio, sr, target_sr)
return audio
def audio_to_tensor(audio: np.ndarray, device: torch.device) -> torch.Tensor:
"""Convert mono PCM samples into shape [1, 1, T] tensor."""
tensor = torch.from_numpy(audio).to(device)
if tensor.dim() == 1:
tensor = tensor.unsqueeze(0)
if tensor.dim() == 2:
tensor = tensor.unsqueeze(0)
return tensor
def encode_audio_tokens(mimi: MimiCodec, audio: np.ndarray) -> torch.Tensor:
"""Encode PCM audio into Mimi codebook tokens [C, T]."""
waveform = audio_to_tensor(audio, mimi.device)
with torch.inference_mode():
codes, *_ = mimi.encode(waveform, return_dict=False)
if isinstance(codes, (tuple, list)):
codes = codes[0]
# Mimi.encode returns [B, num_codebooks, T]; select batch 0.
codes = codes[0].to(torch.long)
return codes
def _resample_linear(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
if src_sr == dst_sr:
return audio.astype(np.float32)
length = audio.shape[0]
new_length = max(1, int(round(length * dst_sr / src_sr)))
tensor = torch.from_numpy(audio.astype(np.float32)).unsqueeze(0).unsqueeze(0)
with torch.no_grad():
resampled = F.interpolate(tensor, size=new_length, mode="linear", align_corners=False)
return resampled.squeeze(0).squeeze(0).cpu().numpy().astype(np.float32)
__all__ = ["load_mono_audio", "audio_to_tensor", "encode_audio_tokens"]