Spaces:
Running
Running
| import joblib | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torchvision.transforms as transforms | |
| # 10 condition-based questions | |
| condition_questions = [ | |
| "Do you have a cough?", | |
| "Do you feel shortness of breath?", | |
| "Are you experiencing chest pain?", | |
| "Do you smoke?", | |
| "Do you have a fever?", | |
| "Do you have fatigue?", | |
| "Have you had recent respiratory infection?", | |
| "Do you have a family history of lung issues?", | |
| "Do you feel wheezing or noisy breathing?", | |
| "Have you been exposed to pollution or chemicals recently?" | |
| ] | |
| # Load the tabular ML model | |
| chest_model = joblib.load("chest_model.joblib") | |
| # OPTIONAL: Load CNN model for image (replace with your model class) | |
| # cnn_model = torch.load("cnn_model.pth", map_location=torch.device('cpu')) | |
| # cnn_model.eval() | |
| # Encoding helpers | |
| def encode_gender(gender): | |
| return 0 if gender == "Male" else 1 | |
| def encode_view_position(position): | |
| return 0 if position == "PA" else 1 | |
| # Main prediction function | |
| def predict_chest(age, gender, view_position, conditions, uploaded_image=None): | |
| """ | |
| Predicts chest disease from tabular data or uploaded image. | |
| Parameters: | |
| - age: int | |
| - gender: 'Male' or 'Female' | |
| - view_position: 'PA' or 'AP' | |
| - conditions: list of 10 binary values (0 or 1) | |
| - uploaded_image: image file from Streamlit uploader (optional) | |
| Returns: | |
| - prediction string | |
| """ | |
| # ===== IMAGE-BASED PREDICTION (optional) ===== | |
| if uploaded_image is not None: | |
| image = Image.open(uploaded_image).convert("RGB") | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), # match your CNN input size | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]) # update channels if needed | |
| ]) | |
| input_tensor = transform(image).unsqueeze(0) # Add batch dimension | |
| # ===== Example Placeholder Logic ===== | |
| # You should replace this with actual CNN model inference | |
| # output = cnn_model(input_tensor) | |
| # predicted = torch.argmax(output, dim=1).item() | |
| # return "Chest Disease Detected" if predicted == 1 else "No Chest Disease Detected" | |
| return "Chest Disease Detected (image-based)" # placeholder for now | |
| # ===== TABULAR PREDICTION ===== | |
| gender_encoded = encode_gender(gender) | |
| position_encoded = encode_view_position(view_position) | |
| if len(conditions) != 10: | |
| raise ValueError("Expected 10 binary values for conditions.") | |
| features = np.array([[age, gender_encoded, position_encoded] + conditions]) | |
| prediction = chest_model.predict(features)[0] | |
| return "Chest Disease Detected" if prediction == 1 else "No Chest Disease Detected" |