Rajhuggingface4253 commited on
Commit
aeeb4b5
·
verified ·
1 Parent(s): 8d1ea02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -49
app.py CHANGED
@@ -12,24 +12,21 @@ import nltk
12
  from pydub import AudioSegment
13
 
14
  # --- App Setup ---
15
- app = FastAPI(title="Kitten TTS API", version="0.2.0")
16
 
17
  # --- Model & Tokenizer Loading ---
18
- # Download NLTK data (one-time)
19
  try:
20
  nltk.data.find('tokenizers/punkt')
21
  except LookupError:
22
  nltk.download('punkt')
23
 
24
- # Load the TTS model once at startup
25
  print("Loading KittenTTS model...")
26
  model = KittenTTS("KittenML/kitten-tts-nano-0.2")
27
  print("Model loaded.")
28
 
29
- # List of available voices
30
  voices = [
31
- "expr-voice-2-f", "expr-voice-2-m", "expr-voice-3-f", "expr-voice-3-m",
32
- "expr-voice-4-f", "expr-voice-4-m", "expr-voice-5-f", "expr-voice-5-m"
33
  ]
34
 
35
  # --- Request Models ---
@@ -37,14 +34,42 @@ class SpeechRequest(BaseModel):
37
  input: str
38
  model: str = "kitten-nano-0.2"
39
  voice: str = "expr-voice-1-f"
40
- speed: Optional[float] = 1.0 # Note: speed is not yet implemented
41
- response_format: str = "mp3" # Defaulting to mp3
42
 
43
- # --- Text Chunking Logic ---
44
- def get_text_chunks(text: str):
45
- """Splits text into semantically safe sentences."""
 
 
 
 
 
46
  sentences = nltk.sent_tokenize(text)
47
- return [sentence for sentence in sentences if sentence.strip()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # --- Streaming Generator Logic ---
50
  async def audio_stream_generator(text: str, voice: str) -> AsyncGenerator[bytes, None]:
@@ -56,15 +81,6 @@ async def audio_stream_generator(text: str, voice: str) -> AsyncGenerator[bytes,
56
  yield b""
57
  return
58
 
59
- # 1. Define FFmpeg command for on-the-fly conversion
60
- # -f s16le: Input is signed 16-bit little-endian (raw PCM from WAV)
61
- # -ar 24000: KittenTTS sample rate is 24kHz
62
- # -ac 1: Mono audio
63
- # -i pipe:0: Read from stdin
64
- # -f mp3: Output format is MP3
65
- # -b:a 96k: 96 kbps bitrate (good for speech)
66
- # -ar 44100: Resample to 44.1kHz for better compatibility (optional)
67
- # pipe:1: Write to stdout
68
  ffmpeg_command = [
69
  "ffmpeg",
70
  "-f", "s16le",
@@ -77,17 +93,14 @@ async def audio_stream_generator(text: str, voice: str) -> AsyncGenerator[bytes,
77
  "pipe:1"
78
  ]
79
 
80
- # 2. Start the FFmpeg subprocess
81
  process = await asyncio.create_subprocess_exec(
82
  *ffmpeg_command,
83
  stdin=subprocess.PIPE,
84
  stdout=subprocess.PIPE,
85
- stderr=subprocess.PIPE # Capture errors
86
  )
87
 
88
  try:
89
- # 3. Create an asyncio task to read from stdout (non-blocking)
90
- # This task will yield MP3 chunks as FFmpeg produces them.
91
  async def read_stdout():
92
  while True:
93
  chunk = await process.stdout.read(4096)
@@ -97,53 +110,50 @@ async def audio_stream_generator(text: str, voice: str) -> AsyncGenerator[bytes,
97
 
98
  stdout_reader = read_stdout()
99
 
100
- # 4. Generate and feed WAV data into FFmpeg's stdin
101
  async def feed_stdin():
102
  try:
103
- chunks = get_text_chunks(text)
104
- for chunk in chunks:
105
- # Generate the raw WAV bytes for this chunk
106
- wav_bytes = model.generate(text=chunk, voice=voice)
 
 
 
107
 
108
- # Use pydub to strip the WAV header and get raw PCM data
109
  wav_audio = AudioSegment.from_wav(io.BytesIO(wav_bytes))
110
  raw_pcm_data = wav_audio.raw_data
111
 
112
- # Feed the raw PCM data into FFmpeg
 
 
 
113
  process.stdin.write(raw_pcm_data)
114
  await process.stdin.drain()
115
 
116
- # Small sleep to allow generation to be non-blocking
117
  await asyncio.sleep(0.01)
118
 
119
  except Exception as e:
120
  print(f"Error in feed_stdin: {e}")
121
  finally:
122
- # Close stdin when done to signal end to FFmpeg
123
- if not process.stdin.is_closing():
124
  process.stdin.close()
125
  await process.stdin.wait_closed()
126
 
127
- # 5. Run the feeder and reader tasks concurrently
128
  feeder_task = asyncio.create_task(feed_stdin())
129
 
130
  async for mp3_chunk in stdout_reader:
131
  yield mp3_chunk
132
 
133
- # Wait for feeder to finish
134
  await feeder_task
135
 
136
- # Wait for FFmpeg to finish
137
  await process.wait()
138
 
139
- # Check for FFmpeg errors
140
  stderr_data = await process.stderr.read()
141
  if stderr_data:
142
  print(f"FFmpeg stderr: {stderr_data.decode()}")
143
 
144
  except Exception as e:
145
  print(f"Streaming generator error: {e}")
146
- # Clean up process if it's still running
147
  if process.returncode is None:
148
  process.terminate()
149
  await process.wait()
@@ -153,19 +163,13 @@ async def audio_stream_generator(text: str, voice: str) -> AsyncGenerator[bytes,
153
  # --- API Endpoints ---
154
  @app.post("/v1/audio/speech")
155
  async def generate_speech(request: SpeechRequest):
156
- """
157
- Generates speech audio, streaming the response.
158
- Supports 'mp3' (streaming) and 'wav' (blocking, file-stream).
159
- """
160
- # Validation
161
  if request.voice not in voices:
162
  raise HTTPException(status_code=400, detail=f"Voice must be one of {voices}")
163
- if len(request.input) > 2000: # New, higher limit for streaming
164
  raise HTTPException(status_code=413, detail="Input text too long; max 2000 chars")
165
 
166
  try:
167
  if request.response_format == "mp3":
168
- # The new TRUE-STREAMING MP3 path
169
  return StreamingResponse(
170
  audio_stream_generator(text=request.input, voice=request.voice),
171
  media_type="audio/mpeg",
@@ -184,7 +188,7 @@ async def generate_speech(request: SpeechRequest):
184
  return StreamingResponse(
185
  iter_bytes(),
186
  media_type="audio/wav",
187
- headers={"Content-DIsiI've": "attachment; filename=speech.wav"}
188
  )
189
 
190
  else:
@@ -200,7 +204,7 @@ async def list_voices():
200
 
201
  @app.get("/")
202
  async def root():
203
- return {"message": "Kitten TTS API v0.2.0 (True-Streaming Enabled)"}
204
 
205
  if __name__ == "__main__":
206
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
12
  from pydub import AudioSegment
13
 
14
  # --- App Setup ---
15
+ app = FastAPI(title="Kitten TTS API", version="0.2.1") # Version bump
16
 
17
  # --- Model & Tokenizer Loading ---
 
18
  try:
19
  nltk.data.find('tokenizers/punkt')
20
  except LookupError:
21
  nltk.download('punkt')
22
 
 
23
  print("Loading KittenTTS model...")
24
  model = KittenTTS("KittenML/kitten-tts-nano-0.2")
25
  print("Model loaded.")
26
 
 
27
  voices = [
28
+ "expr-voice-1-f", "expr-voice-2-m", "expr-voice-3-f", "expr-voice-4-m",
29
+ "expr-voice-5-f", "expr-voice-6-m", "expr-voice-7-f", "expr-voice-8-m"
30
  ]
31
 
32
  # --- Request Models ---
 
34
  input: str
35
  model: str = "kitten-nano-0.2"
36
  voice: str = "expr-voice-1-f"
37
+ speed: Optional[float] = 1.0
38
+ response_format: str = "mp3"
39
 
40
+ # --- --- --- --- --- --- --- --- --- ---
41
+ # --- THIS IS THE FIX ---
42
+ # --- --- --- --- --- --- --- --- --- ---
43
+ def get_text_batches(text: str, min_batch_chars: int = 150):
44
+ """
45
+ Joins small NLTK sentences into larger, 'audio-safe' batches.
46
+ This prevents sending tiny chunks to ffmpeg, which fails to encode.
47
+ """
48
  sentences = nltk.sent_tokenize(text)
49
+ current_batch = ""
50
+
51
+ for sentence in sentences:
52
+ if not sentence.strip():
53
+ continue
54
+
55
+ # Add the sentence to the current batch
56
+ if current_batch:
57
+ current_batch += " " + sentence
58
+ else:
59
+ current_batch = sentence
60
+
61
+ # If the batch is large enough, yield it
62
+ if len(current_batch) >= min_batch_chars:
63
+ yield current_batch
64
+ current_batch = ""
65
+
66
+ # Yield any remaining text in the last batch
67
+ if current_batch.strip():
68
+ yield current_batch
69
+ # --- --- --- --- --- --- --- --- --- ---
70
+ # --- END OF FIX ---
71
+ # --- --- --- --- --- --- --- --- --- ---
72
+
73
 
74
  # --- Streaming Generator Logic ---
75
  async def audio_stream_generator(text: str, voice: str) -> AsyncGenerator[bytes, None]:
 
81
  yield b""
82
  return
83
 
 
 
 
 
 
 
 
 
 
84
  ffmpeg_command = [
85
  "ffmpeg",
86
  "-f", "s16le",
 
93
  "pipe:1"
94
  ]
95
 
 
96
  process = await asyncio.create_subprocess_exec(
97
  *ffmpeg_command,
98
  stdin=subprocess.PIPE,
99
  stdout=subprocess.PIPE,
100
+ stderr=subprocess.PIPE
101
  )
102
 
103
  try:
 
 
104
  async def read_stdout():
105
  while True:
106
  chunk = await process.stdout.read(4096)
 
110
 
111
  stdout_reader = read_stdout()
112
 
 
113
  async def feed_stdin():
114
  try:
115
+ # --- USE THE BATCHER ---
116
+ batches = get_text_batches(text)
117
+ for batch in batches:
118
+ if not batch.strip():
119
+ continue
120
+
121
+ wav_bytes = model.generate(text=batch, voice=voice)
122
 
 
123
  wav_audio = AudioSegment.from_wav(io.BytesIO(wav_bytes))
124
  raw_pcm_data = wav_audio.raw_data
125
 
126
+ if not raw_pcm_data:
127
+ print(f"Warning: No audio data for batch: {batch}")
128
+ continue
129
+
130
  process.stdin.write(raw_pcm_data)
131
  await process.stdin.drain()
132
 
 
133
  await asyncio.sleep(0.01)
134
 
135
  except Exception as e:
136
  print(f"Error in feed_stdin: {e}")
137
  finally:
138
+ if process.stdin and not process.stdin.is_closing():
 
139
  process.stdin.close()
140
  await process.stdin.wait_closed()
141
 
 
142
  feeder_task = asyncio.create_task(feed_stdin())
143
 
144
  async for mp3_chunk in stdout_reader:
145
  yield mp3_chunk
146
 
 
147
  await feeder_task
148
 
 
149
  await process.wait()
150
 
 
151
  stderr_data = await process.stderr.read()
152
  if stderr_data:
153
  print(f"FFmpeg stderr: {stderr_data.decode()}")
154
 
155
  except Exception as e:
156
  print(f"Streaming generator error: {e}")
 
157
  if process.returncode is None:
158
  process.terminate()
159
  await process.wait()
 
163
  # --- API Endpoints ---
164
  @app.post("/v1/audio/speech")
165
  async def generate_speech(request: SpeechRequest):
 
 
 
 
 
166
  if request.voice not in voices:
167
  raise HTTPException(status_code=400, detail=f"Voice must be one of {voices}")
168
+ if len(request.input) > 2000:
169
  raise HTTPException(status_code=413, detail="Input text too long; max 2000 chars")
170
 
171
  try:
172
  if request.response_format == "mp3":
 
173
  return StreamingResponse(
174
  audio_stream_generator(text=request.input, voice=request.voice),
175
  media_type="audio/mpeg",
 
188
  return StreamingResponse(
189
  iter_bytes(),
190
  media_type="audio/wav",
191
+ headers={"Content-Disposition": "attachment; filename=speech.wav"}
192
  )
193
 
194
  else:
 
204
 
205
  @app.get("/")
206
  async def root():
207
+ return {"message": "Kitten TTS API v0.2.1 (Batching Fix)"}
208
 
209
  if __name__ == "__main__":
210
  uvicorn.run(app, host="0.0.0.0", port=7860)