from flask import Flask, render_template, request, jsonify import joblib import numpy as np import os import pandas as pd APP_PORT = int(os.getenv("PORT", "8080")) app = Flask(__name__) MODEL_PATH = os.getenv("MODEL_PATH", "model/stroke_pipeline.joblib") # Load model pipeline at startup try: pipeline = joblib.load(MODEL_PATH) except Exception as e: raise RuntimeError(f"Failed to load model at {MODEL_PATH}: {e}") FEATURE_ORDER = [ "gender", "age", "hypertension", "heart_disease", "ever_married", "work_type", "Residence_type", "avg_glucose_level", "bmi", "smoking_status", ] # Simple healthcheck @app.route("/health", methods=["GET"]) def health(): return jsonify({"status": "ok"}), 200 @app.route("/", methods=["GET"]) def index(): # Provide default values to make testing easy defaults = { "gender": "Female", "age": 45, "hypertension": 0, "heart_disease": 0, "ever_married": "Yes", "work_type": "Private", "Residence_type": "Urban", "avg_glucose_level": 95.0, "bmi": 28.0, "smoking_status": "never smoked", } return render_template("index.html", defaults=defaults) @app.route("/predict", methods=["POST"]) def predict(): try: payload = request.get_json() if request.is_json else request.form.to_dict() # Normalize types numeric_fields = ["age", "avg_glucose_level", "bmi"] int_fields = ["hypertension", "heart_disease"] for k in numeric_fields: if k in payload: payload[k] = float(payload[k]) for k in int_fields: if k in payload: payload[k] = int(payload[k]) # ALWAYS send a DataFrame with named columns X = pd.DataFrame([{f: payload.get(f, None) for f in FEATURE_ORDER}])[FEATURE_ORDER] prob = float(pipeline.predict_proba(X)[0][1]) # quick sanity log (optional) # print("X type:", type(X), "cols:", list(X.columns)) pred = int(prob >= 0.3) result = {"stroke_probability": prob, "predicted_label": pred} return jsonify(result) if request.is_json else render_template("index.html", result=result, defaults=payload) except Exception as e: return (jsonify({"error": str(e)}), 400) if request.is_json else \ (render_template("index.html", error=str(e), defaults=payload), 400) if __name__ == "__main__": app.run(host="0.0.0.0", port=APP_PORT, debug=False)