|
|
|
|
|
""" |
|
|
example_usage.py |
|
|
|
|
|
Demonstrates how to use the BYOL Mammogram model for feature extraction |
|
|
and classification tasks. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import models, transforms |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
from train_byol_mammo import MammogramBYOL |
|
|
|
|
|
|
|
|
def load_byol_model(checkpoint_path: str, device: torch.device): |
|
|
"""Load the pre-trained BYOL model for feature extraction.""" |
|
|
|
|
|
print(f"π₯ Loading BYOL model from: {checkpoint_path}") |
|
|
|
|
|
|
|
|
resnet = models.resnet50(weights=None) |
|
|
backbone = nn.Sequential(*list(resnet.children())[:-1]) |
|
|
|
|
|
|
|
|
model = MammogramBYOL( |
|
|
backbone=backbone, |
|
|
input_dim=2048, |
|
|
hidden_dim=4096, |
|
|
proj_dim=256 |
|
|
).to(device) |
|
|
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
print(f"β
Model loaded successfully!") |
|
|
print(f" Epoch: {checkpoint.get('epoch', 'Unknown')}") |
|
|
print(f" Final loss: {checkpoint.get('loss', 'Unknown'):.4f}") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def create_inference_transform(tile_size: int = 512): |
|
|
"""Create the preprocessing transform for inference.""" |
|
|
return transforms.Compose([ |
|
|
transforms.Resize((tile_size, tile_size), antialias=True), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
|
]) |
|
|
|
|
|
|
|
|
def extract_features(model, image_tensor, device): |
|
|
"""Extract 2048-dimensional features from mammogram tiles.""" |
|
|
with torch.no_grad(): |
|
|
image_tensor = image_tensor.to(device) |
|
|
features = model.get_features(image_tensor) |
|
|
return features.cpu().numpy() |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Demonstrate model usage.""" |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"π₯οΈ Using device: {device}") |
|
|
|
|
|
|
|
|
model = load_byol_model("mammogram_byol_best.pth", device) |
|
|
|
|
|
|
|
|
transform = create_inference_transform(tile_size=512) |
|
|
|
|
|
|
|
|
print("\nπ Example 1: Feature Extraction") |
|
|
print("-" * 40) |
|
|
|
|
|
|
|
|
dummy_image = Image.fromarray(np.random.randint(0, 255, (512, 512), dtype=np.uint8)) |
|
|
dummy_image = dummy_image.convert('RGB') |
|
|
|
|
|
|
|
|
image_tensor = transform(dummy_image).unsqueeze(0) |
|
|
|
|
|
|
|
|
features = extract_features(model, image_tensor, device) |
|
|
|
|
|
print(f"β
Input shape: {image_tensor.shape}") |
|
|
print(f"β
Feature shape: {features.shape}") |
|
|
print(f"β
Feature vector (first 10 values): {features[0][:10]}") |
|
|
|
|
|
|
|
|
print("\nπ Example 2: Batch Feature Extraction") |
|
|
print("-" * 40) |
|
|
|
|
|
|
|
|
batch_size = 4 |
|
|
dummy_batch = torch.stack([ |
|
|
transform(Image.fromarray(np.random.randint(0, 255, (512, 512), dtype=np.uint8)).convert('RGB')) |
|
|
for _ in range(batch_size) |
|
|
]) |
|
|
|
|
|
|
|
|
batch_features = extract_features(model, dummy_batch, device) |
|
|
|
|
|
print(f"β
Batch input shape: {dummy_batch.shape}") |
|
|
print(f"β
Batch features shape: {batch_features.shape}") |
|
|
print(f"β
Features per image: {batch_features.shape[1]} dimensions") |
|
|
|
|
|
|
|
|
print("\nπ Example 3: Feature Similarity") |
|
|
print("-" * 40) |
|
|
|
|
|
|
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
|
|
similarity = cosine_similarity( |
|
|
batch_features[0:1], |
|
|
batch_features[1:2] |
|
|
)[0][0] |
|
|
|
|
|
print(f"β
Cosine similarity between image 1 and 2: {similarity:.4f}") |
|
|
|
|
|
print("\nπ― Next Steps:") |
|
|
print("- Use these 2048D features for downstream classification") |
|
|
print("- Train a classifier using train_classification.py") |
|
|
print("- Fine-tune the entire model for specific tasks") |
|
|
print("- Use for similarity search or clustering") |
|
|
|
|
|
print(f"\nπ Model Summary:") |
|
|
print(f"- Architecture: ResNet50 + BYOL") |
|
|
print(f"- Input: 512x512 RGB mammogram tiles") |
|
|
print(f"- Output: 2048-dimensional feature vectors") |
|
|
print(f"- Training: Self-supervised on breast tissue tiles") |
|
|
print(f"- Use case: Medical image analysis and classification") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |