|
|
from __future__ import annotations |
|
|
|
|
|
from pathlib import Path |
|
|
from typing import Union |
|
|
|
|
|
import numpy as np |
|
|
import sphn |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from ..audio import MimiCodec |
|
|
|
|
|
PathLike = Union[str, Path] |
|
|
|
|
|
|
|
|
def load_mono_audio(path: PathLike, target_sr: int) -> np.ndarray: |
|
|
"""Read an audio file, convert to mono float32, and resample to target_sr.""" |
|
|
path = str(path) |
|
|
try: |
|
|
audio, sr = sphn.read_wav(path) |
|
|
except Exception: |
|
|
import soundfile as sf |
|
|
|
|
|
audio, sr = sf.read(path, dtype="float32", always_2d=False) |
|
|
audio = np.asarray(audio, dtype=np.float32) |
|
|
if audio.ndim == 2: |
|
|
audio = audio.mean(axis=1) |
|
|
if sr != target_sr: |
|
|
if hasattr(sphn, "resample_audio"): |
|
|
audio = sphn.resample_audio(audio, sr, target_sr).astype(np.float32) |
|
|
else: |
|
|
audio = _resample_linear(audio, sr, target_sr) |
|
|
return audio |
|
|
|
|
|
|
|
|
def audio_to_tensor(audio: np.ndarray, device: torch.device) -> torch.Tensor: |
|
|
"""Convert mono PCM samples into shape [1, 1, T] tensor.""" |
|
|
tensor = torch.from_numpy(audio).to(device) |
|
|
if tensor.dim() == 1: |
|
|
tensor = tensor.unsqueeze(0) |
|
|
if tensor.dim() == 2: |
|
|
tensor = tensor.unsqueeze(0) |
|
|
return tensor |
|
|
|
|
|
|
|
|
def encode_audio_tokens(mimi: MimiCodec, audio: np.ndarray) -> torch.Tensor: |
|
|
"""Encode PCM audio into Mimi codebook tokens [C, T].""" |
|
|
waveform = audio_to_tensor(audio, mimi.device) |
|
|
with torch.inference_mode(): |
|
|
codes, *_ = mimi.encode(waveform, return_dict=False) |
|
|
if isinstance(codes, (tuple, list)): |
|
|
codes = codes[0] |
|
|
|
|
|
codes = codes[0].to(torch.long) |
|
|
return codes |
|
|
|
|
|
|
|
|
def _resample_linear(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray: |
|
|
if src_sr == dst_sr: |
|
|
return audio.astype(np.float32) |
|
|
length = audio.shape[0] |
|
|
new_length = max(1, int(round(length * dst_sr / src_sr))) |
|
|
tensor = torch.from_numpy(audio.astype(np.float32)).unsqueeze(0).unsqueeze(0) |
|
|
with torch.no_grad(): |
|
|
resampled = F.interpolate(tensor, size=new_length, mode="linear", align_corners=False) |
|
|
return resampled.squeeze(0).squeeze(0).cpu().numpy().astype(np.float32) |
|
|
|
|
|
|
|
|
__all__ = ["load_mono_audio", "audio_to_tensor", "encode_audio_tokens"] |
|
|
|