π― Classification Training Guide
Complete guide for fine-tuning the BYOL pre-trained model for multi-label classification.
π Overview
After BYOL pre-training completes, you can fine-tune the model for classification using the train_classification.py script. This approach:
- Loads the BYOL checkpoint with learned representations
- Freezes the backbone initially (optional) to prevent overwriting good features
- Fine-tunes the classification head with a higher learning rate
- Gradually unfreezes the backbone for end-to-end fine-tuning
ποΈ Data Preparation
CSV Format
Create train/validation CSV files with this format:
tile_path,mass,calcification,architectural_distortion,asymmetry,normal,benign,malignant,birads_2,birads_3,birads_4
patient1_tile_001.png,1,0,0,0,0,1,0,0,1,0
patient1_tile_002.png,0,1,0,0,0,0,1,0,0,1
patient2_tile_001.png,0,0,0,0,1,1,0,1,0,0
...
Requirements:
tile_path: Relative path to tile image- Class columns: Binary labels (0/1) for each class
- Multi-label support: Each image can have multiple classes = 1
Directory Structure
your_project/
βββ tiles/ # Directory containing tile images
β βββ patient1_tile_001.png
β βββ patient1_tile_002.png
β βββ ...
βββ train_labels.csv # Training labels
βββ val_labels.csv # Validation labels
βββ mammogram_byol_best.pth # BYOL checkpoint
π Quick Start
1. Basic Classification Training
python train_classification.py \
--byol_checkpoint ./mammogram_byol_best.pth \
--train_csv ./train_labels.csv \
--val_csv ./val_labels.csv \
--tiles_dir ./tiles \
--class_names mass calcification architectural_distortion asymmetry normal benign malignant birads_2 birads_3 birads_4 \
--output_dir ./classification_results
2. With Custom Configuration
python train_classification.py \
--byol_checkpoint ./mammogram_byol_best.pth \
--train_csv ./train_labels.csv \
--val_csv ./val_labels.csv \
--tiles_dir ./tiles \
--class_names mass calcification normal \
--config ./classification_config.json \
--output_dir ./classification_results \
--wandb_project my-mammogram-classification
3. Quick Testing (Limited Dataset)
python train_classification.py \
--byol_checkpoint ./mammogram_byol_best.pth \
--train_csv ./train_labels.csv \
--val_csv ./val_labels.csv \
--tiles_dir ./tiles \
--class_names mass calcification normal \
--max_samples 1000 \
--output_dir ./test_results
βοΈ Configuration Options
Key Parameters
| Parameter | Default | Description |
|---|---|---|
batch_size |
32 | Batch size for training |
epochs |
50 | Number of training epochs |
lr_backbone |
1e-5 | Learning rate for pre-trained backbone |
lr_head |
1e-3 | Learning rate for classification head |
freeze_backbone_epochs |
10 | Epochs to freeze backbone (0 = never freeze) |
label_smoothing |
0.1 | Label smoothing for regularization |
gradient_clip |
1.0 | Gradient clipping max norm |
Custom Configuration File
Create my_config.json:
{
"batch_size": 64,
"epochs": 100,
"lr_backbone": 5e-6,
"lr_head": 2e-3,
"freeze_backbone_epochs": 20,
"label_smoothing": 0.2,
"weight_decay": 1e-3
}
π Expected Training Process
Phase 1: Backbone Frozen (Epochs 1-10)
π§ Epoch 1: Backbone frozen (training only classification head)
Epoch 1/50:
Train Loss: 0.6234
Val Loss: 0.5891
Mean AUC: 0.7123
Mean AP: 0.6894
Exact Match: 0.4512
β
New best model saved (AUC: 0.7123)
Phase 2: End-to-End Fine-tuning (Epochs 11-50)
Epoch 15/50:
Train Loss: 0.3456
Val Loss: 0.3891
Mean AUC: 0.8567
Mean AP: 0.8234
Exact Match: 0.6789
β
New best model saved (AUC: 0.8567)
π Making Predictions
Single Image Inference
python inference_classification.py \
--model_path ./classification_results/best_classification_model.pth \
--image_path ./test_image.png \
--threshold 0.5
Output:
πΈ Image 1: test_image.png
π Top prediction: mass (0.847)
π All probabilities:
β
mass : 0.847
β calcification : 0.234
β normal : 0.123
β architectural_distortion: 0.089
Batch Inference
python inference_classification.py \
--model_path ./classification_results/best_classification_model.pth \
--images_dir ./test_images \
--output_json ./predictions.json \
--batch_size 64
Programmatic Usage
import torch
from train_byol_mammo import MammogramBYOL
from inference_classification import load_classification_model, create_inference_transform
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, class_names, config = load_classification_model(
"./classification_results/best_classification_model.pth", device
)
# Make prediction
transform = create_inference_transform()
image = Image.open("test.png").convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
logits = model.classify(input_tensor)
probabilities = torch.sigmoid(logits).cpu().numpy()[0]
# Get results
for i, class_name in enumerate(class_names):
print(f"{class_name}: {probabilities[i]:.3f}")
π Monitoring Training
Weights & Biases Integration
The script automatically logs to W&B:
- Training/validation loss curves
- Per-class AUC and Average Precision
- Learning rate schedules
- Model hyperparameters
Metrics Explained
- AUC (Area Under Curve): Measures ranking quality (0-1, higher better)
- AP (Average Precision): Summarizes precision-recall curve (0-1, higher better)
- Exact Match Accuracy: Percentage where ALL labels are predicted correctly
- Per-Class Accuracy: Binary accuracy for each individual class
πΎ Output Files
Training creates:
classification_results/
βββ best_classification_model.pth # Best model by validation AUC
βββ final_classification_model.pth # Final model after all epochs
βββ classification_epoch_10.pth # Periodic checkpoints
βββ classification_epoch_20.pth
βββ ...
Each checkpoint contains:
- Model state dict
- Optimizer state
- Training configuration
- Class names
- Validation metrics
π οΈ Advanced Usage
Custom Loss Functions
For imbalanced datasets, modify the loss function:
# Calculate positive weights for each class
pos_counts = df[class_names].sum()
neg_counts = len(df) - pos_counts
pos_weight = torch.tensor(neg_counts / pos_counts).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
Transfer Learning Strategies
Conservative: Freeze backbone for many epochs, low backbone LR
freeze_backbone_epochs = 20lr_backbone = 1e-6
Aggressive: Unfreeze early, higher backbone LR
freeze_backbone_epochs = 5lr_backbone = 1e-4
Progressive: Gradually unfreeze layers (requires code modification)
Multi-GPU Training
For multiple GPUs, wrap the model:
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
β οΈ Troubleshooting
Common Issues
Low Validation Performance:
- Increase
freeze_backbone_epochsto 15-20 - Reduce
lr_backboneto 5e-6 or 1e-6 - Check for data leakage between train/val sets
Overfitting:
- Increase
label_smoothingto 0.2-0.3 - Add more dropout (modify model architecture)
- Reduce learning rates
- Use early stopping
Memory Issues:
- Reduce
batch_sizeto 16 or 8 - Reduce
num_workersto 4 - Use gradient checkpointing (requires code modification)
Class Imbalance:
- Use
pos_weightin loss function - Focus on per-class AUC rather than accuracy
- Consider focal loss for extreme imbalance
π― Best Practices
- Start Conservative: Use default settings first
- Monitor Per-Class Metrics: Some classes may need special attention
- Validate Data: Ensure no train/val overlap
- Checkpoint Often: Training can be interrupted
- Use Multiple Runs: Average results across random seeds
- Test Thoroughly: Use held-out test set for final evaluation
π Complete Example
Here's a full workflow from BYOL training to classification:
# 1. Train BYOL (this takes 4-5 hours on A100)
python train_byol_mammo.py
# 2. Prepare classification data (create CSVs with labels)
# ... prepare train_labels.csv and val_labels.csv ...
# 3. Fine-tune for classification (1-2 hours)
python train_classification.py \
--byol_checkpoint ./mammogram_byol_best.pth \
--train_csv ./train_labels.csv \
--val_csv ./val_labels.csv \
--tiles_dir ./tiles \
--class_names mass calcification architectural_distortion asymmetry normal \
--output_dir ./classification_results
# 4. Run inference on new images
python inference_classification.py \
--model_path ./classification_results/best_classification_model.pth \
--images_dir ./new_patient_tiles \
--output_json ./patient_predictions.json
This gives you a complete pipeline from self-supervised pre-training to production-ready classification! π