| | from fastapi import FastAPI |
| | from pydantic import BaseModel |
| | from typing import List |
| | import torch |
| | from parler_tts import ParlerTTSForConditionalGeneration |
| | from transformers import AutoTokenizer |
| | import soundfile as sf |
| | import io |
| | from starlette.responses import StreamingResponse |
| |
|
| | |
| | app = FastAPI() |
| |
|
| | |
| | class Item(BaseModel): |
| | text: str |
| | name: str |
| | section: str |
| |
|
| | |
| | device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| | model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device) |
| | tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1") |
| |
|
| | |
| | @app.get("/") |
| | def greet_json(): |
| | return {"Hello": "World!"} |
| |
|
| | |
| | def generate_audio(text, description="Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."): |
| | print("A") |
| | input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device) |
| | print("B") |
| | prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device) |
| | print("C") |
| | generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) |
| | print("D") |
| | audio_arr = generation.cpu().numpy().squeeze() |
| | print("E") |
| | return audio_arr, model.config.sampling_rate |
| |
|
| | |
| | @app.post("/") |
| | async def create_items(items: List[Item]): |
| | processed_items = [] |
| | for item in items: |
| | print(f"Processing item: {item.text}") |
| | |
| | print("before") |
| | audio_arr, sample_rate = generate_audio(item.text) |
| | print("after") |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | processed_item = { |
| | "text": item.text, |
| | "name": item.name, |
| | "section": item.section, |
| | "processed": True, |
| | |
| | } |
| | processed_items.append(processed_item) |
| |
|
| | return {"processed_items": processed_items} |
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="127.0.0.1", port=8000) |
| |
|