Enterwar99's picture
Update api_app.py
bc70f50 verified
raw
history blame
6.31 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import torch.nn as nn
import io
import numpy as np
import os
from typing import List, Dict, Any
# Importy dla Grad-CAM
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from huggingface_hub import hf_hub_download # Do pobierania modelu z Huba
# --- Konfiguracja ---
# Upewnij si臋, 偶e te warto艣ci s膮 zgodne z Twoim repozytorium modelu
HF_MODEL_REPO_ID = "Enterwar99/MODEL_MAMMOGRAFII"
MODEL_FILENAME = "best_model.pth" # Nazwa pliku modelu w repozytorium
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
# Globalne zmienne dla modelu i transformacji
model_instance = None
transform_pipeline = None
interpretations_dict = {
1: "Wynik negatywny - brak zmian nowotworowych",
2: "Zmiana 艂agodna",
3: "Prawdopodobnie zmiana 艂agodna - zalecana kontrola",
4: "Podejrzenie zmiany z艂o艣liwej - zalecana biopsja",
5: "Wysoka podejrzliwo艣膰 z艂o艣liwo艣ci - wymagana biopsja"
}
# --- Inicjalizacja modelu ---
def initialize_model():
global model_instance, transform_pipeline
if model_instance is not None:
return
print(f"Pobieranie modelu {MODEL_FILENAME} z repozytorium {HF_MODEL_REPO_ID}...")
try:
# Odczytaj token z sekret贸w, je艣li jest dost臋pny
# Nazwa zmiennej 艣rodowiskowej musi by膰 taka sama jak nazwa sekretu w ustawieniach Space
hf_auth_token = os.environ.get("HF_TOKEN_MODEL_READ")
if hf_auth_token:
print("U偶ywam tokenu HF_TOKEN_MODEL_READ do pobrania modelu.")
else:
print("OSTRZE呕ENIE: Sekret HF_TOKEN_MODEL_READ nie zosta艂 znaleziony. Pr贸ba pobrania modelu bez tokenu (mo偶e si臋 nie uda膰 dla prywatnych repozytori贸w).")
model_pt_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=MODEL_FILENAME, token=hf_auth_token)
except Exception as e:
print(f"B艂膮d podczas pobierania modelu z Hugging Face Hub: {e}")
raise RuntimeError(f"Nie mo偶na pobra膰 modelu: {e}")
print(f"Inicjalizacja architektury modelu ResNet-18...")
model_arch = models.resnet18(weights=None)
num_feats = model_arch.fc.in_features
model_arch.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(num_feats, 5)
)
print(f"艁adowanie wag modelu z {model_pt_path}...")
model_arch.load_state_dict(torch.load(model_pt_path, map_location=DEVICE))
model_arch.to(DEVICE)
model_arch.eval()
model_instance = model_arch
transform_pipeline = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])
print(f"Model BI-RADS classifier initialized successfully on device: {DEVICE}")
# --- Aplikacja FastAPI ---
app = FastAPI(title="BI-RADS Mammography Classification API")
@app.on_event("startup")
async def startup_event():
"""Wywo艂ywane przy starcie aplikacji FastAPI."""
initialize_model()
@app.post("/predict/", response_model=List[Dict[str, Any]])
async def predict_image(file: UploadFile = File(...)):
"""
Endpoint do klasyfikacji obrazu mammograficznego.
Oczekuje pliku obrazu (JPG, PNG).
Zwraca list臋 z wynikami (nawet je艣li tylko jeden obraz).
"""
if model_instance is None or transform_pipeline is None:
raise HTTPException(status_code=503, detail="Model nie jest zainicjalizowany. Spr贸buj ponownie za chwil臋.")
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Nie mo偶na odczyta膰 pliku obrazu: {e}")
# Preprocessing
input_tensor = transform_pipeline(image).unsqueeze(0).to(DEVICE)
# Inference
with torch.no_grad():
model_outputs = model_instance(input_tensor)
# Postprocessing
probs = torch.nn.functional.softmax(model_outputs, dim=1)
confidences, predicted_indices = torch.max(probs, 1)
results = []
for i in range(len(predicted_indices)): # P臋tla na wypadek przysz艂ego batch processingu
birads_category = predicted_indices[i].item() + 1
confidence = confidences[i].item()
interpretation = interpretations_dict.get(birads_category, "Nieznana klasyfikacja")
all_class_probs_tensor = probs[i].cpu().numpy()
class_probabilities = {str(j+1): float(all_class_probs_tensor[j]) for j in range(len(all_class_probs_tensor))}
# Generowanie Grad-CAM
grad_cam_map_serialized = None
try:
for param in model_instance.parameters():
param.requires_grad_(True)
model_instance.eval()
target_layers = [model_instance.layer4[-1]] # Dla ResNet-18
cam_algorithm = GradCAMPlusPlus(model=model_instance, target_layers=target_layers)
current_input_tensor_for_cam = input_tensor[i].unsqueeze(0).clone().detach().requires_grad_(True)
targets_for_cam = [ClassifierOutputTarget(predicted_indices[i].item())]
grayscale_cam = cam_algorithm(input_tensor=current_input_tensor_for_cam, targets=targets_for_cam)
if grayscale_cam is not None:
grad_cam_map_np = grayscale_cam[0, :]
grad_cam_map_serialized = grad_cam_map_np.tolist()
except Exception as e:
print(f"B艂膮d podczas generowania Grad-CAM w API: {e}")
results.append({
"birads": birads_category,
"confidence": confidence,
"interpretation": interpretation,
"class_probabilities": class_probabilities,
"grad_cam_map": grad_cam_map_serialized
})
return JSONResponse(content=results)
@app.get("/")
async def root():
return {"message": "Witaj w BI-RADS Classification API! U偶yj endpointu /predict/ do wysy艂ania obraz贸w."}
# Do uruchomienia lokalnie: uvicorn api_app:app --reload