|
|
from __future__ import annotations |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
import warnings |
|
|
|
|
|
import torch |
|
|
from safetensors.torch import load_file |
|
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase |
|
|
|
|
|
from ..config import DiaConfig, load_config |
|
|
from ..core.model import Dia2Model |
|
|
from ..core.precision import Precision, resolve_precision |
|
|
from ..audio import MimiCodec, DEFAULT_MIMI_MODEL_ID |
|
|
from .state_machine import StateMachine, TokenIds |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RuntimeContext: |
|
|
config: DiaConfig |
|
|
model: Dia2Model |
|
|
precision: Precision |
|
|
tokenizer: PreTrainedTokenizerBase |
|
|
mimi: MimiCodec |
|
|
device: torch.device |
|
|
machine: StateMachine |
|
|
transformer_step: callable |
|
|
depformer_step: callable |
|
|
constants: TokenIds |
|
|
audio_delays: list[int] |
|
|
audio_delay_tensor: torch.Tensor |
|
|
frame_rate: float |
|
|
|
|
|
|
|
|
def build_runtime( |
|
|
*, |
|
|
config_path: str | Path, |
|
|
weights_path: str | Path, |
|
|
tokenizer_id: Optional[str], |
|
|
repo_id: Optional[str], |
|
|
mimi_id: Optional[str], |
|
|
device: str, |
|
|
dtype_pref: str, |
|
|
) -> tuple[RuntimeContext, str, str]: |
|
|
device_obj = torch.device(device) |
|
|
if device_obj.type == "cuda": |
|
|
cuda_matmul = torch.backends.cuda.matmul |
|
|
cudnn_conv = torch.backends.cudnn.conv |
|
|
if hasattr(cuda_matmul, "fp32_precision"): |
|
|
cuda_matmul.fp32_precision = "tf32" |
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings( |
|
|
"ignore", |
|
|
message="Please use the new API settings", |
|
|
) |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
else: |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
if hasattr(cudnn_conv, "fp32_precision"): |
|
|
cudnn_conv.fp32_precision = "tf32" |
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings( |
|
|
"ignore", |
|
|
message="Please use the new API settings", |
|
|
) |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
else: |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
precision = resolve_precision(dtype_pref, device_obj) |
|
|
config = load_config(config_path) |
|
|
model = Dia2Model(config, precision) |
|
|
state = load_file(str(weights_path)) |
|
|
model.load_state_dict(state) |
|
|
model = model.to(device_obj) |
|
|
|
|
|
tokenizer_ref = tokenizer_id or config.assets.tokenizer or repo_id |
|
|
if tokenizer_ref is None: |
|
|
raise ValueError("Tokenizer id is missing. Provide --tokenizer or add assets.tokenizer to the config.") |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
tokenizer_ref, |
|
|
use_fast=False, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
mimi_ref = mimi_id or config.assets.mimi or DEFAULT_MIMI_MODEL_ID |
|
|
mimi = MimiCodec.from_pretrained(mimi_ref, device=device_obj) |
|
|
|
|
|
data_cfg = config.data |
|
|
constants = TokenIds( |
|
|
card=data_cfg.text_vocab_size, |
|
|
new_word=data_cfg.text_new_word_token_id, |
|
|
pad=data_cfg.text_pad_token_id, |
|
|
bos=getattr(tokenizer, "bos_token_id", 1) or 1, |
|
|
zero=data_cfg.text_zero_token_id, |
|
|
spk1=tokenizer.convert_tokens_to_ids("[S1]") if "[S1]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id, |
|
|
spk2=tokenizer.convert_tokens_to_ids("[S2]") if "[S2]" in tokenizer.get_vocab() else data_cfg.text_new_word_token_id, |
|
|
audio_pad=data_cfg.audio_pad_token_id, |
|
|
audio_bos=data_cfg.audio_bos_token_id, |
|
|
) |
|
|
machine = StateMachine( |
|
|
token_ids=constants, |
|
|
second_stream_ahead=data_cfg.second_stream_ahead, |
|
|
max_padding=6, |
|
|
initial_padding=0, |
|
|
) |
|
|
audio_delays = list(data_cfg.delay_pattern) |
|
|
audio_delay_tensor = torch.tensor(audio_delays, device=device_obj, dtype=torch.long) if audio_delays else torch.empty(0, dtype=torch.long, device=device_obj) |
|
|
frame_rate = getattr(mimi, "frame_rate", 75.0) |
|
|
|
|
|
runtime = RuntimeContext( |
|
|
config=config, |
|
|
precision=precision, |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
mimi=mimi, |
|
|
device=device_obj, |
|
|
machine=machine, |
|
|
constants=constants, |
|
|
audio_delays=audio_delays, |
|
|
audio_delay_tensor=audio_delay_tensor, |
|
|
frame_rate=frame_rate, |
|
|
transformer_step=model.transformer.forward_step, |
|
|
depformer_step=model.depformer.forward_step, |
|
|
) |
|
|
return runtime, tokenizer_ref, mimi_ref |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"RuntimeContext", |
|
|
"build_runtime", |
|
|
] |
|
|
|