#!/usr/bin/env python3 """ example_usage.py Demonstrates how to use the BYOL Mammogram model for feature extraction and classification tasks. """ import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image import numpy as np from pathlib import Path # Import the BYOL model classes from train_byol_mammo import MammogramBYOL def load_byol_model(checkpoint_path: str, device: torch.device): """Load the pre-trained BYOL model for feature extraction.""" print(f"šŸ“„ Loading BYOL model from: {checkpoint_path}") # Create ResNet50 backbone (same as training) resnet = models.resnet50(weights=None) backbone = nn.Sequential(*list(resnet.children())[:-1]) # Initialize BYOL model with same architecture model = MammogramBYOL( backbone=backbone, input_dim=2048, # ResNet50 feature dimension hidden_dim=4096, # BYOL projection head hidden dim proj_dim=256 # BYOL projection dimension ).to(device) # Load the trained weights checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() print(f"āœ… Model loaded successfully!") print(f" Epoch: {checkpoint.get('epoch', 'Unknown')}") print(f" Final loss: {checkpoint.get('loss', 'Unknown'):.4f}") return model def create_inference_transform(tile_size: int = 512): """Create the preprocessing transform for inference.""" return transforms.Compose([ transforms.Resize((tile_size, tile_size), antialias=True), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) def extract_features(model, image_tensor, device): """Extract 2048-dimensional features from mammogram tiles.""" with torch.no_grad(): image_tensor = image_tensor.to(device) features = model.get_features(image_tensor) return features.cpu().numpy() def main(): """Demonstrate model usage.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"šŸ–„ļø Using device: {device}") # Load the pre-trained BYOL model model = load_byol_model("mammogram_byol_best.pth", device) # Create preprocessing transform transform = create_inference_transform(tile_size=512) # Example 1: Feature extraction from a single image print("\nšŸ“Š Example 1: Feature Extraction") print("-" * 40) # Create a dummy mammogram tile (replace with actual image loading) dummy_image = Image.fromarray(np.random.randint(0, 255, (512, 512), dtype=np.uint8)) dummy_image = dummy_image.convert('RGB') # Convert to RGB as expected # Preprocess the image image_tensor = transform(dummy_image).unsqueeze(0) # Add batch dimension # Extract features features = extract_features(model, image_tensor, device) print(f"āœ… Input shape: {image_tensor.shape}") print(f"āœ… Feature shape: {features.shape}") print(f"āœ… Feature vector (first 10 values): {features[0][:10]}") # Example 2: Batch processing multiple images print("\nšŸ“Š Example 2: Batch Feature Extraction") print("-" * 40) # Create a batch of dummy images batch_size = 4 dummy_batch = torch.stack([ transform(Image.fromarray(np.random.randint(0, 255, (512, 512), dtype=np.uint8)).convert('RGB')) for _ in range(batch_size) ]) # Extract features for the entire batch batch_features = extract_features(model, dummy_batch, device) print(f"āœ… Batch input shape: {dummy_batch.shape}") print(f"āœ… Batch features shape: {batch_features.shape}") print(f"āœ… Features per image: {batch_features.shape[1]} dimensions") # Example 3: Similarity computation print("\nšŸ“Š Example 3: Feature Similarity") print("-" * 40) # Compute cosine similarity between first two images from sklearn.metrics.pairwise import cosine_similarity similarity = cosine_similarity( batch_features[0:1], batch_features[1:2] )[0][0] print(f"āœ… Cosine similarity between image 1 and 2: {similarity:.4f}") print("\nšŸŽÆ Next Steps:") print("- Use these 2048D features for downstream classification") print("- Train a classifier using train_classification.py") print("- Fine-tune the entire model for specific tasks") print("- Use for similarity search or clustering") print(f"\nšŸ“š Model Summary:") print(f"- Architecture: ResNet50 + BYOL") print(f"- Input: 512x512 RGB mammogram tiles") print(f"- Output: 2048-dimensional feature vectors") print(f"- Training: Self-supervised on breast tissue tiles") print(f"- Use case: Medical image analysis and classification") if __name__ == "__main__": main()