BYOL_Mammogram / example_usage.py
PranayPalem's picture
πŸ“– Add example usage script
853f08d
#!/usr/bin/env python3
"""
example_usage.py
Demonstrates how to use the BYOL Mammogram model for feature extraction
and classification tasks.
"""
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import numpy as np
from pathlib import Path
# Import the BYOL model classes
from train_byol_mammo import MammogramBYOL
def load_byol_model(checkpoint_path: str, device: torch.device):
"""Load the pre-trained BYOL model for feature extraction."""
print(f"πŸ“₯ Loading BYOL model from: {checkpoint_path}")
# Create ResNet50 backbone (same as training)
resnet = models.resnet50(weights=None)
backbone = nn.Sequential(*list(resnet.children())[:-1])
# Initialize BYOL model with same architecture
model = MammogramBYOL(
backbone=backbone,
input_dim=2048, # ResNet50 feature dimension
hidden_dim=4096, # BYOL projection head hidden dim
proj_dim=256 # BYOL projection dimension
).to(device)
# Load the trained weights
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"βœ… Model loaded successfully!")
print(f" Epoch: {checkpoint.get('epoch', 'Unknown')}")
print(f" Final loss: {checkpoint.get('loss', 'Unknown'):.4f}")
return model
def create_inference_transform(tile_size: int = 512):
"""Create the preprocessing transform for inference."""
return transforms.Compose([
transforms.Resize((tile_size, tile_size), antialias=True),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def extract_features(model, image_tensor, device):
"""Extract 2048-dimensional features from mammogram tiles."""
with torch.no_grad():
image_tensor = image_tensor.to(device)
features = model.get_features(image_tensor)
return features.cpu().numpy()
def main():
"""Demonstrate model usage."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"πŸ–₯️ Using device: {device}")
# Load the pre-trained BYOL model
model = load_byol_model("mammogram_byol_best.pth", device)
# Create preprocessing transform
transform = create_inference_transform(tile_size=512)
# Example 1: Feature extraction from a single image
print("\nπŸ“Š Example 1: Feature Extraction")
print("-" * 40)
# Create a dummy mammogram tile (replace with actual image loading)
dummy_image = Image.fromarray(np.random.randint(0, 255, (512, 512), dtype=np.uint8))
dummy_image = dummy_image.convert('RGB') # Convert to RGB as expected
# Preprocess the image
image_tensor = transform(dummy_image).unsqueeze(0) # Add batch dimension
# Extract features
features = extract_features(model, image_tensor, device)
print(f"βœ… Input shape: {image_tensor.shape}")
print(f"βœ… Feature shape: {features.shape}")
print(f"βœ… Feature vector (first 10 values): {features[0][:10]}")
# Example 2: Batch processing multiple images
print("\nπŸ“Š Example 2: Batch Feature Extraction")
print("-" * 40)
# Create a batch of dummy images
batch_size = 4
dummy_batch = torch.stack([
transform(Image.fromarray(np.random.randint(0, 255, (512, 512), dtype=np.uint8)).convert('RGB'))
for _ in range(batch_size)
])
# Extract features for the entire batch
batch_features = extract_features(model, dummy_batch, device)
print(f"βœ… Batch input shape: {dummy_batch.shape}")
print(f"βœ… Batch features shape: {batch_features.shape}")
print(f"βœ… Features per image: {batch_features.shape[1]} dimensions")
# Example 3: Similarity computation
print("\nπŸ“Š Example 3: Feature Similarity")
print("-" * 40)
# Compute cosine similarity between first two images
from sklearn.metrics.pairwise import cosine_similarity
similarity = cosine_similarity(
batch_features[0:1],
batch_features[1:2]
)[0][0]
print(f"βœ… Cosine similarity between image 1 and 2: {similarity:.4f}")
print("\n🎯 Next Steps:")
print("- Use these 2048D features for downstream classification")
print("- Train a classifier using train_classification.py")
print("- Fine-tune the entire model for specific tasks")
print("- Use for similarity search or clustering")
print(f"\nπŸ“š Model Summary:")
print(f"- Architecture: ResNet50 + BYOL")
print(f"- Input: 512x512 RGB mammogram tiles")
print(f"- Output: 2048-dimensional feature vectors")
print(f"- Training: Self-supervised on breast tissue tiles")
print(f"- Use case: Medical image analysis and classification")
if __name__ == "__main__":
main()