echo-tts-preview / inference.py
GaboChoropan's picture
Update inference.py
4ad0649 verified
from dataclasses import dataclass
from typing import Callable, List
import torch
import safetensors.torch as st
from huggingface_hub import hf_hub_download
from model import EchoDiT
from autoencoder import build_ae, DAC
import torchaudio
from torchcodec.decoders import AudioDecoder
# ============================================================
# Types
# ============================================================
SampleFn = Callable[
[
EchoDiT,
torch.Tensor, # speaker_latent
torch.Tensor, # speaker_mask
torch.Tensor, # text_ids
torch.Tensor, # text_mask
int, # seed
],
torch.Tensor,
]
# ============================================================
# Loading
# ============================================================
def load_model_from_hf(
repo_id: str = "jordand/echo-tts-base",
device: str = "cuda",
dtype: torch.dtype | None = torch.bfloat16,
compile: bool = False,
token: str | None = None,
) -> EchoDiT:
with torch.device("meta"):
model = EchoDiT(
latent_size=80,
model_size=2048,
num_layers=24,
num_heads=16,
intermediate_size=5888,
norm_eps=1e-5,
max_seq_len=640,
text_vocab_size=256,
text_model_size=1280,
text_num_layers=14,
text_num_heads=10,
text_intermediate_size=3328,
text_max_seq_len=768,
speaker_patch_size=4,
speaker_model_size=1280,
speaker_num_layers=14,
speaker_num_heads=10,
speaker_intermediate_size=3328,
speaker_max_patched_seq_len=640,
timestep_embed_size=512,
adaln_rank=256,
)
w_path = hf_hub_download(repo_id, "pytorch_model.safetensors", token=token)
state = st.load_file(w_path, device="cpu")
if dtype is not None:
state = {k: v.to(dtype=dtype) for k, v in state.items()}
state = {k: v.to(device=device) for k, v in state.items()}
model.load_state_dict(state, strict=False, assign=True)
model = model.eval()
if compile:
model = torch.compile(model)
model.get_kv_cache = torch.compile(model.get_kv_cache)
return model
def load_fish_ae_from_hf(
repo_id: str = "jordand/fish-s1-dac-min",
device: str = "cuda",
dtype: torch.dtype | None = torch.float32,
compile: bool = False,
token: str | None = None,
) -> DAC:
with torch.device("meta"):
fish_ae = build_ae()
w_path = hf_hub_download(repo_id, "pytorch_model.safetensors", token=token)
state = st.load_file(w_path, device="cpu")
if dtype is not None:
state = {k: v.to(dtype=dtype) for k, v in state.items()}
state = {k: v.to(device=device) for k, v in state.items()}
fish_ae.load_state_dict(state, strict=False, assign=True)
fish_ae = fish_ae.eval().to(device)
if compile:
fish_ae.encoder = torch.compile(fish_ae.encoder)
fish_ae.decoder = torch.compile(fish_ae.decoder)
return fish_ae
# ============================================================
# PCA
# ============================================================
@dataclass
class PCAState:
pca_components: torch.Tensor
pca_mean: torch.Tensor
latent_scale: float
def load_pca_state_from_hf(
repo_id: str = "jordand/echo-tts-base",
device: str = "cuda",
filename: str = "pca_state.safetensors",
token: str | None = None,
) -> PCAState:
p_path = hf_hub_download(repo_id, filename, token=token)
t = st.load_file(p_path, device=device)
return PCAState(
pca_components=t["pca_components"],
pca_mean=t["pca_mean"],
latent_scale=float(t["latent_scale"].item()),
)
# ============================================================
# Audio loading (UNCHANGED, SAFE)
# ============================================================
def load_audio(path: str) -> torch.Tensor:
decoder = AudioDecoder(path)
sr = decoder.metadata.sample_rate
audio = decoder.get_samples_played_in_range(0, 120)
audio = audio.data.mean(dim=0).unsqueeze(0)
audio = torchaudio.functional.resample(audio, sr, 44_100)
audio = audio / torch.maximum(audio.abs().max(), torch.tensor(1.0))
return audio # (1, T)
# ============================================================
# Text helpers
# ============================================================
def tokenizer_encode(text: str, append_bos: bool = True, normalize: bool = True) -> torch.Tensor:
if normalize:
text = (
text.replace("…", "...")
.replace("“", '"')
.replace("”", '"')
.replace("’", "'")
.replace("\n", " ")
.replace(":", ",")
.replace(";", ",")
)
b = list(text.encode("utf-8"))
if append_bos:
b.insert(0, 0)
return torch.tensor(b)
def get_text_input_ids_and_mask(
text_arr: List[str],
max_length: int | None,
device: str | None = None,
):
batch_size = len(text_arr)
if max_length is None:
max_length = max(len(tokenizer_encode(text)) for text in text_arr)
tokens = torch.zeros((batch_size, max_length), dtype=torch.int32)
mask = torch.zeros((batch_size, max_length), dtype=torch.bool)
for i, text in enumerate(text_arr):
encoded = tokenizer_encode(text)
length = min(len(encoded), max_length)
tokens[i, :length] = encoded[:length]
mask[i, :length] = 1
if device is not None:
tokens = tokens.to(device)
mask = mask.to(device)
return tokens, mask
# ============================================================
# Autoencoder
# ============================================================
@torch.inference_mode()
def ae_encode(fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor) -> torch.Tensor:
z_q = fish_ae.encode_zq(audio).float()
z_q = (z_q.transpose(1, 2) - pca_state.pca_mean) @ pca_state.pca_components.T
return z_q * pca_state.latent_scale
@torch.inference_mode()
def ae_decode(fish_ae: DAC, pca_state: PCAState, z_q: torch.Tensor) -> torch.Tensor:
z_q = (z_q / pca_state.latent_scale) @ pca_state.pca_components + pca_state.pca_mean
return fish_ae.decode_zq(z_q.transpose(1, 2).to(fish_ae.dtype)).float()
# ============================================================
# Speaker & Content latents
# ============================================================
@torch.inference_mode()
def get_speaker_latent_and_mask(
fish_ae: DAC,
pca_state: PCAState,
audio: torch.Tensor, # (1, T)
max_speaker_latent_len: int = 2560,
):
AE_DOWNSAMPLE = 2048
audio = audio[:, : max_speaker_latent_len * AE_DOWNSAMPLE]
z = ae_encode(fish_ae, pca_state, audio.unsqueeze(1))
length = audio.shape[1] // AE_DOWNSAMPLE
mask = (torch.arange(z.shape[1], device=z.device) < length).unsqueeze(0)
if z.shape[1] < max_speaker_latent_len:
pad = max_speaker_latent_len - z.shape[1]
z = torch.nn.functional.pad(z, (0, 0, 0, pad))
mask = torch.nn.functional.pad(mask, (0, pad))
return z, mask
@torch.inference_mode()
def get_content_latent(
fish_ae: DAC,
pca_state: PCAState,
audio: torch.Tensor, # (1, 1, T)
):
return ae_encode(fish_ae, pca_state, audio)
# ============================================================
# TTS pipeline (unchanged signature)
# ============================================================
@torch.inference_mode()
def sample_pipeline(
model: EchoDiT,
fish_ae: DAC,
pca_state: PCAState,
sample_fn: SampleFn,
text_prompt: str,
speaker_audio: torch.Tensor | None,
rng_seed: int,
):
device = model.device
text_ids, text_mask = get_text_input_ids_and_mask([text_prompt], 768, device)
if speaker_audio is None:
speaker_latent = torch.zeros((1, 2560, 80), device=device)
speaker_mask = torch.zeros((1, 2560), device=device, dtype=torch.bool)
else:
speaker_latent, speaker_mask = get_speaker_latent_and_mask(
fish_ae, pca_state, speaker_audio
)
speaker_latent = speaker_latent.to(device)
speaker_mask = speaker_mask.to(device)
latent = sample_fn(
model,
speaker_latent,
speaker_mask,
text_ids,
text_mask,
rng_seed,
)
return ae_decode(fish_ae, pca_state, latent)
# ============================================================
# ✅ Voice Conversion pipeline
# ============================================================
@torch.inference_mode()
def voice_conversion_pipeline(
model: EchoDiT,
fish_ae: DAC,
pca_state: PCAState,
sample_fn: SampleFn,
source_audio: torch.Tensor, # (1, 1, T)
target_speaker_audio: torch.Tensor, # (1, T)
rng_seed: int,
force_speaker_kv: bool = False,
):
device = model.device
speaker_latent, speaker_mask = get_speaker_latent_and_mask(
fish_ae, pca_state, target_speaker_audio
)
speaker_latent = speaker_latent.to(device)
speaker_mask = speaker_mask.to(device)
content_latent = get_content_latent(
fish_ae, pca_state, source_audio
).to(device)
text_ids = torch.zeros((1, 1), dtype=torch.int32, device=device)
text_mask = torch.zeros((1, 1), dtype=torch.bool, device=device)
latent = sample_fn(
model,
speaker_latent,
speaker_mask,
text_ids,
text_mask,
rng_seed,
init_latent=content_latent,
force_speaker_kv=force_speaker_kv,
)
return ae_decode(fish_ae, pca_state, latent)