import os import time import numpy as np import gradio as gr import librosa import soundfile as sf import torch import traceback import threading from spaces import GPU from datetime import datetime from contextlib import contextmanager from modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference from processor.vibevoice_processor import VibeVoiceProcessor from modular.streamer import AudioStreamer from transformers.utils import logging from transformers import set_seed logging.set_verbosity_info() logger = logging.get_logger(__name__) class VibeVoiceDemo: def __init__(self, model_paths: dict, device: str = "cuda", inference_steps: int = 5): """ model_paths: dict like {"VibeVoice-1.5B": "microsoft/VibeVoice-1.5B", "VibeVoice-7B": "microsoft/VibeVoice-7B"} """ self.model_paths = model_paths self.device = device self.inference_steps = inference_steps self.is_generating = False # Multi-model holders self.models = {} # name -> model self.processors = {} # name -> processor self.current_model_name = None self.available_voices = {} # Set compiler flags for better performance if torch.cuda.is_available() and hasattr(torch, '_inductor'): if hasattr(torch._inductor, 'config'): torch._inductor.config.conv_1x1_as_mm = True torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.epilogue_fusion = False torch._inductor.config.coordinate_descent_check_all_directions = True self.load_models() # load all on CPU self.setup_voice_presets() self.load_example_scripts() def load_models(self): print("Loading processors and models on CPU...") # Debug: Show cache location import os cache_dir = os.path.expanduser("~/.cache/huggingface/hub") print(f"HuggingFace cache directory: {cache_dir}") if os.path.exists(cache_dir): print(f"Cache exists. Size: {sum(os.path.getsize(os.path.join(dirpath, filename)) for dirpath, _, filenames in os.walk(cache_dir) for filename in filenames) / (1024**3):.2f} GB") print("Cached models:") for item in os.listdir(cache_dir): if item.startswith("models--"): print(f" - {item}") for name, path in self.model_paths.items(): print(f" - {name} from {path}") proc = VibeVoiceProcessor.from_pretrained(path) # Try to use flash attention if available try: mdl = VibeVoiceForConditionalGenerationInference.from_pretrained( path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" ) print(f" Flash Attention 2 enabled for {name}") except: # Fallback to default attention mdl = VibeVoiceForConditionalGenerationInference.from_pretrained( path, torch_dtype=torch.bfloat16 ) # Keep on CPU initially self.processors[name] = proc self.models[name] = mdl # choose default self.current_model_name = next(iter(self.models)) print(f"Default model is {self.current_model_name}") def _place_model(self, target_name: str): """ Move the selected model to CUDA and push all others back to CPU. """ for name, mdl in self.models.items(): if name == target_name: self.models[name] = mdl.to(self.device) else: self.models[name] = mdl.to("cpu") self.current_model_name = target_name print(f"Model {target_name} is now on {self.device}. Others moved to CPU.") def setup_voice_presets(self): voices_dir = os.path.join(os.path.dirname(__file__), "voices") if not os.path.exists(voices_dir): print(f"Warning: Voices directory not found at {voices_dir}") return wav_files = [f for f in os.listdir(voices_dir) if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac'))] for wav_file in wav_files: name = os.path.splitext(wav_file)[0] self.available_voices[name] = os.path.join(voices_dir, wav_file) print(f"Voices loaded: {list(self.available_voices.keys())}") # Organize voices by gender self.male_voices = [ "en-Carter_man", "en-Frank_man", "en-Yasser_man", "in-Samuel_man", "zh-Anchen_man_bgm", "zh-Bowen_man" ] self.female_voices = [ "en-Alice_woman_bgm", "en-Alice_woman", "en-Maya_woman", "zh-Xinran_woman" ] def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray: try: wav, sr = sf.read(audio_path) if len(wav.shape) > 1: wav = np.mean(wav, axis=1) if sr != target_sr: wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr) return wav except Exception as e: print(f"Error reading audio {audio_path}: {e}") return np.array([]) @GPU(duration=120) def generate_podcast(self, num_speakers: int, script: str, speaker_1: str = None, speaker_2: str = None, speaker_3: str = None, speaker_4: str = None, cfg_scale: float = 1.3, model_name: str = None): """ Generates a conference as a single audio file from a script and saves it. Non-streaming. """ try: # pick model model_name = model_name or self.current_model_name if model_name not in self.models: raise gr.Error(f"Unknown model: {model_name}") # place models on devices self._place_model(model_name) model = self.models[model_name] processor = self.processors[model_name] print(f"Using model {model_name} on {self.device}") model.eval() model.set_ddpm_inference_steps(num_steps=self.inference_steps) self.is_generating = True if not script.strip(): raise gr.Error("Error: Please provide a script.") script = script.replace("ā", "'") if not 1 <= num_speakers <= 4: raise gr.Error("Error: Number of speakers must be between 1 and 4.") selected_speakers = [speaker_1, speaker_2, speaker_3, speaker_4][:num_speakers] for i, speaker_name in enumerate(selected_speakers): if not speaker_name or speaker_name not in self.available_voices: raise gr.Error(f"Error: Please select a valid speaker for Speaker {i+1}.") log = f"Generating conference with {num_speakers} speakers\n" log += f"Model: {model_name}\n" log += f"Parameters: CFG Scale={cfg_scale}\n" log += f"Speakers: {', '.join(selected_speakers)}\n" voice_samples = [] for speaker_name in selected_speakers: audio_path = self.available_voices[speaker_name] audio_data = self.read_audio(audio_path) if len(audio_data) == 0: raise gr.Error(f"Error: Failed to load audio for {speaker_name}") voice_samples.append(audio_data) log += f"Loaded {len(voice_samples)} voice samples\n" lines = script.strip().split('\n') formatted_script_lines = [] for line in lines: line = line.strip() if not line: continue if line.startswith('Speaker ') and ':' in line: formatted_script_lines.append(line) else: speaker_id = len(formatted_script_lines) % num_speakers formatted_script_lines.append(f"Speaker {speaker_id}: {line}") formatted_script = '\n'.join(formatted_script_lines) log += f"Formatted script with {len(formatted_script_lines)} turns\n" log += "Processing with VibeVoice...\n" inputs = processor( text=[formatted_script], voice_samples=[voice_samples], padding=True, return_tensors="pt", return_attention_mask=True, ) start_time = time.time() # Use efficient attention backend if torch.cuda.is_available() and hasattr(torch.nn.attention, 'SDPBackend'): from torch.nn.attention import SDPBackend, sdpa_kernel with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): outputs = model.generate( **inputs, max_new_tokens=None, cfg_scale=cfg_scale, tokenizer=processor.tokenizer, generation_config={'do_sample': False}, verbose=False, ) else: outputs = model.generate( **inputs, max_new_tokens=None, cfg_scale=cfg_scale, tokenizer=processor.tokenizer, generation_config={'do_sample': False}, verbose=False, ) generation_time = time.time() - start_time if hasattr(outputs, 'speech_outputs') and outputs.speech_outputs[0] is not None: audio_tensor = outputs.speech_outputs[0] audio = audio_tensor.cpu().float().numpy() else: raise gr.Error("Error: No audio was generated by the model. Please try again.") if audio.ndim > 1: audio = audio.squeeze() sample_rate = 24000 output_dir = "outputs" os.makedirs(output_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") file_path = os.path.join(output_dir, f"conference_{timestamp}.wav") sf.write(file_path, audio, sample_rate) print(f"Conference saved to {file_path}") total_duration = len(audio) / sample_rate log += f"Generation completed in {generation_time:.2f} seconds\n" log += f"Final audio duration: {total_duration:.2f} seconds\n" log += f"Successfully saved conference to: {file_path}\n" self.is_generating = False return (sample_rate, audio), log except gr.Error as e: self.is_generating = False error_msg = f"Input Error: {str(e)}" print(error_msg) return None, error_msg except Exception as e: self.is_generating = False error_msg = f"An unexpected error occurred: {str(e)}" print(error_msg) traceback.print_exc() return None, error_msg @staticmethod def _infer_num_speakers_from_script(script: str) -> int: """ Infer number of speakers by counting distinct 'Speaker X:' tags in the script. Robust to 0- or 1-indexed labels and repeated turns. Falls back to 1 if none found. """ import re ids = re.findall(r'(?mi)^\s*Speaker\s+(\d+)\s*:', script) return len({int(x) for x in ids}) if ids else 1 def load_example_scripts(self): examples_dir = os.path.join(os.path.dirname(__file__), "text_examples") self.example_scripts = [] self.example_scripts_natural = [] if not os.path.exists(examples_dir): return original_files = [ "1p_ai_tedtalk.txt", "1p_politcal_speech.txt", "2p_financeipo_meeting.txt", "2p_telehealth_meeting.txt", "3p_military_meeting.txt", "3p_oil_meeting.txt", "4p_gamecreation_meeting.txt", "4p_product_meeting.txt" ] # Gender mapping for each script's speakers self.script_speaker_genders = [ ["female"], # AI TED Talk - Rachel ["neutral"], # Political Speech - generic speaker ["male", "female"], # Finance IPO - James, Patricia ["female", "male"], # Telehealth - Jennifer, Tom ["female", "male", "female"], # Military - Sarah, David, Lisa ["male", "female", "male"], # Oil - Robert, Lisa, Michael ["male", "female", "male", "male"], # Game Creation - Alex, Sarah, Marcus, Emma ["female", "male", "female", "male"] # Product Meeting - Sarah, Marcus, Jennifer, David ] for txt_file in original_files: try: with open(os.path.join(examples_dir, txt_file), 'r', encoding='utf-8') as f: script_content = f.read().strip() if script_content: num_speakers = self._infer_num_speakers_from_script(script_content) self.example_scripts.append([num_speakers, script_content]) natural_file = txt_file.replace('.txt', '_natural.txt') natural_path = os.path.join(examples_dir, natural_file) if os.path.exists(natural_path): with open(natural_path, 'r', encoding='utf-8') as f: natural_content = f.read().strip() if natural_content: num_speakers = self._infer_num_speakers_from_script(natural_content) self.example_scripts_natural.append([num_speakers, natural_content]) else: self.example_scripts_natural.append([num_speakers, script_content]) except Exception as e: print(f"Error loading {txt_file}: {e}") def convert_to_16_bit_wav(data): if torch.is_tensor(data): data = data.detach().cpu().numpy() data = np.array(data) if np.max(np.abs(data)) > 1.0: data = data / np.max(np.abs(data)) return (data * 32767).astype(np.int16) # Set synthwave theme theme = gr.themes.Ocean( primary_hue="indigo", secondary_hue="fuchsia", neutral_hue="slate", ).set( button_large_radius='*radius_sm' ) def set_working_state(*components, transcript_box=None): """ Disable all interactive components and show progress in transcript/log box. Usage: set_working_state(generate_btn, random_example_btn, transcript_box=log_output) """ updates = [gr.update(interactive=False) for _ in components] if transcript_box is not None: updates.append(gr.update(value="Generating... please wait", interactive=False)) return tuple(updates) def set_idle_state(*components, transcript_box=None): """ Re-enable all interactive components and transcript/log box. Usage: set_idle_state(generate_btn, random_example_btn, transcript_box=log_output) """ updates = [gr.update(interactive=True) for _ in components] if transcript_box is not None: updates.append(gr.update(interactive=True)) return tuple(updates) def create_demo_interface(demo_instance: VibeVoiceDemo): custom_css = """ """ with gr.Blocks( title="VibeVoice - Conference Generator", css=custom_css, theme=theme, ) as interface: # Simple image gr.HTML("""