Enterwar99 commited on
Commit
bc70f50
verified
1 Parent(s): 5919385

Update api_app.py

Browse files
Files changed (1) hide show
  1. api_app.py +161 -153
api_app.py CHANGED
@@ -1,153 +1,161 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException
2
- from fastapi.responses import JSONResponse
3
- import torch
4
- import torchvision.models as models
5
- import torchvision.transforms as transforms
6
- from PIL import Image
7
- import torch.nn as nn
8
- import io
9
- import numpy as np
10
- import os
11
- from typing import List, Dict, Any
12
-
13
- # Importy dla Grad-CAM
14
- from pytorch_grad_cam import GradCAMPlusPlus
15
- from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
16
- from huggingface_hub import hf_hub_download # Do pobierania modelu z Huba
17
-
18
- # --- Konfiguracja ---
19
- # Upewnij si臋, 偶e te warto艣ci s膮 zgodne z Twoim repozytorium modelu
20
- HF_MODEL_REPO_ID = "Enterwar99/MODEL_MAMMOGRAFII"
21
- MODEL_FILENAME = "best_model.pth" # Nazwa pliku modelu w repozytorium
22
-
23
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
- IMAGENET_MEAN = [0.485, 0.456, 0.406]
25
- IMAGENET_STD = [0.229, 0.224, 0.225]
26
-
27
- # Globalne zmienne dla modelu i transformacji
28
- model_instance = None
29
- transform_pipeline = None
30
-
31
- interpretations_dict = {
32
- 1: "Wynik negatywny - brak zmian nowotworowych",
33
- 2: "Zmiana 艂agodna",
34
- 3: "Prawdopodobnie zmiana 艂agodna - zalecana kontrola",
35
- 4: "Podejrzenie zmiany z艂o艣liwej - zalecana biopsja",
36
- 5: "Wysoka podejrzliwo艣膰 z艂o艣liwo艣ci - wymagana biopsja"
37
- }
38
-
39
- # --- Inicjalizacja modelu ---
40
- def initialize_model():
41
- global model_instance, transform_pipeline
42
- if model_instance is not None:
43
- return
44
-
45
- print(f"Pobieranie modelu {MODEL_FILENAME} z repozytorium {HF_MODEL_REPO_ID}...")
46
- try:
47
- model_pt_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=MODEL_FILENAME)
48
- except Exception as e:
49
- print(f"B艂膮d podczas pobierania modelu z Hugging Face Hub: {e}")
50
- raise RuntimeError(f"Nie mo偶na pobra膰 modelu: {e}")
51
-
52
- print(f"Inicjalizacja architektury modelu ResNet-18...")
53
- model_arch = models.resnet18(weights=None)
54
- num_feats = model_arch.fc.in_features
55
- model_arch.fc = nn.Sequential(
56
- nn.Dropout(0.5),
57
- nn.Linear(num_feats, 5)
58
- )
59
-
60
- print(f"艁adowanie wag modelu z {model_pt_path}...")
61
- model_arch.load_state_dict(torch.load(model_pt_path, map_location=DEVICE))
62
- model_arch.to(DEVICE)
63
- model_arch.eval()
64
- model_instance = model_arch
65
-
66
- transform_pipeline = transforms.Compose([
67
- transforms.Resize((224, 224)),
68
- transforms.ToTensor(),
69
- transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
70
- ])
71
- print(f"Model BI-RADS classifier initialized successfully on device: {DEVICE}")
72
-
73
- # --- Aplikacja FastAPI ---
74
- app = FastAPI(title="BI-RADS Mammography Classification API")
75
-
76
- @app.on_event("startup")
77
- async def startup_event():
78
- """Wywo艂ywane przy starcie aplikacji FastAPI."""
79
- initialize_model()
80
-
81
- @app.post("/predict/", response_model=List[Dict[str, Any]])
82
- async def predict_image(file: UploadFile = File(...)):
83
- """
84
- Endpoint do klasyfikacji obrazu mammograficznego.
85
- Oczekuje pliku obrazu (JPG, PNG).
86
- Zwraca list臋 z wynikami (nawet je艣li tylko jeden obraz).
87
- """
88
- if model_instance is None or transform_pipeline is None:
89
- raise HTTPException(status_code=503, detail="Model nie jest zainicjalizowany. Spr贸buj ponownie za chwil臋.")
90
-
91
- try:
92
- contents = await file.read()
93
- image = Image.open(io.BytesIO(contents)).convert("RGB")
94
- except Exception as e:
95
- raise HTTPException(status_code=400, detail=f"Nie mo偶na odczyta膰 pliku obrazu: {e}")
96
-
97
- # Preprocessing
98
- input_tensor = transform_pipeline(image).unsqueeze(0).to(DEVICE)
99
-
100
- # Inference
101
- with torch.no_grad():
102
- model_outputs = model_instance(input_tensor)
103
-
104
- # Postprocessing
105
- probs = torch.nn.functional.softmax(model_outputs, dim=1)
106
- confidences, predicted_indices = torch.max(probs, 1)
107
-
108
- results = []
109
- for i in range(len(predicted_indices)): # P臋tla na wypadek przysz艂ego batch processingu
110
- birads_category = predicted_indices[i].item() + 1
111
- confidence = confidences[i].item()
112
- interpretation = interpretations_dict.get(birads_category, "Nieznana klasyfikacja")
113
-
114
- all_class_probs_tensor = probs[i].cpu().numpy()
115
- class_probabilities = {str(j+1): float(all_class_probs_tensor[j]) for j in range(len(all_class_probs_tensor))}
116
-
117
- # Generowanie Grad-CAM
118
- grad_cam_map_serialized = None
119
- try:
120
- for param in model_instance.parameters():
121
- param.requires_grad_(True)
122
- model_instance.eval()
123
-
124
- target_layers = [model_instance.layer4[-1]] # Dla ResNet-18
125
- cam_algorithm = GradCAMPlusPlus(model=model_instance, target_layers=target_layers)
126
-
127
- current_input_tensor_for_cam = input_tensor[i].unsqueeze(0).clone().detach().requires_grad_(True)
128
- targets_for_cam = [ClassifierOutputTarget(predicted_indices[i].item())]
129
-
130
- grayscale_cam = cam_algorithm(input_tensor=current_input_tensor_for_cam, targets=targets_for_cam)
131
-
132
- if grayscale_cam is not None:
133
- grad_cam_map_np = grayscale_cam[0, :]
134
- grad_cam_map_serialized = grad_cam_map_np.tolist()
135
-
136
- except Exception as e:
137
- print(f"B艂膮d podczas generowania Grad-CAM w API: {e}")
138
-
139
- results.append({
140
- "birads": birads_category,
141
- "confidence": confidence,
142
- "interpretation": interpretation,
143
- "class_probabilities": class_probabilities,
144
- "grad_cam_map": grad_cam_map_serialized
145
- })
146
-
147
- return JSONResponse(content=results)
148
-
149
- @app.get("/")
150
- async def root():
151
- return {"message": "Witaj w BI-RADS Classification API! U偶yj endpointu /predict/ do wysy艂ania obraz贸w."}
152
-
153
- # Do uruchomienia lokalnie: uvicorn api_app:app --reload
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ import torch
4
+ import torchvision.models as models
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image
7
+ import torch.nn as nn
8
+ import io
9
+ import numpy as np
10
+ import os
11
+ from typing import List, Dict, Any
12
+
13
+ # Importy dla Grad-CAM
14
+ from pytorch_grad_cam import GradCAMPlusPlus
15
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
16
+ from huggingface_hub import hf_hub_download # Do pobierania modelu z Huba
17
+
18
+ # --- Konfiguracja ---
19
+ # Upewnij si臋, 偶e te warto艣ci s膮 zgodne z Twoim repozytorium modelu
20
+ HF_MODEL_REPO_ID = "Enterwar99/MODEL_MAMMOGRAFII"
21
+ MODEL_FILENAME = "best_model.pth" # Nazwa pliku modelu w repozytorium
22
+
23
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
25
+ IMAGENET_STD = [0.229, 0.224, 0.225]
26
+
27
+ # Globalne zmienne dla modelu i transformacji
28
+ model_instance = None
29
+ transform_pipeline = None
30
+
31
+ interpretations_dict = {
32
+ 1: "Wynik negatywny - brak zmian nowotworowych",
33
+ 2: "Zmiana 艂agodna",
34
+ 3: "Prawdopodobnie zmiana 艂agodna - zalecana kontrola",
35
+ 4: "Podejrzenie zmiany z艂o艣liwej - zalecana biopsja",
36
+ 5: "Wysoka podejrzliwo艣膰 z艂o艣liwo艣ci - wymagana biopsja"
37
+ }
38
+
39
+ # --- Inicjalizacja modelu ---
40
+ def initialize_model():
41
+ global model_instance, transform_pipeline
42
+ if model_instance is not None:
43
+ return
44
+
45
+ print(f"Pobieranie modelu {MODEL_FILENAME} z repozytorium {HF_MODEL_REPO_ID}...")
46
+ try:
47
+ # Odczytaj token z sekret贸w, je艣li jest dost臋pny
48
+ # Nazwa zmiennej 艣rodowiskowej musi by膰 taka sama jak nazwa sekretu w ustawieniach Space
49
+ hf_auth_token = os.environ.get("HF_TOKEN_MODEL_READ")
50
+ if hf_auth_token:
51
+ print("U偶ywam tokenu HF_TOKEN_MODEL_READ do pobrania modelu.")
52
+ else:
53
+ print("OSTRZE呕ENIE: Sekret HF_TOKEN_MODEL_READ nie zosta艂 znaleziony. Pr贸ba pobrania modelu bez tokenu (mo偶e si臋 nie uda膰 dla prywatnych repozytori贸w).")
54
+
55
+ model_pt_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=MODEL_FILENAME, token=hf_auth_token)
56
+ except Exception as e:
57
+ print(f"B艂膮d podczas pobierania modelu z Hugging Face Hub: {e}")
58
+ raise RuntimeError(f"Nie mo偶na pobra膰 modelu: {e}")
59
+
60
+ print(f"Inicjalizacja architektury modelu ResNet-18...")
61
+ model_arch = models.resnet18(weights=None)
62
+ num_feats = model_arch.fc.in_features
63
+ model_arch.fc = nn.Sequential(
64
+ nn.Dropout(0.5),
65
+ nn.Linear(num_feats, 5)
66
+ )
67
+
68
+ print(f"艁adowanie wag modelu z {model_pt_path}...")
69
+ model_arch.load_state_dict(torch.load(model_pt_path, map_location=DEVICE))
70
+ model_arch.to(DEVICE)
71
+ model_arch.eval()
72
+ model_instance = model_arch
73
+
74
+ transform_pipeline = transforms.Compose([
75
+ transforms.Resize((224, 224)),
76
+ transforms.ToTensor(),
77
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
78
+ ])
79
+ print(f"Model BI-RADS classifier initialized successfully on device: {DEVICE}")
80
+
81
+ # --- Aplikacja FastAPI ---
82
+ app = FastAPI(title="BI-RADS Mammography Classification API")
83
+
84
+ @app.on_event("startup")
85
+ async def startup_event():
86
+ """Wywo艂ywane przy starcie aplikacji FastAPI."""
87
+ initialize_model()
88
+
89
+ @app.post("/predict/", response_model=List[Dict[str, Any]])
90
+ async def predict_image(file: UploadFile = File(...)):
91
+ """
92
+ Endpoint do klasyfikacji obrazu mammograficznego.
93
+ Oczekuje pliku obrazu (JPG, PNG).
94
+ Zwraca list臋 z wynikami (nawet je艣li tylko jeden obraz).
95
+ """
96
+ if model_instance is None or transform_pipeline is None:
97
+ raise HTTPException(status_code=503, detail="Model nie jest zainicjalizowany. Spr贸buj ponownie za chwil臋.")
98
+
99
+ try:
100
+ contents = await file.read()
101
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
102
+ except Exception as e:
103
+ raise HTTPException(status_code=400, detail=f"Nie mo偶na odczyta膰 pliku obrazu: {e}")
104
+
105
+ # Preprocessing
106
+ input_tensor = transform_pipeline(image).unsqueeze(0).to(DEVICE)
107
+
108
+ # Inference
109
+ with torch.no_grad():
110
+ model_outputs = model_instance(input_tensor)
111
+
112
+ # Postprocessing
113
+ probs = torch.nn.functional.softmax(model_outputs, dim=1)
114
+ confidences, predicted_indices = torch.max(probs, 1)
115
+
116
+ results = []
117
+ for i in range(len(predicted_indices)): # P臋tla na wypadek przysz艂ego batch processingu
118
+ birads_category = predicted_indices[i].item() + 1
119
+ confidence = confidences[i].item()
120
+ interpretation = interpretations_dict.get(birads_category, "Nieznana klasyfikacja")
121
+
122
+ all_class_probs_tensor = probs[i].cpu().numpy()
123
+ class_probabilities = {str(j+1): float(all_class_probs_tensor[j]) for j in range(len(all_class_probs_tensor))}
124
+
125
+ # Generowanie Grad-CAM
126
+ grad_cam_map_serialized = None
127
+ try:
128
+ for param in model_instance.parameters():
129
+ param.requires_grad_(True)
130
+ model_instance.eval()
131
+
132
+ target_layers = [model_instance.layer4[-1]] # Dla ResNet-18
133
+ cam_algorithm = GradCAMPlusPlus(model=model_instance, target_layers=target_layers)
134
+
135
+ current_input_tensor_for_cam = input_tensor[i].unsqueeze(0).clone().detach().requires_grad_(True)
136
+ targets_for_cam = [ClassifierOutputTarget(predicted_indices[i].item())]
137
+
138
+ grayscale_cam = cam_algorithm(input_tensor=current_input_tensor_for_cam, targets=targets_for_cam)
139
+
140
+ if grayscale_cam is not None:
141
+ grad_cam_map_np = grayscale_cam[0, :]
142
+ grad_cam_map_serialized = grad_cam_map_np.tolist()
143
+
144
+ except Exception as e:
145
+ print(f"B艂膮d podczas generowania Grad-CAM w API: {e}")
146
+
147
+ results.append({
148
+ "birads": birads_category,
149
+ "confidence": confidence,
150
+ "interpretation": interpretation,
151
+ "class_probabilities": class_probabilities,
152
+ "grad_cam_map": grad_cam_map_serialized
153
+ })
154
+
155
+ return JSONResponse(content=results)
156
+
157
+ @app.get("/")
158
+ async def root():
159
+ return {"message": "Witaj w BI-RADS Classification API! U偶yj endpointu /predict/ do wysy艂ania obraz贸w."}
160
+
161
+ # Do uruchomienia lokalnie: uvicorn api_app:app --reload