Dia2-2B / dia2 /audio /codec.py
NariLabs's picture
Upload folder using huggingface_hub
1315cad verified
raw
history blame
1.79 kB
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from transformers import MimiModel
DEFAULT_MIMI_MODEL_ID = "kyutai/mimi"
@dataclass(frozen=True)
class MimiConfig:
model_id: str = DEFAULT_MIMI_MODEL_ID
dtype: Optional[torch.dtype] = None
class MimiCodec(nn.Module):
"""Thin wrapper around transformers' MimiModel for decoding audio tokens."""
def __init__(self, model: MimiModel, device: torch.device) -> None:
super().__init__()
self.model = model
self.device = device
cfg = getattr(model, "config", None)
self.sample_rate = getattr(cfg, "sampling_rate", 24000)
self.frame_rate = getattr(cfg, "frame_rate", 12.5)
self.samples_per_frame = int(round(self.sample_rate / self.frame_rate)) if self.frame_rate else 0
@classmethod
def from_pretrained(
cls,
model_id: str = DEFAULT_MIMI_MODEL_ID,
*,
device: torch.device,
dtype: Optional[torch.dtype] = None,
) -> "MimiCodec":
model = MimiModel.from_pretrained(
model_id,
torch_dtype=dtype,
low_cpu_mem_usage=True,
)
model = model.to(device)
model.eval()
return cls(model, device)
def decode(self, codes: torch.Tensor) -> torch.Tensor:
codes = codes.to(self.device)
with torch.inference_mode():
audio, _ = self.model.decode(codes, return_dict=False)
return torch.clamp(audio, -1.0, 1.0)
def encode(self, audio: torch.Tensor, *, return_dict: bool = False):
audio = audio.to(self.device)
with torch.inference_mode():
return self.model.encode(audio, return_dict=return_dict)