from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse import torch import torchvision.models as models import torchvision.transforms as transforms from PIL import Image import torch.nn as nn import io import numpy as np import os from typing import List, Dict, Any # Importy dla Grad-CAM from pytorch_grad_cam import GradCAMPlusPlus from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from huggingface_hub import hf_hub_download # Do pobierania modelu z Huba # --- Konfiguracja --- # Upewnij się, że te wartości są zgodne z Twoim repozytorium modelu HF_MODEL_REPO_ID = "Enterwar99/MODEL_MAMMOGRAFII" MODEL_FILENAME = "best_model.pth" # Nazwa pliku modelu w repozytorium DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] # Globalne zmienne dla modelu i transformacji model_instance = None transform_pipeline = None interpretations_dict = { 1: "Wynik negatywny - brak zmian nowotworowych", 2: "Zmiana łagodna", 3: "Prawdopodobnie zmiana łagodna - zalecana kontrola", 4: "Podejrzenie zmiany złośliwej - zalecana biopsja", 5: "Wysoka podejrzliwość złośliwości - wymagana biopsja" } # --- Inicjalizacja modelu --- def initialize_model(): global model_instance, transform_pipeline if model_instance is not None: return print(f"Pobieranie modelu {MODEL_FILENAME} z repozytorium {HF_MODEL_REPO_ID}...") try: # Odczytaj token z sekretów, jeśli jest dostępny # Nazwa zmiennej środowiskowej musi być taka sama jak nazwa sekretu w ustawieniach Space hf_auth_token = os.environ.get("HF_TOKEN_MODEL_READ") if hf_auth_token: print("Używam tokenu HF_TOKEN_MODEL_READ do pobrania modelu.") else: print("OSTRZEŻENIE: Sekret HF_TOKEN_MODEL_READ nie został znaleziony. Próba pobrania modelu bez tokenu (może się nie udać dla prywatnych repozytoriów).") model_pt_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename=MODEL_FILENAME, token=hf_auth_token) except Exception as e: print(f"Błąd podczas pobierania modelu z Hugging Face Hub: {e}") raise RuntimeError(f"Nie można pobrać modelu: {e}") print(f"Inicjalizacja architektury modelu ResNet-18...") model_arch = models.resnet18(weights=None) num_feats = model_arch.fc.in_features model_arch.fc = nn.Sequential( nn.Dropout(0.5), nn.Linear(num_feats, 5) ) print(f"Ładowanie wag modelu z {model_pt_path}...") model_arch.load_state_dict(torch.load(model_pt_path, map_location=DEVICE)) model_arch.to(DEVICE) model_arch.eval() model_instance = model_arch transform_pipeline = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) ]) print(f"Model BI-RADS classifier initialized successfully on device: {DEVICE}") # --- Aplikacja FastAPI --- app = FastAPI(title="BI-RADS Mammography Classification API") @app.on_event("startup") async def startup_event(): """Wywoływane przy starcie aplikacji FastAPI.""" initialize_model() @app.post("/predict/", response_model=List[Dict[str, Any]]) async def predict_image(file: UploadFile = File(...)): """ Endpoint do klasyfikacji obrazu mammograficznego. Oczekuje pliku obrazu (JPG, PNG). Zwraca listę z wynikami (nawet jeśli tylko jeden obraz). """ if model_instance is None or transform_pipeline is None: raise HTTPException(status_code=503, detail="Model nie jest zainicjalizowany. Spróbuj ponownie za chwilę.") try: contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("RGB") except Exception as e: raise HTTPException(status_code=400, detail=f"Nie można odczytać pliku obrazu: {e}") # Preprocessing input_tensor = transform_pipeline(image).unsqueeze(0).to(DEVICE) # Inference with torch.no_grad(): model_outputs = model_instance(input_tensor) # Postprocessing probs = torch.nn.functional.softmax(model_outputs, dim=1) confidences, predicted_indices = torch.max(probs, 1) results = [] for i in range(len(predicted_indices)): # Pętla na wypadek przyszłego batch processingu birads_category = predicted_indices[i].item() + 1 confidence = confidences[i].item() interpretation = interpretations_dict.get(birads_category, "Nieznana klasyfikacja") all_class_probs_tensor = probs[i].cpu().numpy() class_probabilities = {str(j+1): float(all_class_probs_tensor[j]) for j in range(len(all_class_probs_tensor))} # Generowanie Grad-CAM grad_cam_map_serialized = None try: for param in model_instance.parameters(): param.requires_grad_(True) model_instance.eval() target_layers = [model_instance.layer4[-1]] # Dla ResNet-18 cam_algorithm = GradCAMPlusPlus(model=model_instance, target_layers=target_layers) current_input_tensor_for_cam = input_tensor[i].unsqueeze(0).clone().detach().requires_grad_(True) targets_for_cam = [ClassifierOutputTarget(predicted_indices[i].item())] grayscale_cam = cam_algorithm(input_tensor=current_input_tensor_for_cam, targets=targets_for_cam) if grayscale_cam is not None: grad_cam_map_np = grayscale_cam[0, :] grad_cam_map_serialized = grad_cam_map_np.tolist() except Exception as e: print(f"Błąd podczas generowania Grad-CAM w API: {e}") results.append({ "birads": birads_category, "confidence": confidence, "interpretation": interpretation, "class_probabilities": class_probabilities, "grad_cam_map": grad_cam_map_serialized }) return JSONResponse(content=results) @app.get("/") async def root(): return {"message": "Witaj w BI-RADS Classification API! Użyj endpointu /predict/ do wysyłania obrazów."} # Do uruchomienia lokalnie: uvicorn api_app:app --reload