Sarthak279 commited on
Commit
562c785
·
verified ·
1 Parent(s): 01785b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py CHANGED
@@ -7,6 +7,48 @@ model_name = "Sarthak279/Disease-symptom-prediction"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # Set model to eval mode and move to device
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  model.to(device)
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
 
10
+
11
+
12
+ # Access label mapping from config if available
13
+ id2label = model.config.id2label # This is a dict like {0: 'Flu', 1: 'Typhoid', ...}
14
+
15
+ # Set device
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model.to(device)
18
+ model.eval()
19
+
20
+ # Prediction function
21
+ def predict_disease(patient_note):
22
+ inputs = tokenizer(
23
+ patient_note,
24
+ return_tensors="pt",
25
+ truncation=True,
26
+ padding=True,
27
+ max_length=512
28
+ ).to(device)
29
+
30
+ with torch.no_grad():
31
+ outputs = model(**inputs)
32
+ logits = outputs.logits
33
+ predicted_class = torch.argmax(logits, dim=1).item()
34
+
35
+ # Fetch label from config, or fallback
36
+ predicted_label = id2label.get(predicted_class, "Unknown")
37
+ return predicted_label
38
+
39
+ # Gradio Interface
40
+ demo = gr.Interface(
41
+ fn=predict_disease,
42
+ inputs=gr.Textbox(lines=4, label="📝 Enter Symptoms or Clinical Notes"),
43
+ outputs=gr.Textbox(label="Predicted Disease"),
44
+ title="🩺 Disease Prediction Model",
45
+ description="Predicts likely disease based on symptoms using a fine-tuned model from Hugging Face Hub."
46
+ )
47
+
48
+ if __name__ == "__main__":
49
+ demo.launch()
50
+
51
+
52
  # Set model to eval mode and move to device
53
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
  model.to(device)