|
|
|
|
|
""" |
|
|
train_classification.py |
|
|
|
|
|
Fine-tune the BYOL pre-trained model for multi-label classification on mammogram tiles. |
|
|
This script loads the BYOL checkpoint and trains only the classification head while |
|
|
optionally fine-tuning the backbone with a lower learning rate. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torch.cuda.amp import autocast, GradScaler |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from PIL import Image |
|
|
import torchvision.transforms as T |
|
|
from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score |
|
|
from tqdm import tqdm |
|
|
import wandb |
|
|
import argparse |
|
|
from typing import Dict, List, Tuple |
|
|
import json |
|
|
|
|
|
|
|
|
from train_byol_mammo import MammogramBYOL |
|
|
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
TILE_SIZE = 512 |
|
|
|
|
|
|
|
|
DEFAULT_CONFIG = { |
|
|
'batch_size': 32, |
|
|
'num_workers': 8, |
|
|
'epochs': 50, |
|
|
'lr_backbone': 1e-5, |
|
|
'lr_head': 1e-3, |
|
|
'weight_decay': 1e-4, |
|
|
'warmup_epochs': 5, |
|
|
'freeze_backbone_epochs': 10, |
|
|
'label_smoothing': 0.1, |
|
|
'dropout_rate': 0.3, |
|
|
'gradient_clip': 1.0, |
|
|
} |
|
|
|
|
|
|
|
|
class MammogramClassificationDataset(Dataset): |
|
|
"""Dataset for mammogram tile classification with multi-label support.""" |
|
|
|
|
|
def __init__(self, csv_path: str, tiles_dir: str, class_names: List[str], |
|
|
transform=None, max_samples: int = None): |
|
|
""" |
|
|
Args: |
|
|
csv_path: Path to CSV with columns ['tile_path', 'class1', 'class2', ...] |
|
|
tiles_dir: Directory containing tile images |
|
|
class_names: List of class names (e.g., ['mass', 'calcification', 'normal', etc.]) |
|
|
transform: Image transformations |
|
|
max_samples: Limit dataset size for testing |
|
|
""" |
|
|
self.tiles_dir = Path(tiles_dir) |
|
|
self.class_names = class_names |
|
|
self.num_classes = len(class_names) |
|
|
self.transform = transform |
|
|
|
|
|
|
|
|
self.df = pd.read_csv(csv_path) |
|
|
if max_samples: |
|
|
self.df = self.df.head(max_samples) |
|
|
|
|
|
print(f"π Loaded {len(self.df)} samples for classification training") |
|
|
print(f"π·οΈ Classes: {class_names}") |
|
|
|
|
|
|
|
|
required_cols = ['tile_path'] + class_names |
|
|
missing_cols = [col for col in required_cols if col not in self.df.columns] |
|
|
if missing_cols: |
|
|
raise ValueError(f"Missing columns in CSV: {missing_cols}") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.df) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
row = self.df.iloc[idx] |
|
|
|
|
|
|
|
|
img_path = self.tiles_dir / row['tile_path'] |
|
|
image = Image.open(img_path).convert('RGB') |
|
|
|
|
|
if self.transform: |
|
|
image = self.transform(image) |
|
|
|
|
|
|
|
|
labels = torch.tensor([row[class_name] for class_name in self.class_names], |
|
|
dtype=torch.float32) |
|
|
|
|
|
return image, labels |
|
|
|
|
|
|
|
|
def create_classification_transforms(tile_size: int, is_training: bool = True): |
|
|
"""Create transforms for classification training.""" |
|
|
|
|
|
if is_training: |
|
|
|
|
|
transform = T.Compose([ |
|
|
T.Resize((tile_size, tile_size)), |
|
|
T.RandomHorizontalFlip(p=0.5), |
|
|
T.RandomVerticalFlip(p=0.2), |
|
|
T.RandomRotation(degrees=10, fill=0), |
|
|
T.ColorJitter(brightness=0.2, contrast=0.2), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
|
]) |
|
|
else: |
|
|
|
|
|
transform = T.Compose([ |
|
|
T.Resize((tile_size, tile_size)), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
|
]) |
|
|
|
|
|
return transform |
|
|
|
|
|
|
|
|
class ClassificationModel(nn.Module): |
|
|
"""Classification model that wraps BYOL backbone with classification head.""" |
|
|
|
|
|
def __init__(self, byol_model: MammogramBYOL, num_classes: int, hidden_dim: int = 2048): |
|
|
super().__init__() |
|
|
self.byol_model = byol_model |
|
|
|
|
|
|
|
|
self.classification_head = nn.Sequential( |
|
|
nn.Linear(2048, hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.3), |
|
|
nn.Linear(hidden_dim, num_classes) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
"""Forward pass for classification.""" |
|
|
features = self.byol_model.get_features(x) |
|
|
return self.classification_head(features) |
|
|
|
|
|
def get_features(self, x): |
|
|
"""Get backbone features.""" |
|
|
return self.byol_model.get_features(x) |
|
|
|
|
|
|
|
|
def load_byol_model(checkpoint_path: str, num_classes: int, device: torch.device): |
|
|
"""Load BYOL pre-trained model and prepare for classification.""" |
|
|
|
|
|
print(f"π₯ Loading BYOL checkpoint: {checkpoint_path}") |
|
|
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
|
|
|
from torchvision import models |
|
|
resnet = models.resnet50(weights=None) |
|
|
backbone = nn.Sequential(*list(resnet.children())[:-1]) |
|
|
|
|
|
byol_model = MammogramBYOL( |
|
|
backbone=backbone, |
|
|
input_dim=2048, |
|
|
hidden_dim=4096, |
|
|
proj_dim=256 |
|
|
).to(device) |
|
|
|
|
|
|
|
|
byol_model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
|
|
|
model = ClassificationModel(byol_model, num_classes).to(device) |
|
|
|
|
|
print(f"β
Loaded BYOL model from epoch {checkpoint.get('epoch', 'unknown')}") |
|
|
print(f"π BYOL training loss: {checkpoint.get('loss', 'unknown'):.4f}") |
|
|
print(f"π― Added classification head: 2048 β {2048} β {num_classes}") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def calculate_metrics(predictions: np.ndarray, targets: np.ndarray, |
|
|
class_names: List[str]) -> Dict[str, float]: |
|
|
"""Calculate comprehensive metrics for multi-label classification.""" |
|
|
|
|
|
metrics = {} |
|
|
|
|
|
|
|
|
pred_binary = (predictions > 0.5).astype(int) |
|
|
|
|
|
|
|
|
for i, class_name in enumerate(class_names): |
|
|
try: |
|
|
|
|
|
auc = roc_auc_score(targets[:, i], predictions[:, i]) |
|
|
metrics[f'auc_{class_name}'] = auc |
|
|
|
|
|
|
|
|
ap = average_precision_score(targets[:, i], predictions[:, i]) |
|
|
metrics[f'ap_{class_name}'] = ap |
|
|
|
|
|
|
|
|
acc = accuracy_score(targets[:, i], pred_binary[:, i]) |
|
|
metrics[f'acc_{class_name}'] = acc |
|
|
|
|
|
except ValueError: |
|
|
|
|
|
metrics[f'auc_{class_name}'] = 0.0 |
|
|
metrics[f'ap_{class_name}'] = 0.0 |
|
|
metrics[f'acc_{class_name}'] = accuracy_score(targets[:, i], pred_binary[:, i]) |
|
|
|
|
|
|
|
|
metrics['mean_auc'] = np.mean([metrics[f'auc_{class_name}'] for class_name in class_names]) |
|
|
metrics['mean_ap'] = np.mean([metrics[f'ap_{class_name}'] for class_name in class_names]) |
|
|
metrics['mean_acc'] = np.mean([metrics[f'acc_{class_name}'] for class_name in class_names]) |
|
|
|
|
|
|
|
|
exact_match = np.all(pred_binary == targets, axis=1).mean() |
|
|
metrics['exact_match_acc'] = exact_match |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
def train_epoch(model: nn.Module, dataloader: DataLoader, criterion: nn.Module, |
|
|
optimizer: optim.Optimizer, scaler: GradScaler, epoch: int, |
|
|
config: dict, freeze_backbone: bool = False) -> Dict[str, float]: |
|
|
"""Train for one epoch.""" |
|
|
|
|
|
model.train() |
|
|
total_loss = 0.0 |
|
|
num_batches = len(dataloader) |
|
|
|
|
|
|
|
|
if freeze_backbone: |
|
|
for param in model.byol_model.backbone.parameters(): |
|
|
param.requires_grad = False |
|
|
for param in model.byol_model.backbone_momentum.parameters(): |
|
|
param.requires_grad = False |
|
|
else: |
|
|
for param in model.byol_model.backbone.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
pbar = tqdm(dataloader, desc=f"Epoch {epoch:3d}/{config['epochs']} [Train]", |
|
|
ncols=100, leave=False) |
|
|
|
|
|
for batch_idx, (images, labels) in enumerate(pbar): |
|
|
images, labels = images.to(DEVICE), labels.to(DEVICE) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
with autocast(): |
|
|
|
|
|
outputs = model(images) |
|
|
loss = criterion(outputs, labels) |
|
|
|
|
|
|
|
|
scaler.scale(loss).backward() |
|
|
scaler.unscale_(optimizer) |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), config['gradient_clip']) |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
|
|
|
total_loss += loss.item() |
|
|
|
|
|
|
|
|
pbar.set_postfix({ |
|
|
'Loss': f'{loss.item():.4f}', |
|
|
'Avg': f'{total_loss/(batch_idx+1):.4f}', |
|
|
'LR': f'{optimizer.param_groups[0]["lr"]:.2e}' |
|
|
}) |
|
|
|
|
|
return {'train_loss': total_loss / num_batches} |
|
|
|
|
|
|
|
|
def validate_epoch(model: nn.Module, dataloader: DataLoader, criterion: nn.Module, |
|
|
class_names: List[str]) -> Dict[str, float]: |
|
|
"""Validate for one epoch.""" |
|
|
|
|
|
model.eval() |
|
|
total_loss = 0.0 |
|
|
all_predictions = [] |
|
|
all_targets = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
pbar = tqdm(dataloader, desc="Validation", ncols=100, leave=False) |
|
|
|
|
|
for images, labels in pbar: |
|
|
images, labels = images.to(DEVICE), labels.to(DEVICE) |
|
|
|
|
|
with autocast(): |
|
|
outputs = model(images) |
|
|
loss = criterion(outputs, labels) |
|
|
|
|
|
total_loss += loss.item() |
|
|
|
|
|
|
|
|
probs = torch.sigmoid(outputs) |
|
|
|
|
|
all_predictions.append(probs.cpu().numpy()) |
|
|
all_targets.append(labels.cpu().numpy()) |
|
|
|
|
|
pbar.set_postfix({'Loss': f'{loss.item():.4f}'}) |
|
|
|
|
|
|
|
|
predictions = np.concatenate(all_predictions, axis=0) |
|
|
targets = np.concatenate(all_targets, axis=0) |
|
|
|
|
|
|
|
|
metrics = calculate_metrics(predictions, targets, class_names) |
|
|
metrics['val_loss'] = total_loss / len(dataloader) |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='Fine-tune BYOL model for classification') |
|
|
parser.add_argument('--byol_checkpoint', type=str, required=True, |
|
|
help='Path to BYOL checkpoint (.pth file)') |
|
|
parser.add_argument('--train_csv', type=str, required=True, |
|
|
help='Path to training CSV file') |
|
|
parser.add_argument('--val_csv', type=str, required=True, |
|
|
help='Path to validation CSV file') |
|
|
parser.add_argument('--tiles_dir', type=str, required=True, |
|
|
help='Directory containing tile images') |
|
|
parser.add_argument('--class_names', type=str, nargs='+', required=True, |
|
|
help='List of class names (e.g., mass calcification normal)') |
|
|
parser.add_argument('--output_dir', type=str, default='./classification_results', |
|
|
help='Output directory for checkpoints and logs') |
|
|
parser.add_argument('--config', type=str, default=None, |
|
|
help='JSON config file (overrides defaults)') |
|
|
parser.add_argument('--wandb_project', type=str, default='mammogram-classification', |
|
|
help='Weights & Biases project name') |
|
|
parser.add_argument('--max_samples', type=int, default=None, |
|
|
help='Limit dataset size for testing') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
config = DEFAULT_CONFIG.copy() |
|
|
if args.config: |
|
|
with open(args.config, 'r') as f: |
|
|
config.update(json.load(f)) |
|
|
|
|
|
|
|
|
output_dir = Path(args.output_dir) |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
try: |
|
|
wandb.init( |
|
|
project=args.wandb_project, |
|
|
config=config, |
|
|
name=f"classification_fine_tune_{len(args.class_names)}classes" |
|
|
) |
|
|
wandb_enabled = True |
|
|
except Exception as e: |
|
|
print(f"β οΈ WandB not configured: {e}") |
|
|
wandb_enabled = False |
|
|
|
|
|
print("π¬ BYOL Classification Fine-Tuning") |
|
|
print("=" * 50) |
|
|
print(f"Device: {DEVICE}") |
|
|
print(f"Classes: {args.class_names}") |
|
|
print(f"Batch size: {config['batch_size']}") |
|
|
print(f"Epochs: {config['epochs']}") |
|
|
print(f"Output directory: {output_dir}") |
|
|
|
|
|
|
|
|
model = load_byol_model(args.byol_checkpoint, len(args.class_names), DEVICE) |
|
|
|
|
|
|
|
|
train_transform = create_classification_transforms(TILE_SIZE, is_training=True) |
|
|
val_transform = create_classification_transforms(TILE_SIZE, is_training=False) |
|
|
|
|
|
train_dataset = MammogramClassificationDataset( |
|
|
args.train_csv, args.tiles_dir, args.class_names, |
|
|
train_transform, max_samples=args.max_samples |
|
|
) |
|
|
|
|
|
val_dataset = MammogramClassificationDataset( |
|
|
args.val_csv, args.tiles_dir, args.class_names, |
|
|
val_transform, max_samples=args.max_samples |
|
|
) |
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=config['batch_size'], |
|
|
shuffle=True, |
|
|
num_workers=config['num_workers'], |
|
|
pin_memory=True, |
|
|
drop_last=True |
|
|
) |
|
|
|
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=config['batch_size'], |
|
|
shuffle=False, |
|
|
num_workers=config['num_workers'], |
|
|
pin_memory=True |
|
|
) |
|
|
|
|
|
print(f"π Dataset sizes: Train={len(train_dataset)}, Val={len(val_dataset)}") |
|
|
|
|
|
|
|
|
|
|
|
pos_weight = None |
|
|
criterion = nn.BCEWithLogitsLoss( |
|
|
pos_weight=pos_weight, |
|
|
label_smoothing=config['label_smoothing'] |
|
|
) |
|
|
|
|
|
|
|
|
backbone_params = list(model.byol_model.backbone.parameters()) |
|
|
head_params = list(model.classification_head.parameters()) |
|
|
|
|
|
optimizer = optim.AdamW([ |
|
|
{'params': backbone_params, 'lr': config['lr_backbone']}, |
|
|
{'params': head_params, 'lr': config['lr_head']} |
|
|
], weight_decay=config['weight_decay']) |
|
|
|
|
|
|
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR( |
|
|
optimizer, T_max=config['epochs'], eta_min=1e-6 |
|
|
) |
|
|
|
|
|
|
|
|
scaler = GradScaler() |
|
|
|
|
|
|
|
|
best_metric = 0.0 |
|
|
|
|
|
for epoch in range(1, config['epochs'] + 1): |
|
|
|
|
|
freeze_backbone = epoch <= config['freeze_backbone_epochs'] |
|
|
if freeze_backbone: |
|
|
print(f"π§ Epoch {epoch}: Backbone frozen (training only classification head)") |
|
|
|
|
|
|
|
|
train_metrics = train_epoch( |
|
|
model, train_loader, criterion, optimizer, scaler, |
|
|
epoch, config, freeze_backbone |
|
|
) |
|
|
|
|
|
|
|
|
val_metrics = validate_epoch(model, val_loader, criterion, args.class_names) |
|
|
|
|
|
|
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
print(f"\nEpoch {epoch:3d}/{config['epochs']}:") |
|
|
print(f" Train Loss: {train_metrics['train_loss']:.4f}") |
|
|
print(f" Val Loss: {val_metrics['val_loss']:.4f}") |
|
|
print(f" Mean AUC: {val_metrics['mean_auc']:.4f}") |
|
|
print(f" Mean AP: {val_metrics['mean_ap']:.4f}") |
|
|
print(f" Exact Match: {val_metrics['exact_match_acc']:.4f}") |
|
|
|
|
|
|
|
|
if wandb_enabled: |
|
|
log_dict = {**train_metrics, **val_metrics, 'epoch': epoch} |
|
|
wandb.log(log_dict) |
|
|
|
|
|
|
|
|
current_metric = val_metrics['mean_auc'] |
|
|
if current_metric > best_metric: |
|
|
best_metric = current_metric |
|
|
checkpoint = { |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'scheduler_state_dict': scheduler.state_dict(), |
|
|
'val_metrics': val_metrics, |
|
|
'config': config, |
|
|
'class_names': args.class_names |
|
|
} |
|
|
torch.save(checkpoint, output_dir / 'best_classification_model.pth') |
|
|
print(f" β
New best model saved (AUC: {best_metric:.4f})") |
|
|
|
|
|
|
|
|
if epoch % 10 == 0: |
|
|
checkpoint = { |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'scheduler_state_dict': scheduler.state_dict(), |
|
|
'val_metrics': val_metrics, |
|
|
'config': config, |
|
|
'class_names': args.class_names |
|
|
} |
|
|
torch.save(checkpoint, output_dir / f'classification_epoch_{epoch}.pth') |
|
|
|
|
|
|
|
|
final_checkpoint = { |
|
|
'epoch': config['epochs'], |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'scheduler_state_dict': scheduler.state_dict(), |
|
|
'val_metrics': val_metrics, |
|
|
'config': config, |
|
|
'class_names': args.class_names |
|
|
} |
|
|
torch.save(final_checkpoint, output_dir / 'final_classification_model.pth') |
|
|
|
|
|
print(f"\nπ Classification training completed!") |
|
|
print(f"π Best validation AUC: {best_metric:.4f}") |
|
|
print(f"πΎ Models saved to: {output_dir}") |
|
|
|
|
|
if wandb_enabled: |
|
|
wandb.finish() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |