TTS-Demo / app.py
CVNSS's picture
Update app.py
491db3e verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CVNSS4.0 Vietnamese TTS Studio
- Architecture: Modular CSS & Component Separation
- UX: High Contrast Input Fields
- Core: Optimized Logic Flow
"""
import os
import sys
import json
import time
import glob
import re
import hashlib
import tempfile
from pathlib import Path
import torch
import numpy as np
import soundfile as sf
import gradio as gr
# Add src to path for imports
sys.path.insert(0, str(Path(__file__).parent))
# Import core modules
try:
from src.vietnamese.text_processor import process_vietnamese_text
from src.vietnamese.phonemizer import text_to_phonemes, VIPHONEME_AVAILABLE
from src.models.synthesizer import SynthesizerTrn
from src.text.symbols import symbols
except ImportError:
# Fallback for environment setup if src is missing during init
print("⚠️ Core modules not found. Ensure 'src' directory exists.")
VIPHONEME_AVAILABLE = False
symbols = []
# =========================================================
# 1) SYSTEM CONFIGURATION & CSS (The Expert Layer)
# =========================================================
# Expert CSS: Definitive Z-Index Management & Neon Theme
NEON_CSS = r"""
:root {
--bg-dark: #0f172a;
--bg-panel: rgba(30, 41, 59, 0.7);
--line: rgba(148, 163, 184, 0.1);
--text-primary: #e2e8f0;
--neon-cyan: #06b6d4;
--neon-accent: #38bdf8;
--radius-lg: 16px;
--radius-sm: 8px;
/* UX Color Palette for Inputs */
--input-bg: #f1f5f9; /* Light Blue-Grey for readability */
--input-text: #0f4c81; /* Classic Blue (Dark Blue) for high contrast */
--input-placeholder: #64748b;
}
body, .gradio-container, .app {
background: radial-gradient(circle at 50% 0%, #1e293b 0%, #0f172a 100%) !important;
color: var(--text-primary) !important;
font-family: 'Inter', 'Segoe UI', sans-serif;
}
/* --- ISOLATION FULL: CVNSS4.0 Vietnamese TTS Studio --- */
.panelNeon {
border: 1px solid rgba(255,255,255,0.08);
border-radius: var(--radius-lg);
background: var(--bg-panel);
backdrop-filter: blur(12px);
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06);
padding: 20px;
position: relative;
isolation: isolate;
z-index: 1;
margin-bottom: 20px;
}
/* UX IMPROVEMENT: High Contrast Input Styling */
.panelNeon textarea, .panelNeon input[type="text"] {
background: var(--input-bg) !important;
color: var(--input-text) !important; /* DARK BLUE TEXT requested */
border: 2px solid transparent !important;
border-radius: var(--radius-sm) !important;
font-weight: 500 !important;
font-size: 1rem !important;
line-height: 1.5 !important;
padding: 12px !important;
transition: all 0.2s ease;
z-index: 10 !important;
position: relative !important;
}
.panelNeon textarea::placeholder {
color: var(--input-placeholder) !important;
}
.panelNeon textarea:focus, .panelNeon input:focus {
background: #ffffff !important;
border-color: var(--neon-cyan) !important;
box-shadow: 0 0 0 4px rgba(6, 182, 212, 0.15) !important;
color: #000000 !important; /* Even darker on focus */
}
/* Label Styling */
.panelNeon label span {
color: var(--neon-accent) !important;
font-weight: 600;
font-size: 0.85rem;
text-transform: uppercase;
letter-spacing: 0.05em;
margin-bottom: 8px;
display: block;
}
/* Dropdown & Slider fixes */
.panelNeon .wrap, .panelNeon .range-compact {
z-index: 10 !important;
}
/* Button Upgrades */
button.primary, .gr-button-primary {
background: linear-gradient(135deg, #06b6d4 0%, #3b82f6 100%) !important;
border: none !important;
color: white !important;
font-weight: 700 !important;
transition: transform 0.1s ease, box-shadow 0.2s ease;
}
button.primary:hover, .gr-button-primary:hover {
box-shadow: 0 10px 15px -3px rgba(6, 182, 212, 0.3) !important;
transform: translateY(-1px);
}
button.primary:active {
transform: translateY(0px);
}
/* Status Panel */
.statusCard {
background: rgba(15, 23, 42, 0.6);
border-radius: var(--radius-sm);
padding: 16px;
border: 1px solid rgba(255,255,255,0.05);
}
.pill {
display: inline-flex;
align-items: center;
padding: 4px 12px;
border-radius: 99px;
background: rgba(56, 189, 248, 0.1);
color: #38bdf8;
border: 1px solid rgba(56, 189, 248, 0.2);
font-size: 0.8rem;
font-weight: 600;
margin-right: 6px;
margin-bottom: 6px;
}
.alert { padding: 12px; border-radius: 8px; margin-top: 12px; font-size: 0.9rem; font-weight: 500; display: flex; align-items: center; gap: 8px;}
.alertOk { background: rgba(34, 197, 94, 0.1); color: #4ade80; border: 1px solid rgba(34, 197, 94, 0.2); }
.alertWarn { background: rgba(234, 179, 8, 0.1); color: #facc15; border: 1px solid rgba(234, 179, 8, 0.2); }
"""
# =========================================================
# 2) UTILITIES & HELPERS
# =========================================================
def check_viphoneme():
if not VIPHONEME_AVAILABLE:
print("⚠️ Viphoneme not available.")
return False
try:
phones, _, _ = text_to_phonemes("Test", use_viphoneme=True)
print("✅ Viphoneme active.")
return True
except Exception as e:
print(f"❌ Viphoneme error: {e}")
return False
def md5_key(*parts: str) -> str:
return hashlib.md5("|".join(parts).encode("utf-8")).hexdigest()
def split_sentences_vi(text: str, max_chars: int):
# Improved splitting logic
if not text: return []
text = re.sub(r'\s+', ' ', text).strip()
# Split by delimiters keeping delimiters
parts = re.split(r'([.?!;:])', text)
chunks = []
current_chunk = ""
for i in range(0, len(parts) - 1, 2):
sentence = parts[i] + parts[i+1]
if len(current_chunk) + len(sentence) <= max_chars:
current_chunk += sentence
else:
if current_chunk: chunks.append(current_chunk.strip())
current_chunk = sentence
if len(parts) % 2 != 0 and parts[-1]:
sentence = parts[-1]
if len(current_chunk) + len(sentence) <= max_chars:
current_chunk += sentence
else:
if current_chunk: chunks.append(current_chunk.strip())
current_chunk = sentence
if current_chunk: chunks.append(current_chunk.strip())
return chunks
# =========================================================
# 3) CORE ENGINE WRAPPER
# =========================================================
class TTSManager:
"""Singleton-like manager for TTS operations."""
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🔧 Initializing TTS on {self.device}...")
self.model_dir = self._get_model_dir()
self.ckpt_path = find_latest_checkpoint(self.model_dir, "G")
self.cfg_path = os.path.join(self.model_dir, "config.json")
if not self.ckpt_path:
raise FileNotFoundError(f"No checkpoint found in {self.model_dir}")
self.tts = VietnameseTTS(self.ckpt_path, self.cfg_path, self.device)
self.temp_dir = Path(tempfile.gettempdir()) / "neon_tts_cache"
self.temp_dir.mkdir(parents=True, exist_ok=True)
def _get_model_dir(self):
return download_model()
def synthesize(self, text, speaker, speed, noise_scale, noise_scale_w, sdp_ratio):
try:
if not text or not text.strip():
return None, "⚠️ Empty input"
key = md5_key(speaker, f"{speed:.2f}", text[:20], str(len(text)))
out_path = self.temp_dir / f"{key}.wav"
if out_path.exists():
return str(out_path), "✅ Cached (From history)"
audio, sr = self.tts.synthesize(
text=text, speaker=speaker, length_scale=speed,
noise_scale=noise_scale, noise_scale_w=noise_scale_w, sdp_ratio=sdp_ratio
)
sf.write(str(out_path), audio, sr)
return str(out_path), "✅ Generated successfully"
except Exception as e:
# Capture full traceback if needed, but return clean msg
return None, f"❌ Error: {str(e)}"
# =========================================================
# 4) MODEL LOGIC (PRESERVED & FIXED)
# =========================================================
def find_latest_checkpoint(model_dir, prefix="G"):
pattern = os.path.join(model_dir, f"{prefix}*.pth")
checkpoints = glob.glob(pattern)
if not checkpoints: return None
checkpoints.sort(key=lambda x: int(re.search(rf"{prefix}(\d+)\.pth", x).group(1)) if re.search(rf"{prefix}(\d+)\.pth", x) else 0, reverse=True)
return checkpoints[0]
def download_model():
from huggingface_hub import snapshot_download
hf_repo = "valtecAI-team/valtec-tts-pretrained"
cache_base = Path(os.environ.get("XDG_CACHE_HOME", Path.home() / ".cache"))
if os.name == "nt": cache_base = Path(os.environ.get("LOCALAPPDATA", Path.home() / "AppData" / "Local"))
model_dir = cache_base / "valtec_tts" / "models" / "vits-vietnamese"
if (model_dir / "config.json").exists() and list(model_dir.glob("G_*.pth")):
return str(model_dir)
print(f"⬇️ Downloading {hf_repo}...")
snapshot_download(repo_id=hf_repo, local_dir=str(model_dir))
return str(model_dir)
class VietnameseTTS:
def __init__(self, ckpt, cfg, device="cpu"):
self.device = device
with open(cfg, "r", encoding="utf-8") as f: self.config = json.load(f)
self.spk2id = self.config["data"]["spk2id"]
self.speakers = list(self.spk2id.keys())
self._load(ckpt)
def _load(self, ckpt):
self.model = SynthesizerTrn(
len(symbols),
self.config["data"]["filter_length"] // 2 + 1,
self.config["train"]["segment_size"] // self.config["data"]["hop_length"],
n_speakers=self.config["data"]["n_speakers"],
**self.config["model"]
).to(self.device)
state = torch.load(ckpt, map_location=self.device)["model"]
self.model.load_state_dict({k.replace("module.", ""): v for k,v in state.items()}, strict=False)
self.model.eval()
def synthesize(self, text, speaker, **kwargs):
from src.text import cleaned_text_to_sequence
from src.nn import commons
# 1. Text Processing
norm_text = process_vietnamese_text(text)
phones, tones, _ = text_to_phonemes(norm_text, use_viphoneme=VIPHONEME_AVAILABLE)
phone_ids, tone_ids, lang_ids = cleaned_text_to_sequence(phones, tones, "VI")
phone_ids = commons.intersperse(phone_ids, 0)
tone_ids = commons.intersperse(tone_ids, 0)
lang_ids = commons.intersperse(lang_ids, 0)
# 2. Prepare Tensors
x = torch.LongTensor(phone_ids).unsqueeze(0).to(self.device)
x_len = torch.LongTensor([len(phone_ids)]).to(self.device)
tone = torch.LongTensor(tone_ids).unsqueeze(0).to(self.device)
lang = torch.LongTensor(lang_ids).unsqueeze(0).to(self.device)
sid = torch.LongTensor([self.spk2id.get(speaker, 0)]).to(self.device)
# 3. Inference with Gradient Safety (FIX IS HERE)
with torch.no_grad():
bert = torch.zeros(1024, len(phone_ids)).unsqueeze(0).to(self.device)
ja_bert = torch.zeros(768, len(phone_ids)).unsqueeze(0).to(self.device)
# Run inference
# The error "Can't call numpy() on Tensor that requires grad" means output has grad_fn.
# We use .detach() before .cpu() to ensure the graph is cut.
outputs = self.model.infer(
x, x_len, sid, tone, lang,
bert, ja_bert,
**kwargs
)
audio = outputs[0][0,0].detach().cpu().numpy()
return audio, self.config["data"]["sampling_rate"]
# =========================================================
# 5) UI CONSTRUCTION (REFACTORED)
# =========================================================
def create_ui(manager: TTSManager):
def ui_header():
return gr.HTML("""
<div style="border-bottom: 1px solid rgba(255,255,255,0.08); padding-bottom: 20px; margin-bottom: 25px;">
<h1 style="color: #38bdf8; margin:0; font-weight:800; font-size: 2rem; letter-spacing: -0.02em;">
🎛️ CVNSS4.0 Vietnamese TTS Studio
</h1>
<div style="color: #94a3b8; font-size: 1rem; margin-top: 5px; font-weight: 400;">
Công nghệ tích hợp giọng nói AI tiên tiến • Phiên bản 1.0.0 Demo • Dự án mã nguồn mở
</div>
</div>
""")
def ui_status_render(text, speaker, speed, chunks, dur, msg):
return f"""
<div class="statusCard">
<div style="margin-bottom:12px; font-weight:700; color:#38bdf8; font-size: 0.9rem; text-transform: uppercase;">
📟 Trạng thái hoạt động
</div>
<div style="display:flex; flex-wrap:wrap; gap:8px;">
<span class="pill">🎤 {speaker}</span>
<span class="pill">⚡ {speed}x</span>
<span class="pill">📄 {len(text)} ký tự</span>
<span class="pill">🧩 {chunks} đoạn</span>
</div>
<div class="alert {'alertOk' if '✅' in msg else 'alertWarn'}">
{msg}
</div>
</div>
"""
with gr.Blocks(theme=gr.themes.Base(), css=NEON_CSS, title="Neon TTS Expert") as app:
ui_header()
with gr.Tabs():
# --- TAB BASIC ---
with gr.Tab("⚡ Chế độ Nhanh"):
with gr.Row():
# INPUT COLUMN
with gr.Column(scale=2):
# REFACTOR: Using a specific ID for the container to target with CSS isolation
with gr.Group(elem_classes=["panelNeon"], elem_id="input-panel-basic"):
gr.HTML('<div class="panelTitle">📝 Văn bản đầu vào</div>')
# THE FIX: Pure Textbox with updated styling (Dark Blue text)
txt_basic = gr.Textbox(
label="",
show_label=False,
placeholder="Nhập nội dung tiếng Việt vào... (Ví dụ: Xin chào, bạn đã học qua CVNSS4.0 chưa?)",
lines=6,
elem_id="main-input-basic"
)
with gr.Row():
spk_basic = gr.Dropdown(choices=manager.tts.speakers, value=manager.tts.speakers[0], label="Giọng đọc")
spd_basic = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Tốc độ đọc")
btn_basic = gr.Button("🔊 Đọc ngay", variant="primary")
# OUTPUT COLUMN
with gr.Column(scale=1):
with gr.Group(elem_classes=["panelNeon"]):
gr.HTML('<div class="panelTitle">🔊 Kết quả</div>')
out_audio_basic = gr.Audio(label="Audio Output", type="filepath", interactive=False)
out_status_basic = gr.HTML()
# --- TAB ADVANCED ---
with gr.Tab("🧠 Chế độ Chuyên sâu"):
with gr.Row():
with gr.Column(scale=2):
with gr.Group(elem_classes=["panelNeon"], elem_id="input-panel-adv"):
gr.HTML('<div class="panelTitle">📝 Xử lý văn bản dài</div>')
txt_adv = gr.Textbox(
label="",
show_label=False,
lines=8,
placeholder="Nhập văn bản dài cần ngắt câu tự động...",
elem_id="main-input-adv"
)
gr.HTML('<div style="height:15px"></div>')
gr.HTML('<div class="panelTitle">🎚️ Tham số âm thanh</div>')
with gr.Row():
ns = gr.Slider(0.1, 1.5, 0.667, step=0.01, label="Noise Scale (Độ biến thiên)")
nsw = gr.Slider(0.1, 1.5, 0.8, step=0.01, label="Duration Noise (Độ động)")
with gr.Row():
sdp = gr.Slider(0.0, 1.0, 0.2, step=0.1, label="SDP Ratio (Ngẫu nhiên)")
max_chars = gr.Slider(50, 500, 300, step=10, label="Ngắt đoạn (ký tự)")
pause = gr.Slider(0, 1000, 250, step=50, label="Nghỉ câu (ms)")
btn_adv = gr.Button("🧠 Xử lý & Ghép nối", variant="primary")
with gr.Column(scale=1):
with gr.Group(elem_classes=["panelNeon"]):
out_audio_adv = gr.Audio(label="Merged Audio", type="filepath", interactive=False)
out_status_adv = gr.HTML()
# --- LOGIC BINDING ---
def run_basic(text, spk, spd):
path, msg = manager.synthesize(text, spk, spd, 0.667, 0.8, 0.2)
html = ui_status_render(text, spk, spd, 1, 0, msg)
return path, html
def run_adv(text, spk, spd, ns, nsw, sdp, mc, p, progress=gr.Progress()):
chunks = split_sentences_vi(text, int(mc))
audios = []
sr = 44100
for i, chunk in enumerate(chunks):
progress((i)/len(chunks), desc=f"Đang xử lý đoạn {i+1}/{len(chunks)}")
path, _ = manager.synthesize(chunk, spk, spd, ns, nsw, sdp)
if path:
data, rate = sf.read(path)
audios.append(data)
sr = rate
if p > 0:
audios.append(np.zeros(int(rate * p/1000)))
if not audios: return None, "❌ Không tạo được âm thanh"
full_audio = np.concatenate(audios)
out_path = manager.temp_dir / f"merged_{int(time.time())}.wav"
sf.write(str(out_path), full_audio, sr)
html = ui_status_render(text, spk, spd, len(chunks), len(full_audio)/sr, "✅ Đã ghép nối thành công")
return str(out_path), html
btn_basic.click(run_basic, [txt_basic, spk_basic, spd_basic], [out_audio_basic, out_status_basic])
btn_adv.click(run_adv, [txt_adv, spk_basic, spd_basic, ns, nsw, sdp, max_chars, pause], [out_audio_adv, out_status_adv])
return app
# =========================================================
# 6) ENTRY POINT
# =========================================================
if __name__ == "__main__":
print("🚀 Starting Expert Neon TTS...")
# Check dependencies
check_viphoneme()
# Init Manager & UI
try:
tts_manager = TTSManager()
app = create_ui(tts_manager)
port = int(os.environ.get("PORT", "7860"))
app.queue(max_size=10).launch(server_name="0.0.0.0", server_port=port)
except Exception as e:
print(f"❌ Fatal Error: {e}")