Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from typing import Callable, List | |
| import torch | |
| import safetensors.torch as st | |
| from huggingface_hub import hf_hub_download | |
| from model import EchoDiT | |
| from autoencoder import build_ae, DAC | |
| import torchaudio | |
| from torchcodec.decoders import AudioDecoder | |
| # ============================================================ | |
| # Types | |
| # ============================================================ | |
| SampleFn = Callable[ | |
| [ | |
| EchoDiT, | |
| torch.Tensor, # speaker_latent | |
| torch.Tensor, # speaker_mask | |
| torch.Tensor, # text_ids | |
| torch.Tensor, # text_mask | |
| int, # seed | |
| ], | |
| torch.Tensor, | |
| ] | |
| # ============================================================ | |
| # Loading | |
| # ============================================================ | |
| def load_model_from_hf( | |
| repo_id: str = "jordand/echo-tts-base", | |
| device: str = "cuda", | |
| dtype: torch.dtype | None = torch.bfloat16, | |
| compile: bool = False, | |
| token: str | None = None, | |
| ) -> EchoDiT: | |
| with torch.device("meta"): | |
| model = EchoDiT( | |
| latent_size=80, | |
| model_size=2048, | |
| num_layers=24, | |
| num_heads=16, | |
| intermediate_size=5888, | |
| norm_eps=1e-5, | |
| max_seq_len=640, | |
| text_vocab_size=256, | |
| text_model_size=1280, | |
| text_num_layers=14, | |
| text_num_heads=10, | |
| text_intermediate_size=3328, | |
| text_max_seq_len=768, | |
| speaker_patch_size=4, | |
| speaker_model_size=1280, | |
| speaker_num_layers=14, | |
| speaker_num_heads=10, | |
| speaker_intermediate_size=3328, | |
| speaker_max_patched_seq_len=640, | |
| timestep_embed_size=512, | |
| adaln_rank=256, | |
| ) | |
| w_path = hf_hub_download(repo_id, "pytorch_model.safetensors", token=token) | |
| state = st.load_file(w_path, device="cpu") | |
| if dtype is not None: | |
| state = {k: v.to(dtype=dtype) for k, v in state.items()} | |
| state = {k: v.to(device=device) for k, v in state.items()} | |
| model.load_state_dict(state, strict=False, assign=True) | |
| model = model.eval() | |
| if compile: | |
| model = torch.compile(model) | |
| model.get_kv_cache = torch.compile(model.get_kv_cache) | |
| return model | |
| def load_fish_ae_from_hf( | |
| repo_id: str = "jordand/fish-s1-dac-min", | |
| device: str = "cuda", | |
| dtype: torch.dtype | None = torch.float32, | |
| compile: bool = False, | |
| token: str | None = None, | |
| ) -> DAC: | |
| with torch.device("meta"): | |
| fish_ae = build_ae() | |
| w_path = hf_hub_download(repo_id, "pytorch_model.safetensors", token=token) | |
| state = st.load_file(w_path, device="cpu") | |
| if dtype is not None: | |
| state = {k: v.to(dtype=dtype) for k, v in state.items()} | |
| state = {k: v.to(device=device) for k, v in state.items()} | |
| fish_ae.load_state_dict(state, strict=False, assign=True) | |
| fish_ae = fish_ae.eval().to(device) | |
| if compile: | |
| fish_ae.encoder = torch.compile(fish_ae.encoder) | |
| fish_ae.decoder = torch.compile(fish_ae.decoder) | |
| return fish_ae | |
| # ============================================================ | |
| # PCA | |
| # ============================================================ | |
| class PCAState: | |
| pca_components: torch.Tensor | |
| pca_mean: torch.Tensor | |
| latent_scale: float | |
| def load_pca_state_from_hf( | |
| repo_id: str = "jordand/echo-tts-base", | |
| device: str = "cuda", | |
| filename: str = "pca_state.safetensors", | |
| token: str | None = None, | |
| ) -> PCAState: | |
| p_path = hf_hub_download(repo_id, filename, token=token) | |
| t = st.load_file(p_path, device=device) | |
| return PCAState( | |
| pca_components=t["pca_components"], | |
| pca_mean=t["pca_mean"], | |
| latent_scale=float(t["latent_scale"].item()), | |
| ) | |
| # ============================================================ | |
| # Audio loading (UNCHANGED, SAFE) | |
| # ============================================================ | |
| def load_audio(path: str) -> torch.Tensor: | |
| decoder = AudioDecoder(path) | |
| sr = decoder.metadata.sample_rate | |
| audio = decoder.get_samples_played_in_range(0, 120) | |
| audio = audio.data.mean(dim=0).unsqueeze(0) | |
| audio = torchaudio.functional.resample(audio, sr, 44_100) | |
| audio = audio / torch.maximum(audio.abs().max(), torch.tensor(1.0)) | |
| return audio # (1, T) | |
| # ============================================================ | |
| # Text helpers | |
| # ============================================================ | |
| def tokenizer_encode(text: str, append_bos: bool = True, normalize: bool = True) -> torch.Tensor: | |
| if normalize: | |
| text = ( | |
| text.replace("…", "...") | |
| .replace("“", '"') | |
| .replace("”", '"') | |
| .replace("’", "'") | |
| .replace("\n", " ") | |
| .replace(":", ",") | |
| .replace(";", ",") | |
| ) | |
| b = list(text.encode("utf-8")) | |
| if append_bos: | |
| b.insert(0, 0) | |
| return torch.tensor(b) | |
| def get_text_input_ids_and_mask( | |
| text_arr: List[str], | |
| max_length: int | None, | |
| device: str | None = None, | |
| ): | |
| batch_size = len(text_arr) | |
| if max_length is None: | |
| max_length = max(len(tokenizer_encode(text)) for text in text_arr) | |
| tokens = torch.zeros((batch_size, max_length), dtype=torch.int32) | |
| mask = torch.zeros((batch_size, max_length), dtype=torch.bool) | |
| for i, text in enumerate(text_arr): | |
| encoded = tokenizer_encode(text) | |
| length = min(len(encoded), max_length) | |
| tokens[i, :length] = encoded[:length] | |
| mask[i, :length] = 1 | |
| if device is not None: | |
| tokens = tokens.to(device) | |
| mask = mask.to(device) | |
| return tokens, mask | |
| # ============================================================ | |
| # Autoencoder | |
| # ============================================================ | |
| def ae_encode(fish_ae: DAC, pca_state: PCAState, audio: torch.Tensor) -> torch.Tensor: | |
| z_q = fish_ae.encode_zq(audio).float() | |
| z_q = (z_q.transpose(1, 2) - pca_state.pca_mean) @ pca_state.pca_components.T | |
| return z_q * pca_state.latent_scale | |
| def ae_decode(fish_ae: DAC, pca_state: PCAState, z_q: torch.Tensor) -> torch.Tensor: | |
| z_q = (z_q / pca_state.latent_scale) @ pca_state.pca_components + pca_state.pca_mean | |
| return fish_ae.decode_zq(z_q.transpose(1, 2).to(fish_ae.dtype)).float() | |
| # ============================================================ | |
| # Speaker & Content latents | |
| # ============================================================ | |
| def get_speaker_latent_and_mask( | |
| fish_ae: DAC, | |
| pca_state: PCAState, | |
| audio: torch.Tensor, # (1, T) | |
| max_speaker_latent_len: int = 2560, | |
| ): | |
| AE_DOWNSAMPLE = 2048 | |
| audio = audio[:, : max_speaker_latent_len * AE_DOWNSAMPLE] | |
| z = ae_encode(fish_ae, pca_state, audio.unsqueeze(1)) | |
| length = audio.shape[1] // AE_DOWNSAMPLE | |
| mask = (torch.arange(z.shape[1], device=z.device) < length).unsqueeze(0) | |
| if z.shape[1] < max_speaker_latent_len: | |
| pad = max_speaker_latent_len - z.shape[1] | |
| z = torch.nn.functional.pad(z, (0, 0, 0, pad)) | |
| mask = torch.nn.functional.pad(mask, (0, pad)) | |
| return z, mask | |
| def get_content_latent( | |
| fish_ae: DAC, | |
| pca_state: PCAState, | |
| audio: torch.Tensor, # (1, 1, T) | |
| ): | |
| return ae_encode(fish_ae, pca_state, audio) | |
| # ============================================================ | |
| # TTS pipeline (unchanged signature) | |
| # ============================================================ | |
| def sample_pipeline( | |
| model: EchoDiT, | |
| fish_ae: DAC, | |
| pca_state: PCAState, | |
| sample_fn: SampleFn, | |
| text_prompt: str, | |
| speaker_audio: torch.Tensor | None, | |
| rng_seed: int, | |
| ): | |
| device = model.device | |
| text_ids, text_mask = get_text_input_ids_and_mask([text_prompt], 768, device) | |
| if speaker_audio is None: | |
| speaker_latent = torch.zeros((1, 2560, 80), device=device) | |
| speaker_mask = torch.zeros((1, 2560), device=device, dtype=torch.bool) | |
| else: | |
| speaker_latent, speaker_mask = get_speaker_latent_and_mask( | |
| fish_ae, pca_state, speaker_audio | |
| ) | |
| speaker_latent = speaker_latent.to(device) | |
| speaker_mask = speaker_mask.to(device) | |
| latent = sample_fn( | |
| model, | |
| speaker_latent, | |
| speaker_mask, | |
| text_ids, | |
| text_mask, | |
| rng_seed, | |
| ) | |
| return ae_decode(fish_ae, pca_state, latent) | |
| # ============================================================ | |
| # ✅ Voice Conversion pipeline | |
| # ============================================================ | |
| def voice_conversion_pipeline( | |
| model: EchoDiT, | |
| fish_ae: DAC, | |
| pca_state: PCAState, | |
| sample_fn: SampleFn, | |
| source_audio: torch.Tensor, # (1, 1, T) | |
| target_speaker_audio: torch.Tensor, # (1, T) | |
| rng_seed: int, | |
| force_speaker_kv: bool = False, | |
| ): | |
| device = model.device | |
| speaker_latent, speaker_mask = get_speaker_latent_and_mask( | |
| fish_ae, pca_state, target_speaker_audio | |
| ) | |
| speaker_latent = speaker_latent.to(device) | |
| speaker_mask = speaker_mask.to(device) | |
| content_latent = get_content_latent( | |
| fish_ae, pca_state, source_audio | |
| ).to(device) | |
| text_ids = torch.zeros((1, 1), dtype=torch.int32, device=device) | |
| text_mask = torch.zeros((1, 1), dtype=torch.bool, device=device) | |
| latent = sample_fn( | |
| model, | |
| speaker_latent, | |
| speaker_mask, | |
| text_ids, | |
| text_mask, | |
| rng_seed, | |
| init_latent=content_latent, | |
| force_speaker_kv=force_speaker_kv, | |
| ) | |
| return ae_decode(fish_ae, pca_state, latent) | |