from __future__ import annotations from dataclasses import dataclass from typing import Optional import torch from torch import nn from transformers import MimiModel DEFAULT_MIMI_MODEL_ID = "kyutai/mimi" @dataclass(frozen=True) class MimiConfig: model_id: str = DEFAULT_MIMI_MODEL_ID dtype: Optional[torch.dtype] = None class MimiCodec(nn.Module): """Thin wrapper around transformers' MimiModel for decoding audio tokens.""" def __init__(self, model: MimiModel, device: torch.device) -> None: super().__init__() self.model = model self.device = device cfg = getattr(model, "config", None) self.sample_rate = getattr(cfg, "sampling_rate", 24000) self.frame_rate = getattr(cfg, "frame_rate", 12.5) self.samples_per_frame = int(round(self.sample_rate / self.frame_rate)) if self.frame_rate else 0 @classmethod def from_pretrained( cls, model_id: str = DEFAULT_MIMI_MODEL_ID, *, device: torch.device, dtype: Optional[torch.dtype] = None, ) -> "MimiCodec": model = MimiModel.from_pretrained( model_id, torch_dtype=dtype, low_cpu_mem_usage=True, ) model = model.to(device) model.eval() return cls(model, device) def decode(self, codes: torch.Tensor) -> torch.Tensor: codes = codes.to(self.device) with torch.inference_mode(): audio, _ = self.model.decode(codes, return_dict=False) return torch.clamp(audio, -1.0, 1.0) def encode(self, audio: torch.Tensor, *, return_dict: bool = False): audio = audio.to(self.device) with torch.inference_mode(): return self.model.encode(audio, return_dict=return_dict)