Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,24 +12,21 @@ import nltk
|
|
| 12 |
from pydub import AudioSegment
|
| 13 |
|
| 14 |
# --- App Setup ---
|
| 15 |
-
app = FastAPI(title="Kitten TTS API", version="0.2.
|
| 16 |
|
| 17 |
# --- Model & Tokenizer Loading ---
|
| 18 |
-
# Download NLTK data (one-time)
|
| 19 |
try:
|
| 20 |
nltk.data.find('tokenizers/punkt')
|
| 21 |
except LookupError:
|
| 22 |
nltk.download('punkt')
|
| 23 |
|
| 24 |
-
# Load the TTS model once at startup
|
| 25 |
print("Loading KittenTTS model...")
|
| 26 |
model = KittenTTS("KittenML/kitten-tts-nano-0.2")
|
| 27 |
print("Model loaded.")
|
| 28 |
|
| 29 |
-
# List of available voices
|
| 30 |
voices = [
|
| 31 |
-
"expr-voice-
|
| 32 |
-
"expr-voice-
|
| 33 |
]
|
| 34 |
|
| 35 |
# --- Request Models ---
|
|
@@ -37,14 +34,42 @@ class SpeechRequest(BaseModel):
|
|
| 37 |
input: str
|
| 38 |
model: str = "kitten-nano-0.2"
|
| 39 |
voice: str = "expr-voice-1-f"
|
| 40 |
-
speed: Optional[float] = 1.0
|
| 41 |
-
response_format: str = "mp3"
|
| 42 |
|
| 43 |
-
# ---
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
sentences = nltk.sent_tokenize(text)
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
# --- Streaming Generator Logic ---
|
| 50 |
async def audio_stream_generator(text: str, voice: str) -> AsyncGenerator[bytes, None]:
|
|
@@ -56,15 +81,6 @@ async def audio_stream_generator(text: str, voice: str) -> AsyncGenerator[bytes,
|
|
| 56 |
yield b""
|
| 57 |
return
|
| 58 |
|
| 59 |
-
# 1. Define FFmpeg command for on-the-fly conversion
|
| 60 |
-
# -f s16le: Input is signed 16-bit little-endian (raw PCM from WAV)
|
| 61 |
-
# -ar 24000: KittenTTS sample rate is 24kHz
|
| 62 |
-
# -ac 1: Mono audio
|
| 63 |
-
# -i pipe:0: Read from stdin
|
| 64 |
-
# -f mp3: Output format is MP3
|
| 65 |
-
# -b:a 96k: 96 kbps bitrate (good for speech)
|
| 66 |
-
# -ar 44100: Resample to 44.1kHz for better compatibility (optional)
|
| 67 |
-
# pipe:1: Write to stdout
|
| 68 |
ffmpeg_command = [
|
| 69 |
"ffmpeg",
|
| 70 |
"-f", "s16le",
|
|
@@ -77,17 +93,14 @@ async def audio_stream_generator(text: str, voice: str) -> AsyncGenerator[bytes,
|
|
| 77 |
"pipe:1"
|
| 78 |
]
|
| 79 |
|
| 80 |
-
# 2. Start the FFmpeg subprocess
|
| 81 |
process = await asyncio.create_subprocess_exec(
|
| 82 |
*ffmpeg_command,
|
| 83 |
stdin=subprocess.PIPE,
|
| 84 |
stdout=subprocess.PIPE,
|
| 85 |
-
stderr=subprocess.PIPE
|
| 86 |
)
|
| 87 |
|
| 88 |
try:
|
| 89 |
-
# 3. Create an asyncio task to read from stdout (non-blocking)
|
| 90 |
-
# This task will yield MP3 chunks as FFmpeg produces them.
|
| 91 |
async def read_stdout():
|
| 92 |
while True:
|
| 93 |
chunk = await process.stdout.read(4096)
|
|
@@ -97,53 +110,50 @@ async def audio_stream_generator(text: str, voice: str) -> AsyncGenerator[bytes,
|
|
| 97 |
|
| 98 |
stdout_reader = read_stdout()
|
| 99 |
|
| 100 |
-
# 4. Generate and feed WAV data into FFmpeg's stdin
|
| 101 |
async def feed_stdin():
|
| 102 |
try:
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
-
# Use pydub to strip the WAV header and get raw PCM data
|
| 109 |
wav_audio = AudioSegment.from_wav(io.BytesIO(wav_bytes))
|
| 110 |
raw_pcm_data = wav_audio.raw_data
|
| 111 |
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
| 113 |
process.stdin.write(raw_pcm_data)
|
| 114 |
await process.stdin.drain()
|
| 115 |
|
| 116 |
-
# Small sleep to allow generation to be non-blocking
|
| 117 |
await asyncio.sleep(0.01)
|
| 118 |
|
| 119 |
except Exception as e:
|
| 120 |
print(f"Error in feed_stdin: {e}")
|
| 121 |
finally:
|
| 122 |
-
|
| 123 |
-
if not process.stdin.is_closing():
|
| 124 |
process.stdin.close()
|
| 125 |
await process.stdin.wait_closed()
|
| 126 |
|
| 127 |
-
# 5. Run the feeder and reader tasks concurrently
|
| 128 |
feeder_task = asyncio.create_task(feed_stdin())
|
| 129 |
|
| 130 |
async for mp3_chunk in stdout_reader:
|
| 131 |
yield mp3_chunk
|
| 132 |
|
| 133 |
-
# Wait for feeder to finish
|
| 134 |
await feeder_task
|
| 135 |
|
| 136 |
-
# Wait for FFmpeg to finish
|
| 137 |
await process.wait()
|
| 138 |
|
| 139 |
-
# Check for FFmpeg errors
|
| 140 |
stderr_data = await process.stderr.read()
|
| 141 |
if stderr_data:
|
| 142 |
print(f"FFmpeg stderr: {stderr_data.decode()}")
|
| 143 |
|
| 144 |
except Exception as e:
|
| 145 |
print(f"Streaming generator error: {e}")
|
| 146 |
-
# Clean up process if it's still running
|
| 147 |
if process.returncode is None:
|
| 148 |
process.terminate()
|
| 149 |
await process.wait()
|
|
@@ -153,19 +163,13 @@ async def audio_stream_generator(text: str, voice: str) -> AsyncGenerator[bytes,
|
|
| 153 |
# --- API Endpoints ---
|
| 154 |
@app.post("/v1/audio/speech")
|
| 155 |
async def generate_speech(request: SpeechRequest):
|
| 156 |
-
"""
|
| 157 |
-
Generates speech audio, streaming the response.
|
| 158 |
-
Supports 'mp3' (streaming) and 'wav' (blocking, file-stream).
|
| 159 |
-
"""
|
| 160 |
-
# Validation
|
| 161 |
if request.voice not in voices:
|
| 162 |
raise HTTPException(status_code=400, detail=f"Voice must be one of {voices}")
|
| 163 |
-
if len(request.input) > 2000:
|
| 164 |
raise HTTPException(status_code=413, detail="Input text too long; max 2000 chars")
|
| 165 |
|
| 166 |
try:
|
| 167 |
if request.response_format == "mp3":
|
| 168 |
-
# The new TRUE-STREAMING MP3 path
|
| 169 |
return StreamingResponse(
|
| 170 |
audio_stream_generator(text=request.input, voice=request.voice),
|
| 171 |
media_type="audio/mpeg",
|
|
@@ -184,7 +188,7 @@ async def generate_speech(request: SpeechRequest):
|
|
| 184 |
return StreamingResponse(
|
| 185 |
iter_bytes(),
|
| 186 |
media_type="audio/wav",
|
| 187 |
-
headers={"Content-
|
| 188 |
)
|
| 189 |
|
| 190 |
else:
|
|
@@ -200,7 +204,7 @@ async def list_voices():
|
|
| 200 |
|
| 201 |
@app.get("/")
|
| 202 |
async def root():
|
| 203 |
-
return {"message": "Kitten TTS API v0.2.
|
| 204 |
|
| 205 |
if __name__ == "__main__":
|
| 206 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 12 |
from pydub import AudioSegment
|
| 13 |
|
| 14 |
# --- App Setup ---
|
| 15 |
+
app = FastAPI(title="Kitten TTS API", version="0.2.1") # Version bump
|
| 16 |
|
| 17 |
# --- Model & Tokenizer Loading ---
|
|
|
|
| 18 |
try:
|
| 19 |
nltk.data.find('tokenizers/punkt')
|
| 20 |
except LookupError:
|
| 21 |
nltk.download('punkt')
|
| 22 |
|
|
|
|
| 23 |
print("Loading KittenTTS model...")
|
| 24 |
model = KittenTTS("KittenML/kitten-tts-nano-0.2")
|
| 25 |
print("Model loaded.")
|
| 26 |
|
|
|
|
| 27 |
voices = [
|
| 28 |
+
"expr-voice-1-f", "expr-voice-2-m", "expr-voice-3-f", "expr-voice-4-m",
|
| 29 |
+
"expr-voice-5-f", "expr-voice-6-m", "expr-voice-7-f", "expr-voice-8-m"
|
| 30 |
]
|
| 31 |
|
| 32 |
# --- Request Models ---
|
|
|
|
| 34 |
input: str
|
| 35 |
model: str = "kitten-nano-0.2"
|
| 36 |
voice: str = "expr-voice-1-f"
|
| 37 |
+
speed: Optional[float] = 1.0
|
| 38 |
+
response_format: str = "mp3"
|
| 39 |
|
| 40 |
+
# --- --- --- --- --- --- --- --- --- ---
|
| 41 |
+
# --- THIS IS THE FIX ---
|
| 42 |
+
# --- --- --- --- --- --- --- --- --- ---
|
| 43 |
+
def get_text_batches(text: str, min_batch_chars: int = 150):
|
| 44 |
+
"""
|
| 45 |
+
Joins small NLTK sentences into larger, 'audio-safe' batches.
|
| 46 |
+
This prevents sending tiny chunks to ffmpeg, which fails to encode.
|
| 47 |
+
"""
|
| 48 |
sentences = nltk.sent_tokenize(text)
|
| 49 |
+
current_batch = ""
|
| 50 |
+
|
| 51 |
+
for sentence in sentences:
|
| 52 |
+
if not sentence.strip():
|
| 53 |
+
continue
|
| 54 |
+
|
| 55 |
+
# Add the sentence to the current batch
|
| 56 |
+
if current_batch:
|
| 57 |
+
current_batch += " " + sentence
|
| 58 |
+
else:
|
| 59 |
+
current_batch = sentence
|
| 60 |
+
|
| 61 |
+
# If the batch is large enough, yield it
|
| 62 |
+
if len(current_batch) >= min_batch_chars:
|
| 63 |
+
yield current_batch
|
| 64 |
+
current_batch = ""
|
| 65 |
+
|
| 66 |
+
# Yield any remaining text in the last batch
|
| 67 |
+
if current_batch.strip():
|
| 68 |
+
yield current_batch
|
| 69 |
+
# --- --- --- --- --- --- --- --- --- ---
|
| 70 |
+
# --- END OF FIX ---
|
| 71 |
+
# --- --- --- --- --- --- --- --- --- ---
|
| 72 |
+
|
| 73 |
|
| 74 |
# --- Streaming Generator Logic ---
|
| 75 |
async def audio_stream_generator(text: str, voice: str) -> AsyncGenerator[bytes, None]:
|
|
|
|
| 81 |
yield b""
|
| 82 |
return
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
ffmpeg_command = [
|
| 85 |
"ffmpeg",
|
| 86 |
"-f", "s16le",
|
|
|
|
| 93 |
"pipe:1"
|
| 94 |
]
|
| 95 |
|
|
|
|
| 96 |
process = await asyncio.create_subprocess_exec(
|
| 97 |
*ffmpeg_command,
|
| 98 |
stdin=subprocess.PIPE,
|
| 99 |
stdout=subprocess.PIPE,
|
| 100 |
+
stderr=subprocess.PIPE
|
| 101 |
)
|
| 102 |
|
| 103 |
try:
|
|
|
|
|
|
|
| 104 |
async def read_stdout():
|
| 105 |
while True:
|
| 106 |
chunk = await process.stdout.read(4096)
|
|
|
|
| 110 |
|
| 111 |
stdout_reader = read_stdout()
|
| 112 |
|
|
|
|
| 113 |
async def feed_stdin():
|
| 114 |
try:
|
| 115 |
+
# --- USE THE BATCHER ---
|
| 116 |
+
batches = get_text_batches(text)
|
| 117 |
+
for batch in batches:
|
| 118 |
+
if not batch.strip():
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
wav_bytes = model.generate(text=batch, voice=voice)
|
| 122 |
|
|
|
|
| 123 |
wav_audio = AudioSegment.from_wav(io.BytesIO(wav_bytes))
|
| 124 |
raw_pcm_data = wav_audio.raw_data
|
| 125 |
|
| 126 |
+
if not raw_pcm_data:
|
| 127 |
+
print(f"Warning: No audio data for batch: {batch}")
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
process.stdin.write(raw_pcm_data)
|
| 131 |
await process.stdin.drain()
|
| 132 |
|
|
|
|
| 133 |
await asyncio.sleep(0.01)
|
| 134 |
|
| 135 |
except Exception as e:
|
| 136 |
print(f"Error in feed_stdin: {e}")
|
| 137 |
finally:
|
| 138 |
+
if process.stdin and not process.stdin.is_closing():
|
|
|
|
| 139 |
process.stdin.close()
|
| 140 |
await process.stdin.wait_closed()
|
| 141 |
|
|
|
|
| 142 |
feeder_task = asyncio.create_task(feed_stdin())
|
| 143 |
|
| 144 |
async for mp3_chunk in stdout_reader:
|
| 145 |
yield mp3_chunk
|
| 146 |
|
|
|
|
| 147 |
await feeder_task
|
| 148 |
|
|
|
|
| 149 |
await process.wait()
|
| 150 |
|
|
|
|
| 151 |
stderr_data = await process.stderr.read()
|
| 152 |
if stderr_data:
|
| 153 |
print(f"FFmpeg stderr: {stderr_data.decode()}")
|
| 154 |
|
| 155 |
except Exception as e:
|
| 156 |
print(f"Streaming generator error: {e}")
|
|
|
|
| 157 |
if process.returncode is None:
|
| 158 |
process.terminate()
|
| 159 |
await process.wait()
|
|
|
|
| 163 |
# --- API Endpoints ---
|
| 164 |
@app.post("/v1/audio/speech")
|
| 165 |
async def generate_speech(request: SpeechRequest):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
if request.voice not in voices:
|
| 167 |
raise HTTPException(status_code=400, detail=f"Voice must be one of {voices}")
|
| 168 |
+
if len(request.input) > 2000:
|
| 169 |
raise HTTPException(status_code=413, detail="Input text too long; max 2000 chars")
|
| 170 |
|
| 171 |
try:
|
| 172 |
if request.response_format == "mp3":
|
|
|
|
| 173 |
return StreamingResponse(
|
| 174 |
audio_stream_generator(text=request.input, voice=request.voice),
|
| 175 |
media_type="audio/mpeg",
|
|
|
|
| 188 |
return StreamingResponse(
|
| 189 |
iter_bytes(),
|
| 190 |
media_type="audio/wav",
|
| 191 |
+
headers={"Content-Disposition": "attachment; filename=speech.wav"}
|
| 192 |
)
|
| 193 |
|
| 194 |
else:
|
|
|
|
| 204 |
|
| 205 |
@app.get("/")
|
| 206 |
async def root():
|
| 207 |
+
return {"message": "Kitten TTS API v0.2.1 (Batching Fix)"}
|
| 208 |
|
| 209 |
if __name__ == "__main__":
|
| 210 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|