Enterwar99 commited on
Commit
e915833
verified
1 Parent(s): 35bb529

Upload 3 files

Browse files
Files changed (3) hide show
  1. DOCKERFILE +21 -0
  2. api_app.py +153 -0
  3. requirements_api_space.txt +9 -0
DOCKERFILE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements_api_space.txt .
6
+ RUN pip install --no-cache-dir -r requirements_api_space.txt
7
+
8
+ COPY api_app.py .
9
+ # Mo偶esz skopiowa膰 inne potrzebne pliki, je艣li s膮
10
+
11
+ # Port, na kt贸rym FastAPI b臋dzie nas艂uchiwa膰 (domy艣lnie 8000 dla Uvicorn)
12
+ # Hugging Face Spaces automatycznie mapuje port 7860, ale Uvicorn domy艣lnie u偶ywa 8000.
13
+ # Mo偶emy to dostosowa膰 w poleceniu CMD lub pozwoli膰 HF na mapowanie.
14
+ # Dla szablonu FastAPI, HF Spaces cz臋sto oczekuje, 偶e aplikacja dzia艂a na porcie 7860.
15
+ # Sprawd藕 dokumentacj臋 szablonu lub dostosuj CMD.
16
+ # Uvicorn domy艣lnie uruchamia si臋 na porcie 8000.
17
+ # Mo偶emy to zmieni膰 na 7860, je艣li HF Spaces tego wymaga dla automatycznego wykrywania.
18
+ ENV PORT 7860
19
+ EXPOSE 7860
20
+
21
+ CMD ["uvicorn", "api_app:app", "--host", "0.0.0.0", "--port", "7860"]
api_app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ import torch
4
+ import torchvision.models as models
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image
7
+ import torch.nn as nn
8
+ import io
9
+ import numpy as np
10
+ import os
11
+ from typing import List, Dict, Any
12
+
13
+ # Importy dla Grad-CAM
14
+ from pytorch_grad_cam import GradCAMPlusPlus
15
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
16
+ from huggingface_hub import hf_hub_download # Do pobierania modelu z Huba
17
+
18
+ # --- Konfiguracja ---
19
+ # Upewnij si臋, 偶e te warto艣ci s膮 zgodne z Twoim repozytorium modelu
20
+ HF_MODEL_REPO_ID = "Enterwar99/MODEL_MAMMOGRAFII"
21
+ MODEL_FILENAME = "best_model.pth" # Nazwa pliku modelu w repozytorium
22
+
23
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
25
+ IMAGENET_STD = [0.229, 0.224, 0.225]
26
+
27
+ # Globalne zmienne dla modelu i transformacji
28
+ model_instance = None
29
+ transform_pipeline = None
30
+
31
+ interpretations_dict = {
32
+ 1: "Wynik negatywny - brak zmian nowotworowych",
33
+ 2: "Zmiana 艂agodna",
34
+ 3: "Prawdopodobnie zmiana 艂agodna - zalecana kontrola",
35
+ 4: "Podejrzenie zmiany z艂o艣liwej - zalecana biopsja",
36
+ 5: "Wysoka podejrzliwo艣膰 z艂o艣liwo艣ci - wymagana biopsja"
37
+ }
38
+
39
+ # --- Inicjalizacja modelu ---
40
+ def initialize_model():
41
+ global model_instance, transform_pipeline
42
+ if model_instance is not None:
43
+ return
44
+
45
+ print(f"Pobieranie modelu {MODEL_FILENAME} z repozytorium {HF_MODEL_REPO_ID}...")
46
+ try:
47
+ model_pt_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=MODEL_FILENAME)
48
+ except Exception as e:
49
+ print(f"B艂膮d podczas pobierania modelu z Hugging Face Hub: {e}")
50
+ raise RuntimeError(f"Nie mo偶na pobra膰 modelu: {e}")
51
+
52
+ print(f"Inicjalizacja architektury modelu ResNet-18...")
53
+ model_arch = models.resnet18(weights=None)
54
+ num_feats = model_arch.fc.in_features
55
+ model_arch.fc = nn.Sequential(
56
+ nn.Dropout(0.5),
57
+ nn.Linear(num_feats, 5)
58
+ )
59
+
60
+ print(f"艁adowanie wag modelu z {model_pt_path}...")
61
+ model_arch.load_state_dict(torch.load(model_pt_path, map_location=DEVICE))
62
+ model_arch.to(DEVICE)
63
+ model_arch.eval()
64
+ model_instance = model_arch
65
+
66
+ transform_pipeline = transforms.Compose([
67
+ transforms.Resize((224, 224)),
68
+ transforms.ToTensor(),
69
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
70
+ ])
71
+ print(f"Model BI-RADS classifier initialized successfully on device: {DEVICE}")
72
+
73
+ # --- Aplikacja FastAPI ---
74
+ app = FastAPI(title="BI-RADS Mammography Classification API")
75
+
76
+ @app.on_event("startup")
77
+ async def startup_event():
78
+ """Wywo艂ywane przy starcie aplikacji FastAPI."""
79
+ initialize_model()
80
+
81
+ @app.post("/predict/", response_model=List[Dict[str, Any]])
82
+ async def predict_image(file: UploadFile = File(...)):
83
+ """
84
+ Endpoint do klasyfikacji obrazu mammograficznego.
85
+ Oczekuje pliku obrazu (JPG, PNG).
86
+ Zwraca list臋 z wynikami (nawet je艣li tylko jeden obraz).
87
+ """
88
+ if model_instance is None or transform_pipeline is None:
89
+ raise HTTPException(status_code=503, detail="Model nie jest zainicjalizowany. Spr贸buj ponownie za chwil臋.")
90
+
91
+ try:
92
+ contents = await file.read()
93
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
94
+ except Exception as e:
95
+ raise HTTPException(status_code=400, detail=f"Nie mo偶na odczyta膰 pliku obrazu: {e}")
96
+
97
+ # Preprocessing
98
+ input_tensor = transform_pipeline(image).unsqueeze(0).to(DEVICE)
99
+
100
+ # Inference
101
+ with torch.no_grad():
102
+ model_outputs = model_instance(input_tensor)
103
+
104
+ # Postprocessing
105
+ probs = torch.nn.functional.softmax(model_outputs, dim=1)
106
+ confidences, predicted_indices = torch.max(probs, 1)
107
+
108
+ results = []
109
+ for i in range(len(predicted_indices)): # P臋tla na wypadek przysz艂ego batch processingu
110
+ birads_category = predicted_indices[i].item() + 1
111
+ confidence = confidences[i].item()
112
+ interpretation = interpretations_dict.get(birads_category, "Nieznana klasyfikacja")
113
+
114
+ all_class_probs_tensor = probs[i].cpu().numpy()
115
+ class_probabilities = {str(j+1): float(all_class_probs_tensor[j]) for j in range(len(all_class_probs_tensor))}
116
+
117
+ # Generowanie Grad-CAM
118
+ grad_cam_map_serialized = None
119
+ try:
120
+ for param in model_instance.parameters():
121
+ param.requires_grad_(True)
122
+ model_instance.eval()
123
+
124
+ target_layers = [model_instance.layer4[-1]] # Dla ResNet-18
125
+ cam_algorithm = GradCAMPlusPlus(model=model_instance, target_layers=target_layers)
126
+
127
+ current_input_tensor_for_cam = input_tensor[i].unsqueeze(0).clone().detach().requires_grad_(True)
128
+ targets_for_cam = [ClassifierOutputTarget(predicted_indices[i].item())]
129
+
130
+ grayscale_cam = cam_algorithm(input_tensor=current_input_tensor_for_cam, targets=targets_for_cam)
131
+
132
+ if grayscale_cam is not None:
133
+ grad_cam_map_np = grayscale_cam[0, :]
134
+ grad_cam_map_serialized = grad_cam_map_np.tolist()
135
+
136
+ except Exception as e:
137
+ print(f"B艂膮d podczas generowania Grad-CAM w API: {e}")
138
+
139
+ results.append({
140
+ "birads": birads_category,
141
+ "confidence": confidence,
142
+ "interpretation": interpretation,
143
+ "class_probabilities": class_probabilities,
144
+ "grad_cam_map": grad_cam_map_serialized
145
+ })
146
+
147
+ return JSONResponse(content=results)
148
+
149
+ @app.get("/")
150
+ async def root():
151
+ return {"message": "Witaj w BI-RADS Classification API! U偶yj endpointu /predict/ do wysy艂ania obraz贸w."}
152
+
153
+ # Do uruchomienia lokalnie: uvicorn api_app:app --reload
requirements_api_space.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ torch
4
+ torchvision
5
+ Pillow
6
+ pytorch-grad-cam
7
+ numpy
8
+ huggingface_hub
9
+ python-multipart # Potrzebne dla FastAPI do obs艂ugi wgrywania plik贸w