File size: 6,312 Bytes
bc70f50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import torch.nn as nn
import io
import numpy as np
import os
from typing import List, Dict, Any

# Importy dla Grad-CAM
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from huggingface_hub import hf_hub_download # Do pobierania modelu z Huba

# --- Konfiguracja ---
# Upewnij si臋, 偶e te warto艣ci s膮 zgodne z Twoim repozytorium modelu
HF_MODEL_REPO_ID = "Enterwar99/MODEL_MAMMOGRAFII"
MODEL_FILENAME = "best_model.pth" # Nazwa pliku modelu w repozytorium

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

# Globalne zmienne dla modelu i transformacji
model_instance = None
transform_pipeline = None

interpretations_dict = {
    1: "Wynik negatywny - brak zmian nowotworowych",
    2: "Zmiana 艂agodna",
    3: "Prawdopodobnie zmiana 艂agodna - zalecana kontrola",
    4: "Podejrzenie zmiany z艂o艣liwej - zalecana biopsja",
    5: "Wysoka podejrzliwo艣膰 z艂o艣liwo艣ci - wymagana biopsja"
}

# --- Inicjalizacja modelu ---
def initialize_model():
    global model_instance, transform_pipeline
    if model_instance is not None:
        return

    print(f"Pobieranie modelu {MODEL_FILENAME} z repozytorium {HF_MODEL_REPO_ID}...")
    try:
        # Odczytaj token z sekret贸w, je艣li jest dost臋pny
        # Nazwa zmiennej 艣rodowiskowej musi by膰 taka sama jak nazwa sekretu w ustawieniach Space
        hf_auth_token = os.environ.get("HF_TOKEN_MODEL_READ")
        if hf_auth_token:
            print("U偶ywam tokenu HF_TOKEN_MODEL_READ do pobrania modelu.")
        else:
            print("OSTRZE呕ENIE: Sekret HF_TOKEN_MODEL_READ nie zosta艂 znaleziony. Pr贸ba pobrania modelu bez tokenu (mo偶e si臋 nie uda膰 dla prywatnych repozytori贸w).")

        model_pt_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=MODEL_FILENAME, token=hf_auth_token)
    except Exception as e:
        print(f"B艂膮d podczas pobierania modelu z Hugging Face Hub: {e}")
        raise RuntimeError(f"Nie mo偶na pobra膰 modelu: {e}")

    print(f"Inicjalizacja architektury modelu ResNet-18...")
    model_arch = models.resnet18(weights=None)
    num_feats = model_arch.fc.in_features
    model_arch.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(num_feats, 5)
    )
    
    print(f"艁adowanie wag modelu z {model_pt_path}...")
    model_arch.load_state_dict(torch.load(model_pt_path, map_location=DEVICE))
    model_arch.to(DEVICE)
    model_arch.eval()
    model_instance = model_arch

    transform_pipeline = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ])
    print(f"Model BI-RADS classifier initialized successfully on device: {DEVICE}")

# --- Aplikacja FastAPI ---
app = FastAPI(title="BI-RADS Mammography Classification API")

@app.on_event("startup")
async def startup_event():
    """Wywo艂ywane przy starcie aplikacji FastAPI."""
    initialize_model()

@app.post("/predict/", response_model=List[Dict[str, Any]])
async def predict_image(file: UploadFile = File(...)):
    """
    Endpoint do klasyfikacji obrazu mammograficznego.
    Oczekuje pliku obrazu (JPG, PNG).
    Zwraca list臋 z wynikami (nawet je艣li tylko jeden obraz).
    """
    if model_instance is None or transform_pipeline is None:
        raise HTTPException(status_code=503, detail="Model nie jest zainicjalizowany. Spr贸buj ponownie za chwil臋.")

    try:
        contents = await file.read()
        image = Image.open(io.BytesIO(contents)).convert("RGB")
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Nie mo偶na odczyta膰 pliku obrazu: {e}")

    # Preprocessing
    input_tensor = transform_pipeline(image).unsqueeze(0).to(DEVICE)

    # Inference
    with torch.no_grad():
        model_outputs = model_instance(input_tensor)
    
    # Postprocessing
    probs = torch.nn.functional.softmax(model_outputs, dim=1)
    confidences, predicted_indices = torch.max(probs, 1)
    
    results = []
    for i in range(len(predicted_indices)): # P臋tla na wypadek przysz艂ego batch processingu
        birads_category = predicted_indices[i].item() + 1
        confidence = confidences[i].item()
        interpretation = interpretations_dict.get(birads_category, "Nieznana klasyfikacja")
        
        all_class_probs_tensor = probs[i].cpu().numpy()
        class_probabilities = {str(j+1): float(all_class_probs_tensor[j]) for j in range(len(all_class_probs_tensor))}

        # Generowanie Grad-CAM
        grad_cam_map_serialized = None
        try:
            for param in model_instance.parameters():
                param.requires_grad_(True)
            model_instance.eval()

            target_layers = [model_instance.layer4[-1]] # Dla ResNet-18
            cam_algorithm = GradCAMPlusPlus(model=model_instance, target_layers=target_layers)
            
            current_input_tensor_for_cam = input_tensor[i].unsqueeze(0).clone().detach().requires_grad_(True)
            targets_for_cam = [ClassifierOutputTarget(predicted_indices[i].item())]
            
            grayscale_cam = cam_algorithm(input_tensor=current_input_tensor_for_cam, targets=targets_for_cam)
            
            if grayscale_cam is not None:
                grad_cam_map_np = grayscale_cam[0, :]
                grad_cam_map_serialized = grad_cam_map_np.tolist()

        except Exception as e:
            print(f"B艂膮d podczas generowania Grad-CAM w API: {e}")

        results.append({
            "birads": birads_category,
            "confidence": confidence,
            "interpretation": interpretation,
            "class_probabilities": class_probabilities,
            "grad_cam_map": grad_cam_map_serialized
        })
    
    return JSONResponse(content=results)

@app.get("/")
async def root():
    return {"message": "Witaj w BI-RADS Classification API! U偶yj endpointu /predict/ do wysy艂ania obraz贸w."}

# Do uruchomienia lokalnie: uvicorn api_app:app --reload