Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| from typing import Dict, Iterator, List, Optional, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import onnxruntime as ort | |
| import phonemizer | |
| import soundfile as sf | |
| from huggingface_hub import hf_hub_download | |
| # --------------------------- | |
| # Utility: tokenization + cleaning | |
| # --------------------------- | |
| _TOKENIZER_RE = re.compile(r"\w+|[^\w\s]") | |
| def basic_english_tokenize(text: str) -> List[str]: | |
| """Simple whitespace + punctuation tokenizer.""" | |
| return _TOKENIZER_RE.findall(text) | |
| class TextCleaner: | |
| """Character-to-index mapper matching the original symbol inventory.""" | |
| def __init__(self) -> None: | |
| _pad = "$" | |
| _punctuation = ';:,.!?¡¿—…"«»"" ' | |
| _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" | |
| _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" | |
| symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) | |
| self._dict: Dict[str, int] = {ch: i for i, ch in enumerate(symbols)} | |
| def __call__(self, text: str) -> List[int]: | |
| # Unknown chars are dropped to mirror original behavior. | |
| return [self._dict[c] for c in text if c in self._dict] | |
| # --------------------------- | |
| # Core model | |
| # --------------------------- | |
| class KittenTTS_1_Onnx: | |
| """ | |
| ONNX-based KittenTTS inference. | |
| Matches the original interface: | |
| - generate(text, voice, speed) -> np.ndarray | |
| - generate_to_file(...) | |
| """ | |
| # Original voice set kept for compatibility. | |
| _DEFAULT_VOICES = [ | |
| "expr-voice-2-m", | |
| "expr-voice-2-f", | |
| "expr-voice-3-m", | |
| "expr-voice-3-f", | |
| "expr-voice-4-m", | |
| "expr-voice-4-f", | |
| "expr-voice-5-m", | |
| "expr-voice-5-f", | |
| ] | |
| def __init__( | |
| self, | |
| model_path: str = "kitten_tts_nano_v0_2.onnx", | |
| voices_path: str = "voices.npz", | |
| providers: Optional[List[str]] = None, | |
| ) -> None: | |
| self.model_path = model_path | |
| self.voices = np.load(voices_path) | |
| self._phonemizer = phonemizer.backend.EspeakBackend( | |
| language="en-us", preserve_punctuation=True, with_stress=True | |
| ) | |
| self._cleaner = TextCleaner() | |
| # Derive available voices from file when possible, else fall back to defaults. | |
| try: | |
| files = list(getattr(self.voices, "files", [])) | |
| except Exception: | |
| files = [] | |
| self.available_voices: List[str] = [ | |
| v for v in self._DEFAULT_VOICES if v in files | |
| ] or (files or self._DEFAULT_VOICES) | |
| # ONNX Runtime session with aggressive graph optimizations. | |
| sess_opt = ort.SessionOptions() | |
| sess_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| # Respect ORT thread env vars when present. Otherwise leave defaults. | |
| # This avoids over-constraining environments like Spaces. | |
| # providers selection | |
| chosen_providers = providers if providers else ["CPUExecutionProvider"] | |
| # Keep only supported providers to avoid runtime errors. | |
| supported = set(ort.get_available_providers()) | |
| chosen_providers = [p for p in chosen_providers if p in supported] or list( | |
| supported | |
| ) | |
| self.session = ort.InferenceSession( | |
| self.model_path, | |
| sess_options=sess_opt, | |
| providers=chosen_providers, | |
| ) | |
| # --- add: max-length detection and per-chunk budget --- | |
| self.max_seq_len = self._infer_max_seq_len() or int(os.getenv("KITTEN_MAX_SEQ_LEN", "512")) | |
| # reserve 2 slots for BOS/EOS tokens inserted below | |
| self._chunk_budget = max(1, self.max_seq_len - 2) | |
| def _prepare_inputs( | |
| self, text: str, voice: str, speed: float | |
| ) -> Dict[str, np.ndarray]: | |
| if voice not in self.available_voices: | |
| raise ValueError( | |
| f"Voice '{voice}' not available. Choose from: {self.available_voices}" | |
| ) | |
| # Phonemize then map to token IDs. | |
| phonemes_list = self._phonemizer.phonemize([text]) | |
| phonemes = " ".join(basic_english_tokenize(phonemes_list[0])) | |
| tokens = self._cleaner(phonemes) | |
| # Start/end tokens as in the original. | |
| tokens.insert(0, 0) | |
| tokens.append(0) | |
| input_ids = np.asarray([tokens], dtype=np.int64) | |
| style_vec = self.voices[voice] | |
| speed_arr = np.asarray([speed], dtype=np.float32) | |
| return {"input_ids": input_ids, "style": style_vec, "speed": speed_arr} | |
| def _infer_max_seq_len(self) -> Optional[int]: | |
| """Try to read positional-embedding length from the ONNX initializers. | |
| Falls back to env var or 512 if unavailable. Optional dependency on 'onnx'. | |
| """ | |
| try: | |
| import onnx # optional | |
| except Exception: | |
| return None | |
| try: | |
| model = onnx.load(self.model_path) | |
| except Exception: | |
| return None | |
| for tensor in model.graph.initializer: | |
| name = tensor.name.lower() | |
| if "position" in name and len(tensor.dims) == 2: | |
| # dims[0] = max positions, dims[1] = hidden dim | |
| return int(tensor.dims[0]) | |
| return None | |
| def _phonemize_to_clean(self, text: str) -> str: | |
| """Phonemize once and keep only characters present in the symbol set.""" | |
| phonemes = self._phonemizer.phonemize([text])[0] | |
| token_str = " ".join(basic_english_tokenize(phonemes)) | |
| # keep only symbols known to the TextCleaner | |
| return "".join(c for c in token_str if c in self._cleaner._dict) | |
| def _run_onnx(self, token_ids: List[int], voice: str, speed: float) -> np.ndarray: | |
| """One inference call with trimming identical to original behavior.""" | |
| input_ids = np.asarray([token_ids], dtype=np.int64) | |
| style_vec = self.voices[voice] | |
| speed_arr = np.asarray([speed], dtype=np.float32) | |
| outputs = self.session.run(None, {"input_ids": input_ids, "style": style_vec, "speed": speed_arr}) | |
| audio = np.asarray(outputs[0], dtype=np.float32) | |
| if audio.size > 15000: | |
| audio = audio[5000:-10000] | |
| return audio | |
| def _chunk_token_ids(self, clean: str) -> Iterator[List[int]]: | |
| """Yield BOS/segment/EOS token-id sequences within model capacity.""" | |
| n = len(clean) | |
| i = 0 | |
| while i < n: | |
| j = min(i + self._chunk_budget, n) | |
| # prefer to cut at a space when possible, to keep phrasing natural | |
| cut = clean.rfind(" ", i, j) | |
| if cut != -1 and cut > i + int(0.6 * self._chunk_budget): | |
| j = cut + 1 # include the space | |
| seg = clean[i:j] | |
| ids = self._cleaner(seg) # segment ids | |
| ids.insert(0, 0) # BOS | |
| ids.append(0) # EOS | |
| yield ids | |
| i = j | |
| def generate(self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0) -> np.ndarray: | |
| """Synthesize speech with automatic chunking at the model's max length.""" | |
| if voice not in self.available_voices: | |
| raise ValueError(f"Voice '{voice}' not available. Choose from: {self.available_voices}") | |
| # Phonemize once, then either run single-shot or chunked | |
| clean = self._phonemize_to_clean(text) | |
| # Fast path: fits in one pass | |
| if len(clean) + 2 <= self.max_seq_len: | |
| ids = self._cleaner(clean) | |
| ids.insert(0, 0) # BOS | |
| ids.append(0) # EOS | |
| return self._run_onnx(ids, voice, speed) | |
| # Chunked path: concatenate per-chunk audio | |
| pieces: List[np.ndarray] = [] | |
| for ids in self._chunk_token_ids(clean): | |
| pieces.append(self._run_onnx(ids, voice, speed)) | |
| if not pieces: | |
| return np.array([], dtype=np.float32) | |
| return pieces[0] if len(pieces) == 1 else np.concatenate(pieces) | |
| def generate_to_file( | |
| self, | |
| text: str, | |
| output_path: str, | |
| voice: str = "expr-voice-5-m", | |
| speed: float = 1.0, | |
| sample_rate: int = 24000, | |
| ) -> None: | |
| audio = self.generate(text, voice, speed) | |
| sf.write(output_path, audio, sample_rate) | |
| # --------------------------- | |
| # HF download wrapper (consolidated) | |
| # --------------------------- | |
| class KittenTTS: | |
| """High-level wrapper that fetches model assets from Hugging Face.""" | |
| def __init__( | |
| self, | |
| model_name: str = "KittenML/kitten-tts-nano-0.2", | |
| cache_dir: Optional[str] = None, | |
| providers: Optional[List[str]] = None, | |
| ) -> None: | |
| repo_id = model_name if "/" in model_name else f"KittenML/{model_name}" | |
| self._model = download_from_huggingface( | |
| repo_id=repo_id, cache_dir=cache_dir, providers=providers | |
| ) | |
| def generate( | |
| self, text: str, voice: str = "expr-voice-5-m", speed: float = 1.0 | |
| ) -> np.ndarray: | |
| return self._model.generate(text, voice=voice, speed=speed) | |
| def generate_to_file( | |
| self, | |
| text: str, | |
| output_path: str, | |
| voice: str = "expr-voice-5-m", | |
| speed: float = 1.0, | |
| sample_rate: int = 24000, | |
| ) -> None: | |
| return self._model.generate_to_file( | |
| text, output_path, voice=voice, speed=speed, sample_rate=sample_rate | |
| ) | |
| def available_voices(self) -> List[str]: | |
| return self._model.available_voices | |
| def download_from_huggingface( | |
| repo_id: str = "KittenML/kitten-tts-nano-0.2", | |
| cache_dir: Optional[str] = None, | |
| providers: Optional[List[str]] = None, | |
| ) -> KittenTTS_1_Onnx: | |
| """ | |
| Download config, model, and voices. Instantiate ONNX model. | |
| """ | |
| config_path = hf_hub_download( | |
| repo_id=repo_id, filename="config.json", cache_dir=cache_dir | |
| ) | |
| with open(config_path, "r", encoding="utf-8") as f: | |
| config = json.load(f) | |
| if config.get("type") != "ONNX1": | |
| raise ValueError("Unsupported model type in config.json.") | |
| model_path = hf_hub_download( | |
| repo_id=repo_id, filename=config["model_file"], cache_dir=cache_dir | |
| ) | |
| voices_path = hf_hub_download( | |
| repo_id=repo_id, filename=config["voices"], cache_dir=cache_dir | |
| ) | |
| return KittenTTS_1_Onnx( | |
| model_path=model_path, voices_path=voices_path, providers=providers | |
| ) | |
| def get_model( | |
| repo_id: str = "KittenML/kitten-tts-nano-0.2", cache_dir: Optional[str] = None | |
| ) -> KittenTTS: | |
| """Backward-compatible alias.""" | |
| return KittenTTS(repo_id, cache_dir) | |
| # --------------------------- | |
| # Gradio app | |
| # --------------------------- | |
| # Allow overriding model repo and providers via env on Spaces. | |
| _MODEL_REPO = os.getenv("MODEL_REPO", "KittenML/kitten-tts-nano-0.2") | |
| # Use CPU by default on Spaces; adjust if GPU EPs are available. | |
| _DEFAULT_PROVIDERS = os.getenv("ORT_PROVIDERS", "CPUExecutionProvider").split(",") | |
| # Single global instance for efficiency. | |
| _TTS = KittenTTS(_MODEL_REPO, providers=_DEFAULT_PROVIDERS) | |
| def _synthesize(text: str, voice: str, speed: float) -> Tuple[int, np.ndarray]: | |
| if not text or not text.strip(): | |
| raise gr.Error("Please enter text.") | |
| audio = _TTS.generate(text, voice=voice, speed=speed) | |
| # Gradio expects (sample_rate, np.ndarray[float32]) | |
| return 24000, audio.astype(np.float32, copy=False) | |
| with gr.Blocks(title="Kitten TTS Nano 0.2 😻") as demo: | |
| gr.Markdown("# Kitten TTS Nano 0.2 😻\nText-to-Speech using ONNX on CPU") | |
| with gr.Row(): | |
| inp_text = gr.Textbox( | |
| label="Text", | |
| lines=6, | |
| placeholder='Type something like: "The quick brown fox jumps over the lazy dog."', | |
| ) | |
| with gr.Row(): | |
| voice = gr.Dropdown( | |
| label="Voice", | |
| choices=_TTS.available_voices, | |
| value="expr-voice-5-m" | |
| if "expr-voice-5-m" in _TTS.available_voices | |
| else _TTS.available_voices[0], | |
| ) | |
| speed = gr.Slider(minimum=0.5, maximum=1.5, step=0.05, value=1.0, label="Speed") | |
| out_audio = gr.Audio(label="Output Audio", type="numpy") | |
| btn = gr.Button("Generate") | |
| btn.click(_synthesize, inputs=[inp_text, voice, speed], outputs=out_audio) | |
| gr.Examples( | |
| examples=[ | |
| ["Hello from KittenTTS Nano.", "expr-voice-5-m", 1.0], | |
| [ | |
| "It begins with an Ugh. Another mysterious stain appears on a favorite shirt.", | |
| "expr-voice-2-f", | |
| 1.0, | |
| ], | |
| ], | |
| inputs=[inp_text, voice, speed], | |
| ) | |
| # Spaces will auto-run app.py. Local dev can still call launch(). | |
| if __name__ == "__main__": | |
| demo.launch() | |