Spaces:
Build error
Build error
| import gradio as gr | |
| from utils.model_configuration_utils import select_best_model, ensure_model | |
| from services.llm import build_llm | |
| from services.embeddings import configure_embeddings | |
| from services.indexing import build_symptom_index | |
| from utils.voice_input_utils import enhanced_process_speech, format_response_for_user, get_asr_pipeline | |
| import torch | |
| import torchaudio.transforms as T | |
| import json | |
| # 1) Model selection & download | |
| MODEL_NAME, REPO_ID = select_best_model() | |
| model_path = ensure_model() | |
| print(f"Using model: {MODEL_NAME} from {REPO_ID}") | |
| print(f"Model path: {model_path}") | |
| print(f"Model size: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.2f} GB") | |
| print(f"Model requirements: {MODEL_NAME} requires at least 4GB VRAM and 8GB RAM.") | |
| print(f"Model type: {'GPU' if torch.cuda.is_available() else 'CPU'}") | |
| # 2) LLM and embeddings config | |
| llm = build_llm(model_path) | |
| configure_embeddings() | |
| print(f"LLM configured with model: {model_path}") | |
| print("Embeddings configured successfully.") | |
| # 3) Index setup | |
| symptom_index = build_symptom_index() | |
| print("Symptom index built successfully.") | |
| print("Ready for queries.") | |
| # --- System prompt --- | |
| SYSTEM_PROMPT = """ | |
| You are a medical assistant helping a user narrow down to the most likely ICD-10 code. | |
| At each turn, EITHER ask one focused clarifying question (e.g. "Is your cough dry or productive?") | |
| or, if you have enough info, output a final JSON with fields: | |
| {"diagnoses":[…], "confidences":[…]}. | |
| """ | |
| # Build enhanced Gradio interface | |
| with gr.Blocks(theme="default") as demo: | |
| gr.Markdown(""" | |
| # 🏥 Medical Symptom to ICD-10 Code Assistant | |
| ## About | |
| This application is part of the Agents+MCP Hackathon. It helps medical professionals | |
| and patients understand potential diagnoses based on described symptoms. | |
| ### How it works: | |
| 1. Either click the record button and describe your symptoms or type them into the textbox | |
| 2. The AI will analyze your description and suggest possible diagnoses | |
| 3. Answer follow-up questions to refine the diagnosis | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Add text input above microphone | |
| with gr.Row(): | |
| text_input = gr.Textbox( | |
| label="Type your symptoms", | |
| placeholder="Or type your symptoms here...", | |
| lines=3 | |
| ) | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| # Existing microphone row | |
| with gr.Row(): | |
| microphone = gr.Audio( | |
| sources=["microphone"], | |
| streaming=True, | |
| type="numpy", | |
| label="Describe your symptoms" | |
| ) | |
| transcript_box = gr.Textbox( | |
| label="Transcribed Text", | |
| interactive=False, | |
| show_label=True | |
| ) | |
| clear_btn = gr.Button("Clear Chat", variant="secondary") | |
| chatbot = gr.Chatbot( | |
| label="Medical Consultation", | |
| height=500, | |
| container=True, | |
| type="messages" # This is now properly supported by our message format | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Accordion("Enter an API Key to give it more power!", open=False): | |
| api_key = gr.Textbox( | |
| label="OpenAI API Key (optional)", | |
| type="password", | |
| placeholder="sk-..." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| modal_key = gr.Textbox( | |
| label="Modal Labs API Key", | |
| type="password", | |
| placeholder="mk-..." | |
| ) | |
| anthropic_key = gr.Textbox( | |
| label="Anthropic API Key", | |
| type="password", | |
| placeholder="sk-ant-..." | |
| ) | |
| mistral_key = gr.Textbox( | |
| label="MistralAI API Key", | |
| type="password", | |
| placeholder="..." | |
| ) | |
| with gr.Column(): | |
| nebius_key = gr.Textbox( | |
| label="Nebius API Key", | |
| type="password", | |
| placeholder="..." | |
| ) | |
| hyperbolic_key = gr.Textbox( | |
| label="Hyperbolic Labs API Key", | |
| type="password", | |
| placeholder="hyp-..." | |
| ) | |
| sambanova_key = gr.Textbox( | |
| label="SambaNova API Key", | |
| type="password", | |
| placeholder="..." | |
| ) | |
| with gr.Row(): | |
| model_selector = gr.Dropdown( | |
| choices=["OpenAI", "Modal", "Anthropic", "MistralAI", "Nebius", "Hyperbolic", "SambaNova"], | |
| value="OpenAI", | |
| label="Model Provider" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.7, | |
| label="Temperature" | |
| ) | |
| # self promotion at bottom of page | |
| gr.Markdown(""" | |
| --- | |
| ### 👋 About the Creator | |
| Hi! I'm Graham Paasch, an experienced technology professional! | |
| 🎥 **Check out my YouTube channel** for more tech content: | |
| [Subscribe to my channel](https://www.youtube.com/channel/UCg3oUjrSYcqsL9rGk1g_lPQ) | |
| 💼 **Looking for a skilled developer?** | |
| I'm currently seeking new opportunities! View my experience and connect on [LinkedIn](https://www.linkedin.com/in/grahampaasch/) | |
| ⭐ If you found this tool helpful, please consider: | |
| - Subscribing to my YouTube channel | |
| - Connecting on LinkedIn | |
| - Sharing this tool with others in healthcare tech | |
| """) | |
| # Event handlers | |
| clear_btn.click(lambda: None, None, chatbot, queue=False) | |
| microphone.stream( | |
| fn=enhanced_process_speech, | |
| inputs=[microphone, chatbot, api_key, model_selector, temperature], | |
| outputs=chatbot, | |
| show_progress="hidden", | |
| api_name=False, | |
| queue=True # Enable queuing for better stream handling | |
| ) | |
| def process_audio(audio_array, sample_rate): | |
| """Pre-process audio for Whisper.""" | |
| if audio_array.ndim > 1: | |
| audio_array = audio_array.mean(axis=1) | |
| # Convert to tensor for resampling | |
| audio_tensor = torch.FloatTensor(audio_array) | |
| # Resample to 16kHz if needed | |
| if sample_rate != 16000: | |
| resampler = T.Resample(sample_rate, 16000) | |
| audio_tensor = resampler(audio_tensor) | |
| # Normalize | |
| audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor)) | |
| # Convert back to numpy array and return in correct format | |
| return { | |
| "raw": audio_tensor.numpy(), # Key must be "raw" | |
| "sampling_rate": 16000 # Key must be "sampling_rate" | |
| } | |
| # Update transcription handler | |
| def update_live_transcription(audio): | |
| """Real-time transcription updates.""" | |
| if not audio or not isinstance(audio, tuple): | |
| return "" | |
| try: | |
| sample_rate, audio_array = audio | |
| features = process_audio(audio_array, sample_rate) | |
| asr = get_asr_pipeline() | |
| result = asr(features) | |
| return result.get("text", "").strip() if isinstance(result, dict) else str(result).strip() | |
| except Exception as e: | |
| print(f"Transcription error: {str(e)}") | |
| return "" | |
| microphone.stream( | |
| fn=update_live_transcription, | |
| inputs=[microphone], | |
| outputs=transcript_box, | |
| show_progress="hidden", | |
| queue=True | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, "", ""), | |
| outputs=[chatbot, transcript_box, text_input], | |
| queue=False | |
| ) | |
| def cleanup_memory(): | |
| """Release unused memory (placeholder for future memory management).""" | |
| import gc | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def process_text_input(text, history): | |
| """Process text input with memory management.""" | |
| print("process_text_input received:", text) | |
| if not text: | |
| return history, "" # Return tuple to clear input | |
| # Process the symptoms using the configured LLM | |
| prompt = f"""Given these symptoms: '{text}' | |
| Please provide: | |
| 1. Most likely ICD-10 codes | |
| 2. Confidence levels for each diagnosis | |
| 3. Key follow-up questions | |
| Format as JSON with diagnoses, confidences, and follow_up fields.""" | |
| response = llm.complete(prompt) | |
| try: | |
| # Try to parse as JSON first | |
| result = json.loads(response.text) | |
| except json.JSONDecodeError: | |
| # If not JSON, wrap in our format | |
| result = { | |
| "diagnoses": [], | |
| "confidences": [], | |
| "follow_up": str(response.text)[:1000] # Limit response length | |
| } | |
| new_history = history + [ | |
| {"role": "user", "content": text}, | |
| {"role": "assistant", "content": format_response_for_user(result)} | |
| ] | |
| return new_history, "" # Return empty string to clear input | |
| # Update the submit button handler | |
| submit_btn.click( | |
| fn=process_text_input, | |
| inputs=[text_input, chatbot], | |
| outputs=[chatbot, text_input], | |
| queue=True | |
| ).success( # Changed from .then to .success for better error handling | |
| fn=cleanup_memory, | |
| inputs=None, | |
| outputs=None, | |
| queue=False | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, # Enable sharing via Gradio's temporary URLs | |
| show_api=True # Shows the API documentation | |
| ) | |