WissMah commited on
Commit
1fb619d
·
verified ·
1 Parent(s): 08fbc2b

Update stroke-flask-docker/app.py

Browse files
Files changed (1) hide show
  1. stroke-flask-docker/app.py +95 -94
stroke-flask-docker/app.py CHANGED
@@ -1,94 +1,95 @@
1
- from flask import Flask, render_template, request, jsonify
2
- import joblib
3
- import numpy as np
4
- import os
5
-
6
- APP_PORT = int(os.getenv("PORT", "8080"))
7
-
8
- app = Flask(__name__)
9
-
10
- MODEL_PATH = os.getenv("MODEL_PATH", "model/stroke_pipeline.joblib")
11
-
12
- # Load model pipeline at startup
13
- try:
14
- pipeline = joblib.load(MODEL_PATH)
15
- except Exception as e:
16
- raise RuntimeError(f"Failed to load model at {MODEL_PATH}: {e}")
17
-
18
- FEATURE_ORDER = [
19
- "gender",
20
- "age",
21
- "hypertension",
22
- "heart_disease",
23
- "ever_married",
24
- "work_type",
25
- "Residence_type",
26
- "avg_glucose_level",
27
- "bmi",
28
- "smoking_status",
29
- ]
30
-
31
- # Simple healthcheck
32
- @app.route("/health", methods=["GET"])
33
- def health():
34
- return jsonify({"status": "ok"}), 200
35
-
36
- @app.route("/", methods=["GET"])
37
- def index():
38
- # Provide default values to make testing easy
39
- defaults = {
40
- "gender": "Female",
41
- "age": 45,
42
- "hypertension": 0,
43
- "heart_disease": 0,
44
- "ever_married": "Yes",
45
- "work_type": "Private",
46
- "Residence_type": "Urban",
47
- "avg_glucose_level": 95.0,
48
- "bmi": 28.0,
49
- "smoking_status": "never smoked",
50
- }
51
- return render_template("index.html", defaults=defaults)
52
-
53
- @app.route("/predict", methods=["POST"])
54
- def predict():
55
- try:
56
- # Read input either from JSON (API) or form (UI)
57
- if request.is_json:
58
- payload = request.get_json()
59
- else:
60
- payload = request.form.to_dict()
61
-
62
- # Ensure types
63
- # Map numeric fields
64
- numeric_fields = ["age", "avg_glucose_level", "bmi"]
65
- int_fields = ["hypertension", "heart_disease"]
66
-
67
- for k in numeric_fields:
68
- if k in payload:
69
- payload[k] = float(payload[k])
70
- for k in int_fields:
71
- if k in payload:
72
- payload[k] = int(payload[k])
73
-
74
- # Build row in fixed feature order
75
- row = [[payload.get(f, None) for f in FEATURE_ORDER]]
76
-
77
- # Predict proba (stroke = 1)
78
- prob = float(pipeline.predict_proba(row)[0][1])
79
- pred = int(prob >= 0.5)
80
-
81
- result = {"stroke_probability": prob, "predicted_label": pred}
82
- if request.is_json:
83
- return jsonify(result)
84
- else:
85
- return render_template("index.html", result=result, defaults=payload)
86
- except Exception as e:
87
- msg = {"error": str(e)}
88
- if request.is_json:
89
- return jsonify(msg), 400
90
- else:
91
- return render_template("index.html", error=str(e), defaults=request.form), 400
92
-
93
- if __name__ == "__main__":
94
- app.run(host="0.0.0.0", port=APP_PORT, debug=False)
 
 
1
+ from flask import Flask, render_template, request, jsonify
2
+ import joblib
3
+ import numpy as np
4
+ import os
5
+
6
+ APP_PORT = int(os.getenv("PORT", "8080"))
7
+
8
+ app = Flask(__name__)
9
+
10
+ MODEL_PATH = os.getenv("MODEL_PATH", "model/stroke_pipeline.joblib")
11
+
12
+ # Load model pipeline at startup
13
+ try:
14
+ pipeline = joblib.load(MODEL_PATH)
15
+ except Exception as e:
16
+ raise RuntimeError(f"Failed to load model at {MODEL_PATH}: {e}")
17
+
18
+ FEATURE_ORDER = [
19
+ "gender",
20
+ "age",
21
+ "hypertension",
22
+ "heart_disease",
23
+ "ever_married",
24
+ "work_type",
25
+ "Residence_type",
26
+ "avg_glucose_level",
27
+ "bmi",
28
+ "smoking_status",
29
+ ]
30
+
31
+ # Simple healthcheck
32
+ @app.route("/health", methods=["GET"])
33
+ def health():
34
+ return jsonify({"status": "ok"}), 200
35
+
36
+ @app.route("/", methods=["GET"])
37
+ def index():
38
+ # Provide default values to make testing easy
39
+ defaults = {
40
+ "gender": "Female",
41
+ "age": 45,
42
+ "hypertension": 0,
43
+ "heart_disease": 0,
44
+ "ever_married": "Yes",
45
+ "work_type": "Private",
46
+ "Residence_type": "Urban",
47
+ "avg_glucose_level": 95.0,
48
+ "bmi": 28.0,
49
+ "smoking_status": "never smoked",
50
+ }
51
+ return render_template("index.html", defaults=defaults)
52
+
53
+ @app.route("/predict", methods=["POST"])
54
+ def predict():
55
+ try:
56
+ # Read input either from JSON (API) or form (UI)
57
+ if request.is_json:
58
+ payload = request.get_json()
59
+ else:
60
+ payload = request.form.to_dict()
61
+
62
+ # Ensure types
63
+ # Map numeric fields
64
+ numeric_fields = ["age", "avg_glucose_level", "bmi"]
65
+ int_fields = ["hypertension", "heart_disease"]
66
+
67
+ for k in numeric_fields:
68
+ if k in payload:
69
+ payload[k] = float(payload[k])
70
+ for k in int_fields:
71
+ if k in payload:
72
+ payload[k] = int(payload[k])
73
+
74
+ # Build row in fixed feature order
75
+ X = pd.DataFrame([{f: payload.get(f, None) for f in FEATURE_ORDER}])[FEATURE_ORDER]
76
+
77
+
78
+ # Predict proba (stroke = 1)
79
+ prob = float(pipeline.predict_proba(X)[0][1])
80
+ pred = int(prob >= 0.5)
81
+
82
+ result = {"stroke_probability": prob, "predicted_label": pred}
83
+ if request.is_json:
84
+ return jsonify(result)
85
+ else:
86
+ return render_template("index.html", result=result, defaults=payload)
87
+ except Exception as e:
88
+ msg = {"error": str(e)}
89
+ if request.is_json:
90
+ return jsonify(msg), 400
91
+ else:
92
+ return render_template("index.html", error=str(e), defaults=request.form), 400
93
+
94
+ if __name__ == "__main__":
95
+ app.run(host="0.0.0.0", port=APP_PORT, debug=False)