BYOL Mammogram Classification Model
A self-supervised learning model for mammogram analysis using Bootstrap Your Own Latent (BYOL) pre-training with ResNet50 backbone.
Model Description
This model implements BYOL (Bootstrap Your Own Latent) self-supervised pre-training on mammogram breast tissue tiles, followed by fine-tuning for classification tasks. The model is designed specifically for medical imaging applications with aggressive background rejection and intelligent tissue segmentation.
Key Features
- Self-supervised pre-training: Uses BYOL to learn meaningful representations from unlabeled mammogram data
- Aggressive background rejection: Multi-level filtering eliminates empty space and background tiles
- Medical-optimized augmentations: Preserves anatomical details while providing effective augmentation
- High-quality tile extraction: Intelligent breast tissue segmentation with frequency-based selection
- A100 GPU optimized: Mixed precision training with advanced optimizations
Model Architecture
- Backbone: ResNet50 (ImageNet pre-trained β BYOL fine-tuned)
- Input dimension: 2048 (ResNet50 features)
- Hidden dimension: 4096
- Projection dimension: 256
- Tile size: 512x512 pixels
- Input format: RGB (grayscale mammograms converted to RGB)
Training Details
BYOL Pre-training
- Epochs: 100
- Batch size: 32 (A100 optimized)
- Learning rate: 2e-3 with warmup
- Optimizer: AdamW with cosine annealing
- Mixed precision: Enabled for A100 optimization
- Momentum updates: Per-step momentum scheduling (0.996 β 1.0)
Data Processing
- Tile extraction: 512x512 pixels with 50% overlap
- Background rejection: Multiple criteria including intensity, frequency energy, and tissue ratio
- Minimum breast ratio: 15% (increased from typical 30%)
- Minimum frequency energy: 0.03 (aggressive threshold)
- Augmentations: Medical-safe rotations, flips, color jittering, perspective transforms
Usage
Loading the Model
import torch
from train_byol_mammo import MammogramBYOL
from torchvision import models
import torch.nn as nn
# Load the pre-trained BYOL model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create ResNet50 backbone
resnet = models.resnet50(weights=None)
backbone = nn.Sequential(*list(resnet.children())[:-1])
# Initialize BYOL model
model = MammogramBYOL(
backbone=backbone,
input_dim=2048,
hidden_dim=4096,
proj_dim=256
).to(device)
# Load pre-trained weights
checkpoint = torch.load('mammogram_byol_best.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
Feature Extraction
# Extract features from mammogram tiles
def extract_features(image_tensor):
with torch.no_grad():
features = model.get_features(image_tensor)
return features
# Example usage
image = torch.randn(1, 3, 512, 512).to(device) # Example input
features = extract_features(image) # Returns 2048-dim features
Classification Fine-tuning
Use the provided train_classification.py script for downstream classification tasks:
python train_classification.py \
--byol_checkpoint ./mammogram_byol_best.pth \
--train_csv ./train_labels.csv \
--val_csv ./val_labels.csv \
--tiles_dir ./tiles/ \
--output_dir ./classification_results/
File Structure
BYOL_Mammogram/
βββ mammogram_byol_best.pth # Best BYOL checkpoint
βββ mammogram_byol_final.pth # Final BYOL checkpoint
βββ train_byol_mammo.py # BYOL pre-training script
βββ train_classification.py # Classification fine-tuning
βββ inference_classification.py # Inference script
βββ classification_config.json # Classification configuration
βββ CLASSIFICATION_GUIDE.md # Detailed training guide
βββ requirements.txt # Dependencies
Performance
Pre-training Results
- Dataset: High-quality breast tissue tiles with aggressive background rejection
- Efficiency: ~15-20% tile selection rate (quality over quantity)
- Background contamination: 0% (eliminated during extraction)
- Training time: ~100 epochs on A100 GPU
Key Metrics
- Average breast tissue per tile: >15%
- Average frequency energy: >0.03
- Tile quality: Medical-grade with preserved anatomical details
Technical Specifications
Hardware Requirements
- GPU: A100 (40GB/80GB) recommended
- Memory: 35-40GB GPU memory for training
- CPU: 16+ cores for data loading
Dependencies
torch>=2.0.0
torchvision>=0.15.0
lightly>=1.4.0
opencv-python>=4.8.0
scipy>=1.10.0
numpy>=1.24.0
Pillow>=9.5.0
tqdm>=4.65.0
Medical Imaging Considerations
Data Safety
- Augmentation strategy: Preserves medical accuracy while providing diversity
- Background rejection: Prevents training on non-diagnostic regions
- Tissue focus: Emphasizes clinically relevant breast tissue areas
Clinical Applications
- Screening support: Potential for computer-aided detection
- Research tool: Feature extraction for medical AI research
- Educational: Understanding mammogram image analysis
Limitations
- Domain specific: Trained specifically on mammogram data
- Preprocessing required: Requires proper tissue segmentation
- Computational intensive: Large model requiring substantial GPU resources
- Medical supervision: Requires clinical validation for any medical application
Citation
If you use this model in your research, please cite:
@model{byol_mammogram_2024,
title={BYOL Mammogram Classification Model},
author={PranayPalem},
year={2024},
url={https://huggingface.co/PranayPalem/BYOL_Mammogram}
}
License
MIT License - See LICENSE file for details.
Disclaimer
This model is for research purposes only and should not be used for clinical diagnosis without proper validation and medical supervision. Always consult healthcare professionals for medical decisions.
Model tree for PranayPalem/BYOL_Mammogram
Base model
microsoft/resnet-50