BYOL_Mammogram / CLASSIFICATION_GUIDE.md
PranayPalem's picture
πŸ₯ Add BYOL Mammogram Classification Model
d921913

🎯 Classification Training Guide

Complete guide for fine-tuning the BYOL pre-trained model for multi-label classification.

πŸ“‹ Overview

After BYOL pre-training completes, you can fine-tune the model for classification using the train_classification.py script. This approach:

  1. Loads the BYOL checkpoint with learned representations
  2. Freezes the backbone initially (optional) to prevent overwriting good features
  3. Fine-tunes the classification head with a higher learning rate
  4. Gradually unfreezes the backbone for end-to-end fine-tuning

πŸ—‚οΈ Data Preparation

CSV Format

Create train/validation CSV files with this format:

tile_path,mass,calcification,architectural_distortion,asymmetry,normal,benign,malignant,birads_2,birads_3,birads_4
patient1_tile_001.png,1,0,0,0,0,1,0,0,1,0
patient1_tile_002.png,0,1,0,0,0,0,1,0,0,1
patient2_tile_001.png,0,0,0,0,1,1,0,1,0,0
...

Requirements:

  • tile_path: Relative path to tile image
  • Class columns: Binary labels (0/1) for each class
  • Multi-label support: Each image can have multiple classes = 1

Directory Structure

your_project/
β”œβ”€β”€ tiles/                    # Directory containing tile images
β”‚   β”œβ”€β”€ patient1_tile_001.png
β”‚   β”œβ”€β”€ patient1_tile_002.png
β”‚   └── ...
β”œβ”€β”€ train_labels.csv         # Training labels
β”œβ”€β”€ val_labels.csv          # Validation labels
└── mammogram_byol_best.pth # BYOL checkpoint

πŸš€ Quick Start

1. Basic Classification Training

python train_classification.py \
    --byol_checkpoint ./mammogram_byol_best.pth \
    --train_csv ./train_labels.csv \
    --val_csv ./val_labels.csv \
    --tiles_dir ./tiles \
    --class_names mass calcification architectural_distortion asymmetry normal benign malignant birads_2 birads_3 birads_4 \
    --output_dir ./classification_results

2. With Custom Configuration

python train_classification.py \
    --byol_checkpoint ./mammogram_byol_best.pth \
    --train_csv ./train_labels.csv \
    --val_csv ./val_labels.csv \
    --tiles_dir ./tiles \
    --class_names mass calcification normal \
    --config ./classification_config.json \
    --output_dir ./classification_results \
    --wandb_project my-mammogram-classification

3. Quick Testing (Limited Dataset)

python train_classification.py \
    --byol_checkpoint ./mammogram_byol_best.pth \
    --train_csv ./train_labels.csv \
    --val_csv ./val_labels.csv \
    --tiles_dir ./tiles \
    --class_names mass calcification normal \
    --max_samples 1000 \
    --output_dir ./test_results

βš™οΈ Configuration Options

Key Parameters

Parameter Default Description
batch_size 32 Batch size for training
epochs 50 Number of training epochs
lr_backbone 1e-5 Learning rate for pre-trained backbone
lr_head 1e-3 Learning rate for classification head
freeze_backbone_epochs 10 Epochs to freeze backbone (0 = never freeze)
label_smoothing 0.1 Label smoothing for regularization
gradient_clip 1.0 Gradient clipping max norm

Custom Configuration File

Create my_config.json:

{
  "batch_size": 64,
  "epochs": 100,
  "lr_backbone": 5e-6,
  "lr_head": 2e-3,
  "freeze_backbone_epochs": 20,
  "label_smoothing": 0.2,
  "weight_decay": 1e-3
}

πŸ“Š Expected Training Process

Phase 1: Backbone Frozen (Epochs 1-10)

🧊 Epoch 1: Backbone frozen (training only classification head)
Epoch   1/50:
  Train Loss: 0.6234
  Val Loss:   0.5891
  Mean AUC:   0.7123
  Mean AP:    0.6894
  Exact Match: 0.4512
  βœ… New best model saved (AUC: 0.7123)

Phase 2: End-to-End Fine-tuning (Epochs 11-50)

Epoch  15/50:
  Train Loss: 0.3456
  Val Loss:   0.3891
  Mean AUC:   0.8567
  Mean AP:    0.8234
  Exact Match: 0.6789
  βœ… New best model saved (AUC: 0.8567)

πŸ” Making Predictions

Single Image Inference

python inference_classification.py \
    --model_path ./classification_results/best_classification_model.pth \
    --image_path ./test_image.png \
    --threshold 0.5

Output:

πŸ“Έ Image 1: test_image.png
πŸ† Top prediction: mass (0.847)
πŸ“Š All probabilities:
   βœ… mass              : 0.847
   ❌ calcification     : 0.234
   ❌ normal            : 0.123
   ❌ architectural_distortion: 0.089

Batch Inference

python inference_classification.py \
    --model_path ./classification_results/best_classification_model.pth \
    --images_dir ./test_images \
    --output_json ./predictions.json \
    --batch_size 64

Programmatic Usage

import torch
from train_byol_mammo import MammogramBYOL
from inference_classification import load_classification_model, create_inference_transform

# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, class_names, config = load_classification_model(
    "./classification_results/best_classification_model.pth", device
)

# Make prediction
transform = create_inference_transform()
image = Image.open("test.png").convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device)

with torch.no_grad():
    logits = model.classify(input_tensor)
    probabilities = torch.sigmoid(logits).cpu().numpy()[0]

# Get results
for i, class_name in enumerate(class_names):
    print(f"{class_name}: {probabilities[i]:.3f}")

πŸ“ˆ Monitoring Training

Weights & Biases Integration

The script automatically logs to W&B:

  • Training/validation loss curves
  • Per-class AUC and Average Precision
  • Learning rate schedules
  • Model hyperparameters

Metrics Explained

  • AUC (Area Under Curve): Measures ranking quality (0-1, higher better)
  • AP (Average Precision): Summarizes precision-recall curve (0-1, higher better)
  • Exact Match Accuracy: Percentage where ALL labels are predicted correctly
  • Per-Class Accuracy: Binary accuracy for each individual class

πŸ’Ύ Output Files

Training creates:

classification_results/
β”œβ”€β”€ best_classification_model.pth      # Best model by validation AUC
β”œβ”€β”€ final_classification_model.pth     # Final model after all epochs
β”œβ”€β”€ classification_epoch_10.pth        # Periodic checkpoints
β”œβ”€β”€ classification_epoch_20.pth
└── ...

Each checkpoint contains:

  • Model state dict
  • Optimizer state
  • Training configuration
  • Class names
  • Validation metrics

πŸ› οΈ Advanced Usage

Custom Loss Functions

For imbalanced datasets, modify the loss function:

# Calculate positive weights for each class
pos_counts = df[class_names].sum()
neg_counts = len(df) - pos_counts
pos_weight = torch.tensor(neg_counts / pos_counts).to(device)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

Transfer Learning Strategies

  1. Conservative: Freeze backbone for many epochs, low backbone LR

    • freeze_backbone_epochs = 20
    • lr_backbone = 1e-6
  2. Aggressive: Unfreeze early, higher backbone LR

    • freeze_backbone_epochs = 5
    • lr_backbone = 1e-4
  3. Progressive: Gradually unfreeze layers (requires code modification)

Multi-GPU Training

For multiple GPUs, wrap the model:

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

⚠️ Troubleshooting

Common Issues

Low Validation Performance:

  • Increase freeze_backbone_epochs to 15-20
  • Reduce lr_backbone to 5e-6 or 1e-6
  • Check for data leakage between train/val sets

Overfitting:

  • Increase label_smoothing to 0.2-0.3
  • Add more dropout (modify model architecture)
  • Reduce learning rates
  • Use early stopping

Memory Issues:

  • Reduce batch_size to 16 or 8
  • Reduce num_workers to 4
  • Use gradient checkpointing (requires code modification)

Class Imbalance:

  • Use pos_weight in loss function
  • Focus on per-class AUC rather than accuracy
  • Consider focal loss for extreme imbalance

🎯 Best Practices

  1. Start Conservative: Use default settings first
  2. Monitor Per-Class Metrics: Some classes may need special attention
  3. Validate Data: Ensure no train/val overlap
  4. Checkpoint Often: Training can be interrupted
  5. Use Multiple Runs: Average results across random seeds
  6. Test Thoroughly: Use held-out test set for final evaluation

πŸ“š Complete Example

Here's a full workflow from BYOL training to classification:

# 1. Train BYOL (this takes 4-5 hours on A100)
python train_byol_mammo.py

# 2. Prepare classification data (create CSVs with labels)
# ... prepare train_labels.csv and val_labels.csv ...

# 3. Fine-tune for classification (1-2 hours)
python train_classification.py \
    --byol_checkpoint ./mammogram_byol_best.pth \
    --train_csv ./train_labels.csv \
    --val_csv ./val_labels.csv \
    --tiles_dir ./tiles \
    --class_names mass calcification architectural_distortion asymmetry normal \
    --output_dir ./classification_results

# 4. Run inference on new images
python inference_classification.py \
    --model_path ./classification_results/best_classification_model.pth \
    --images_dir ./new_patient_tiles \
    --output_json ./patient_predictions.json

This gives you a complete pipeline from self-supervised pre-training to production-ready classification! πŸš€