BYOL Mammogram Classification Model

A self-supervised learning model for mammogram analysis using Bootstrap Your Own Latent (BYOL) pre-training with ResNet50 backbone.

Model Description

This model implements BYOL (Bootstrap Your Own Latent) self-supervised pre-training on mammogram breast tissue tiles, followed by fine-tuning for classification tasks. The model is designed specifically for medical imaging applications with aggressive background rejection and intelligent tissue segmentation.

Key Features

  • Self-supervised pre-training: Uses BYOL to learn meaningful representations from unlabeled mammogram data
  • Aggressive background rejection: Multi-level filtering eliminates empty space and background tiles
  • Medical-optimized augmentations: Preserves anatomical details while providing effective augmentation
  • High-quality tile extraction: Intelligent breast tissue segmentation with frequency-based selection
  • A100 GPU optimized: Mixed precision training with advanced optimizations

Model Architecture

  • Backbone: ResNet50 (ImageNet pre-trained β†’ BYOL fine-tuned)
  • Input dimension: 2048 (ResNet50 features)
  • Hidden dimension: 4096
  • Projection dimension: 256
  • Tile size: 512x512 pixels
  • Input format: RGB (grayscale mammograms converted to RGB)

Training Details

BYOL Pre-training

  • Epochs: 100
  • Batch size: 32 (A100 optimized)
  • Learning rate: 2e-3 with warmup
  • Optimizer: AdamW with cosine annealing
  • Mixed precision: Enabled for A100 optimization
  • Momentum updates: Per-step momentum scheduling (0.996 β†’ 1.0)

Data Processing

  • Tile extraction: 512x512 pixels with 50% overlap
  • Background rejection: Multiple criteria including intensity, frequency energy, and tissue ratio
  • Minimum breast ratio: 15% (increased from typical 30%)
  • Minimum frequency energy: 0.03 (aggressive threshold)
  • Augmentations: Medical-safe rotations, flips, color jittering, perspective transforms

Usage

Loading the Model

import torch
from train_byol_mammo import MammogramBYOL
from torchvision import models
import torch.nn as nn

# Load the pre-trained BYOL model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create ResNet50 backbone
resnet = models.resnet50(weights=None)
backbone = nn.Sequential(*list(resnet.children())[:-1])

# Initialize BYOL model
model = MammogramBYOL(
    backbone=backbone,
    input_dim=2048,
    hidden_dim=4096,
    proj_dim=256
).to(device)

# Load pre-trained weights
checkpoint = torch.load('mammogram_byol_best.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

Feature Extraction

# Extract features from mammogram tiles
def extract_features(image_tensor):
    with torch.no_grad():
        features = model.get_features(image_tensor)
    return features

# Example usage
image = torch.randn(1, 3, 512, 512).to(device)  # Example input
features = extract_features(image)  # Returns 2048-dim features

Classification Fine-tuning

Use the provided train_classification.py script for downstream classification tasks:

python train_classification.py \
    --byol_checkpoint ./mammogram_byol_best.pth \
    --train_csv ./train_labels.csv \
    --val_csv ./val_labels.csv \
    --tiles_dir ./tiles/ \
    --output_dir ./classification_results/

File Structure

BYOL_Mammogram/
β”œβ”€β”€ mammogram_byol_best.pth          # Best BYOL checkpoint
β”œβ”€β”€ mammogram_byol_final.pth         # Final BYOL checkpoint
β”œβ”€β”€ train_byol_mammo.py              # BYOL pre-training script
β”œβ”€β”€ train_classification.py          # Classification fine-tuning
β”œβ”€β”€ inference_classification.py     # Inference script
β”œβ”€β”€ classification_config.json      # Classification configuration
β”œβ”€β”€ CLASSIFICATION_GUIDE.md         # Detailed training guide
└── requirements.txt                # Dependencies

Performance

Pre-training Results

  • Dataset: High-quality breast tissue tiles with aggressive background rejection
  • Efficiency: ~15-20% tile selection rate (quality over quantity)
  • Background contamination: 0% (eliminated during extraction)
  • Training time: ~100 epochs on A100 GPU

Key Metrics

  • Average breast tissue per tile: >15%
  • Average frequency energy: >0.03
  • Tile quality: Medical-grade with preserved anatomical details

Technical Specifications

Hardware Requirements

  • GPU: A100 (40GB/80GB) recommended
  • Memory: 35-40GB GPU memory for training
  • CPU: 16+ cores for data loading

Dependencies

torch>=2.0.0
torchvision>=0.15.0
lightly>=1.4.0
opencv-python>=4.8.0
scipy>=1.10.0
numpy>=1.24.0
Pillow>=9.5.0
tqdm>=4.65.0

Medical Imaging Considerations

Data Safety

  • Augmentation strategy: Preserves medical accuracy while providing diversity
  • Background rejection: Prevents training on non-diagnostic regions
  • Tissue focus: Emphasizes clinically relevant breast tissue areas

Clinical Applications

  • Screening support: Potential for computer-aided detection
  • Research tool: Feature extraction for medical AI research
  • Educational: Understanding mammogram image analysis

Limitations

  • Domain specific: Trained specifically on mammogram data
  • Preprocessing required: Requires proper tissue segmentation
  • Computational intensive: Large model requiring substantial GPU resources
  • Medical supervision: Requires clinical validation for any medical application

Citation

If you use this model in your research, please cite:

@model{byol_mammogram_2024,
  title={BYOL Mammogram Classification Model},
  author={PranayPalem},
  year={2024},
  url={https://huggingface.co/PranayPalem/BYOL_Mammogram}
}

License

MIT License - See LICENSE file for details.

Disclaimer

This model is for research purposes only and should not be used for clinical diagnosis without proper validation and medical supervision. Always consult healthcare professionals for medical decisions.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for PranayPalem/BYOL_Mammogram

Finetuned
(447)
this model