MedCodeMCP / app.py
gpaasch's picture
improved seperation of concerns best practice in the code, added print statements for better understanding of what code is doing
9d2bec8
raw
history blame
10.6 kB
import gradio as gr
from utils.model_configuration_utils import select_best_model, ensure_model
from services.llm import build_llm
from services.embeddings import configure_embeddings
from services.indexing import build_symptom_index
from utils.voice_input_utils import enhanced_process_speech, format_response_for_user, get_asr_pipeline
import torch
import torchaudio.transforms as T
import json
# 1) Model selection & download
MODEL_NAME, REPO_ID = select_best_model()
model_path = ensure_model()
print(f"Using model: {MODEL_NAME} from {REPO_ID}")
print(f"Model path: {model_path}")
print(f"Model size: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.2f} GB")
print(f"Model requirements: {MODEL_NAME} requires at least 4GB VRAM and 8GB RAM.")
print(f"Model type: {'GPU' if torch.cuda.is_available() else 'CPU'}")
# 2) LLM and embeddings config
llm = build_llm(model_path)
configure_embeddings()
print(f"LLM configured with model: {model_path}")
print("Embeddings configured successfully.")
# 3) Index setup
symptom_index = build_symptom_index()
print("Symptom index built successfully.")
print("Ready for queries.")
# --- System prompt ---
SYSTEM_PROMPT = """
You are a medical assistant helping a user narrow down to the most likely ICD-10 code.
At each turn, EITHER ask one focused clarifying question (e.g. "Is your cough dry or productive?")
or, if you have enough info, output a final JSON with fields:
{"diagnoses":[…], "confidences":[…]}.
"""
# Build enhanced Gradio interface
with gr.Blocks(theme="default") as demo:
gr.Markdown("""
# 🏥 Medical Symptom to ICD-10 Code Assistant
## About
This application is part of the Agents+MCP Hackathon. It helps medical professionals
and patients understand potential diagnoses based on described symptoms.
### How it works:
1. Either click the record button and describe your symptoms or type them into the textbox
2. The AI will analyze your description and suggest possible diagnoses
3. Answer follow-up questions to refine the diagnosis
""")
with gr.Row():
with gr.Column(scale=2):
# Add text input above microphone
with gr.Row():
text_input = gr.Textbox(
label="Type your symptoms",
placeholder="Or type your symptoms here...",
lines=3
)
submit_btn = gr.Button("Submit", variant="primary")
# Existing microphone row
with gr.Row():
microphone = gr.Audio(
sources=["microphone"],
streaming=True,
type="numpy",
label="Describe your symptoms"
)
transcript_box = gr.Textbox(
label="Transcribed Text",
interactive=False,
show_label=True
)
clear_btn = gr.Button("Clear Chat", variant="secondary")
chatbot = gr.Chatbot(
label="Medical Consultation",
height=500,
container=True,
type="messages" # This is now properly supported by our message format
)
with gr.Column(scale=1):
with gr.Accordion("Enter an API Key to give it more power!", open=False):
api_key = gr.Textbox(
label="OpenAI API Key (optional)",
type="password",
placeholder="sk-..."
)
with gr.Row():
with gr.Column():
modal_key = gr.Textbox(
label="Modal Labs API Key",
type="password",
placeholder="mk-..."
)
anthropic_key = gr.Textbox(
label="Anthropic API Key",
type="password",
placeholder="sk-ant-..."
)
mistral_key = gr.Textbox(
label="MistralAI API Key",
type="password",
placeholder="..."
)
with gr.Column():
nebius_key = gr.Textbox(
label="Nebius API Key",
type="password",
placeholder="..."
)
hyperbolic_key = gr.Textbox(
label="Hyperbolic Labs API Key",
type="password",
placeholder="hyp-..."
)
sambanova_key = gr.Textbox(
label="SambaNova API Key",
type="password",
placeholder="..."
)
with gr.Row():
model_selector = gr.Dropdown(
choices=["OpenAI", "Modal", "Anthropic", "MistralAI", "Nebius", "Hyperbolic", "SambaNova"],
value="OpenAI",
label="Model Provider"
)
temperature = gr.Slider(
minimum=0,
maximum=1,
value=0.7,
label="Temperature"
)
# self promotion at bottom of page
gr.Markdown("""
---
### 👋 About the Creator
Hi! I'm Graham Paasch, an experienced technology professional!
🎥 **Check out my YouTube channel** for more tech content:
[Subscribe to my channel](https://www.youtube.com/channel/UCg3oUjrSYcqsL9rGk1g_lPQ)
💼 **Looking for a skilled developer?**
I'm currently seeking new opportunities! View my experience and connect on [LinkedIn](https://www.linkedin.com/in/grahampaasch/)
⭐ If you found this tool helpful, please consider:
- Subscribing to my YouTube channel
- Connecting on LinkedIn
- Sharing this tool with others in healthcare tech
""")
# Event handlers
clear_btn.click(lambda: None, None, chatbot, queue=False)
microphone.stream(
fn=enhanced_process_speech,
inputs=[microphone, chatbot, api_key, model_selector, temperature],
outputs=chatbot,
show_progress="hidden",
api_name=False,
queue=True # Enable queuing for better stream handling
)
def process_audio(audio_array, sample_rate):
"""Pre-process audio for Whisper."""
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1)
# Convert to tensor for resampling
audio_tensor = torch.FloatTensor(audio_array)
# Resample to 16kHz if needed
if sample_rate != 16000:
resampler = T.Resample(sample_rate, 16000)
audio_tensor = resampler(audio_tensor)
# Normalize
audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor))
# Convert back to numpy array and return in correct format
return {
"raw": audio_tensor.numpy(), # Key must be "raw"
"sampling_rate": 16000 # Key must be "sampling_rate"
}
# Update transcription handler
def update_live_transcription(audio):
"""Real-time transcription updates."""
if not audio or not isinstance(audio, tuple):
return ""
try:
sample_rate, audio_array = audio
features = process_audio(audio_array, sample_rate)
asr = get_asr_pipeline()
result = asr(features)
return result.get("text", "").strip() if isinstance(result, dict) else str(result).strip()
except Exception as e:
print(f"Transcription error: {str(e)}")
return ""
microphone.stream(
fn=update_live_transcription,
inputs=[microphone],
outputs=transcript_box,
show_progress="hidden",
queue=True
)
clear_btn.click(
fn=lambda: (None, "", ""),
outputs=[chatbot, transcript_box, text_input],
queue=False
)
def cleanup_memory():
"""Release unused memory (placeholder for future memory management)."""
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def process_text_input(text, history):
"""Process text input with memory management."""
print("process_text_input received:", text)
if not text:
return history, "" # Return tuple to clear input
# Process the symptoms using the configured LLM
prompt = f"""Given these symptoms: '{text}'
Please provide:
1. Most likely ICD-10 codes
2. Confidence levels for each diagnosis
3. Key follow-up questions
Format as JSON with diagnoses, confidences, and follow_up fields."""
response = llm.complete(prompt)
try:
# Try to parse as JSON first
result = json.loads(response.text)
except json.JSONDecodeError:
# If not JSON, wrap in our format
result = {
"diagnoses": [],
"confidences": [],
"follow_up": str(response.text)[:1000] # Limit response length
}
new_history = history + [
{"role": "user", "content": text},
{"role": "assistant", "content": format_response_for_user(result)}
]
return new_history, "" # Return empty string to clear input
# Update the submit button handler
submit_btn.click(
fn=process_text_input,
inputs=[text_input, chatbot],
outputs=[chatbot, text_input],
queue=True
).success( # Changed from .then to .success for better error handling
fn=cleanup_memory,
inputs=None,
outputs=None,
queue=False
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True, # Enable sharing via Gradio's temporary URLs
show_api=True # Shows the API documentation
)