WissMah commited on
Commit
fbec100
·
verified ·
1 Parent(s): d028f2b

Update model/train_and_save.py

Browse files
Files changed (1) hide show
  1. model/train_and_save.py +142 -119
model/train_and_save.py CHANGED
@@ -1,119 +1,142 @@
1
- """
2
- Train & save a full sklearn Pipeline for stroke prediction.
3
-
4
- - If ./data/healthcare-dataset-stroke-data.csv exists, trains on it (matching the notebook structure).
5
- - Otherwise, trains on a synthetic dataset with the same schema.
6
- Saves: model/stroke_pipeline.joblib
7
- """
8
- from pathlib import Path
9
- import pandas as pd
10
- import numpy as np
11
- import joblib
12
-
13
- from sklearn.compose import ColumnTransformer
14
- from sklearn.preprocessing import OneHotEncoder, StandardScaler
15
- from sklearn.impute import SimpleImputer
16
- from sklearn.linear_model import LogisticRegression
17
- from sklearn.pipeline import Pipeline
18
- from sklearn.model_selection import train_test_split
19
- from sklearn.metrics import classification_report, roc_auc_score
20
-
21
- DATA_PATH = Path("C:\Users\wissa\Downloads\data\stroke-flask-docker\data\healthcare-dataset-stroke-data.csv")
22
- OUT_PATH = Path("C:\Users\wissa\Downloads\data\stroke-flask-docker\model/stroke_pipeline.joblib")
23
- OUT_PATH.parent.mkdir(parents=True, exist_ok=True)
24
-
25
- CATEGORICAL = ["gender","ever_married","work_type","Residence_type","smoking_status"]
26
- NUMERIC = ["age","avg_glucose_level","bmi"]
27
- BINARY_INT = ["hypertension","heart_disease"] # keep as numeric ints
28
-
29
- def load_real_or_synthetic():
30
- if DATA_PATH.exists():
31
- df = pd.read_csv(DATA_PATH)
32
- # expected columns from the Kaggle stroke dataset
33
- must_have = ["gender","age","hypertension","heart_disease","ever_married",
34
- "work_type","Residence_type","avg_glucose_level","bmi",
35
- "smoking_status","stroke"]
36
- missing = set(must_have) - set(df.columns)
37
- if missing:
38
- raise ValueError(f"Dataset is missing columns: {missing}")
39
- # drop id if present
40
- df = df[[c for c in df.columns if c in must_have]]
41
- return df
42
- else:
43
- # Synthetic data with the right columns
44
- rng = np.random.RandomState(42)
45
- N = 2000
46
- df = pd.DataFrame({
47
- "gender": rng.choice(["Male","Female","Other"], size=N, p=[0.49,0.50,0.01]),
48
- "age": rng.randint(1, 90, size=N),
49
- "hypertension": rng.binomial(1, 0.15, size=N),
50
- "heart_disease": rng.binomial(1, 0.08, size=N),
51
- "ever_married": rng.choice(["Yes","No"], size=N, p=[0.7,0.3]),
52
- "work_type": rng.choice(["Private","Self-employed","Govt_job","children","Never_worked"], size=N, p=[0.6,0.2,0.18,0.01,0.01]),
53
- "Residence_type": rng.choice(["Urban","Rural"], size=N, p=[0.55,0.45]),
54
- "avg_glucose_level": rng.normal(100, 30, size=N).clip(50, 300),
55
- "bmi": rng.normal(28, 6, size=N).clip(10, 60),
56
- "smoking_status": rng.choice(["formerly smoked","never smoked","smokes","Unknown"], size=N, p=[0.2,0.6,0.15,0.05]),
57
- })
58
- # Fabricate a signal for stroke
59
- logit = (
60
- 0.03*df["age"] +
61
- 0.02*(df["avg_glucose_level"]-100) +
62
- 0.05*(df["bmi"]-28) +
63
- 0.8*df["hypertension"] +
64
- 0.9*df["heart_disease"] +
65
- 0.3*(df["ever_married"]=="Yes").astype(int)
66
- )
67
- prob = 1/(1+np.exp(- (logit-4.0))) # bias to keep prevalence low
68
- df["stroke"] = (rng.rand(len(df)) < prob).astype(int)
69
- return df
70
-
71
- def build_pipeline():
72
- cat_proc = Pipeline(steps=[
73
- ("impute", SimpleImputer(strategy="most_frequent")),
74
- ("ohe", OneHotEncoder(handle_unknown="ignore"))
75
- ])
76
- num_proc = Pipeline(steps=[
77
- ("impute", SimpleImputer(strategy="median")),
78
- ("scale", StandardScaler())
79
- ])
80
- # Binary int -> treat as numeric (no scaling needed, but fine to scale)
81
- bin_proc = Pipeline(steps=[
82
- ("impute", SimpleImputer(strategy="most_frequent")),
83
- ("scale", StandardScaler(with_mean=False)) # keep sparse-friendly path
84
- ])
85
-
86
- pre = ColumnTransformer(transformers=[
87
- ("cat", cat_proc, CATEGORICAL),
88
- ("num", num_proc, NUMERIC),
89
- ("bin", bin_proc, BINARY_INT),
90
- ])
91
-
92
- clf = LogisticRegression(max_iter=1000, n_jobs=None)
93
- pipeline = Pipeline([("pre", pre), ("clf", clf)])
94
- return pipeline
95
-
96
- def main():
97
- df = load_real_or_synthetic()
98
-
99
- X = df.drop(columns=["stroke"])
100
- y = df["stroke"].astype(int)
101
-
102
- X_train, X_test, y_train, y_test = train_test_split(
103
- X, y, test_size=0.2, random_state=42, stratify=y
104
- )
105
-
106
- pipeline = build_pipeline()
107
- pipeline.fit(X_train, y_train)
108
-
109
- y_prob = pipeline.predict_proba(X_test)[:,1]
110
- y_pred = (y_prob >= 0.5).astype(int)
111
-
112
- print("AUC:", roc_auc_score(y_test, y_prob))
113
- print("Report:\n", classification_report(y_test, y_pred))
114
-
115
- joblib.dump(pipeline, OUT_PATH)
116
- print(f"Saved pipeline to {OUT_PATH.resolve()}")
117
-
118
- if __name__ == "__main__":
119
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import required libraries
2
+ from pathlib import Path
3
+ import pandas as pd
4
+ import numpy as np
5
+ import joblib
6
+
7
+ # Scikit-learn imports for building ML pipeline
8
+ from sklearn.compose import ColumnTransformer
9
+ from sklearn.preprocessing import OneHotEncoder, StandardScaler
10
+ from sklearn.impute import SimpleImputer
11
+ from sklearn.linear_model import LogisticRegression
12
+ from sklearn.pipeline import Pipeline
13
+ from sklearn.model_selection import train_test_split
14
+ from sklearn.metrics import classification_report, roc_auc_score
15
+
16
+ DATA_PATH = Path("C:\Users\wissa\Downloads\data\stroke-flask-docker\data\healthcare-dataset-stroke-data.csv")
17
+ OUT_PATH = Path("C:\Users\wissa\Downloads\data\stroke-flask-docker\model\stroke_pipeline.joblib")
18
+ OUT_PATH.parent.mkdir(parents=True, exist_ok=True) # Make sure output folder exists
19
+
20
+ # Define feature groups
21
+ CATEGORICAL = ["gender","ever_married","work_type","Residence_type","smoking_status"]
22
+ NUMERIC = ["age","avg_glucose_level","bmi"]
23
+ BINARY_INT = ["hypertension","heart_disease"] # Already numeric (0/1), but treated separately
24
+
25
+
26
+ def load_real_or_synthetic():
27
+ if DATA_PATH.exists():
28
+ # Load dataset from CSV
29
+ df = pd.read_csv(DATA_PATH)
30
+
31
+ # Define which columns we MUST have
32
+ must_have = ["gender","age","hypertension","heart_disease","ever_married",
33
+ "work_type","Residence_type","avg_glucose_level","bmi",
34
+ "smoking_status","stroke"]
35
+
36
+ # Check if any required columns are missing
37
+ missing = set(must_have) - set(df.columns)
38
+ if missing:
39
+ raise ValueError(f"Dataset is missing columns: {missing}")
40
+
41
+ # Drop extra columns like "id" if present, keep only required ones
42
+ df = df[[c for c in df.columns if c in must_have]]
43
+ return df
44
+ else:
45
+ # If dataset file is not found, generate synthetic (random but realistic) data
46
+ rng = np.random.RandomState(42) # Random seed for reproducibility
47
+ N = 2000 # number of synthetic rows
48
+
49
+ # Generate random values for each feature
50
+ df = pd.DataFrame({
51
+ "gender": rng.choice(["Male","Female","Other"], size=N, p=[0.49,0.50,0.01]),
52
+ "age": rng.randint(1, 90, size=N),
53
+ "hypertension": rng.binomial(1, 0.15, size=N), # 15% chance of hypertension
54
+ "heart_disease": rng.binomial(1, 0.08, size=N), # 8% chance of heart disease
55
+ "ever_married": rng.choice(["Yes","No"], size=N, p=[0.7,0.3]),
56
+ "work_type": rng.choice(["Private","Self-employed","Govt_job","children","Never_worked"],
57
+ size=N, p=[0.6,0.2,0.18,0.01,0.01]),
58
+ "Residence_type": rng.choice(["Urban","Rural"], size=N, p=[0.55,0.45]),
59
+ "avg_glucose_level": rng.normal(100, 30, size=N).clip(50, 300), # realistic range
60
+ "bmi": rng.normal(28, 6, size=N).clip(10, 60),
61
+ "smoking_status": rng.choice(["formerly smoked","never smoked","smokes","Unknown"],
62
+ size=N, p=[0.2,0.6,0.15,0.05]),
63
+ })
64
+
65
+ # Define a "logit" (linear combination of features) that influences stroke probability
66
+ logit = (
67
+ 0.03*df["age"] +
68
+ 0.02*(df["avg_glucose_level"]-100) +
69
+ 0.05*(df["bmi"]-28) +
70
+ 0.8*df["hypertension"] +
71
+ 0.9*df["heart_disease"] +
72
+ 0.3*(df["ever_married"]=="Yes").astype(int)
73
+ )
74
+
75
+ # Convert logit to probability using sigmoid function
76
+ prob = 1/(1+np.exp(- (logit-4.0))) # shift so stroke is rare (imbalanced dataset)
77
+
78
+ # Assign stroke label (1 = stroke, 0 = no stroke) based on probability
79
+ df["stroke"] = (rng.rand(len(df)) < prob).astype(int)
80
+ return df
81
+
82
+
83
+ def build_pipeline():
84
+ # For categorical features: fill missing with most frequent, then one-hot encode
85
+ cat_proc = Pipeline(steps=[
86
+ ("impute", SimpleImputer(strategy="most_frequent")),
87
+ ("ohe", OneHotEncoder(handle_unknown="ignore"))
88
+ ])
89
+
90
+ # For numeric features: fill missing with median, then scale to mean=0, std=1
91
+ num_proc = Pipeline(steps=[
92
+ ("impute", SimpleImputer(strategy="median")),
93
+ ("scale", StandardScaler())
94
+ ])
95
+
96
+ # For binary integer features: impute, then scale (optional but safe for pipeline)
97
+ bin_proc = Pipeline(steps=[
98
+ ("impute", SimpleImputer(strategy="most_frequent")),
99
+ ("scale", StandardScaler(with_mean=False)) # keep sparse-friendly format
100
+ ])
101
+
102
+ # Combine all processors into one column transformer
103
+ pre = ColumnTransformer(transformers=[
104
+ ("cat", cat_proc, CATEGORICAL),
105
+ ("num", num_proc, NUMERIC),
106
+ ("bin", bin_proc, BINARY_INT),
107
+ ])
108
+
109
+ # Define classifier (logistic regression for binary classification)
110
+ clf = LogisticRegression(max_iter=1000, n_jobs=None)
111
+
112
+ # Final pipeline: preprocessing → model
113
+ pipeline = Pipeline([("pre", pre), ("clf", clf)])
114
+ return pipeline
115
+
116
+
117
+ def main():
118
+ df = load_real_or_synthetic()
119
+
120
+ # Split into features (X) and target (y = stroke)
121
+ X = df.drop(columns=["stroke"])
122
+ y = df["stroke"].astype(int)
123
+
124
+ X_train, X_test, y_train, y_test = train_test_split(
125
+ X, y, test_size=0.2, random_state=42, stratify=y
126
+ )
127
+
128
+ pipeline = build_pipeline()
129
+ pipeline.fit(X_train, y_train)
130
+
131
+ y_prob = pipeline.predict_proba(X_test)[:,1] # probability of stroke
132
+ y_pred = (y_prob >= 0.3).astype(int) # classify as 1 if prob ≥ 0.3
133
+
134
+ print("AUC:", roc_auc_score(y_test, y_prob)) # area under ROC curve
135
+ print("Report:\n", classification_report(y_test, y_pred)) # precision/recall/F1
136
+
137
+ joblib.dump(pipeline, OUT_PATH)
138
+ print(f"Saved pipeline to {OUT_PATH.resolve()}")
139
+
140
+
141
+ if __name__ == "__main__":
142
+ main()