tox21-classifier / train.py
sk16er's picture
Update train.py
25c9bf9 verified
import argparse
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score
import numpy as np
import logging
from tqdm import tqdm
from data import load_data, Tox21Dataset, TASKS
from model import DMPNN, MolGraph, BatchMolGraph
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_settings():
"""
Leaderboard Requirement: Load hyperparameters from a config folder.
"""
config_path = os.path.join("config", "model_config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
logger.info(f"Loading configuration from {config_path}")
return json.load(f)
logger.warning("Config file not found in config/ folder. Using default values.")
return {}
def set_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def collate_fn(batch):
smiles = [item['smiles'] for item in batch]
labels = torch.stack([item['labels'] for item in batch])
global_features = torch.stack([item['global_features'] for item in batch])
mol_graphs = [MolGraph(s) for s in smiles]
batch_graph = BatchMolGraph(mol_graphs)
return batch_graph, global_features, labels
def calculate_pos_weights(train_data):
labels = torch.stack([item['labels'] for item in train_data])
weights = []
for i in range(labels.shape[1]):
task_labels = labels[:, i]
valid_labels = task_labels[task_labels != -1]
num_pos = (valid_labels == 1).sum().item()
num_neg = (valid_labels == 0).sum().item()
weight = num_neg / max(num_pos, 1)
weights.append(weight)
return torch.tensor(weights, dtype=torch.float32).to(DEVICE)
def evaluate(model, loader):
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for batch_graph, global_feats, labels in loader:
logits = model(batch_graph, global_feats)
preds = torch.sigmoid(logits)
all_preds.append(preds.cpu().numpy())
all_labels.append(labels.numpy())
all_preds = np.vstack(all_preds)
all_labels = np.vstack(all_labels)
aucs = []
for i in range(len(TASKS)):
valid_idx = all_labels[:, i] != -1
if sum(all_labels[valid_idx, i]) > 0 and sum(all_labels[valid_idx, i]) < len(all_labels[valid_idx, i]):
aucs.append(roc_auc_score(all_labels[valid_idx, i], all_preds[valid_idx, i]))
return np.mean(aucs) if aucs else 0.5
def train_model(train_loader, val_loader, pos_weights, settings, seed, global_feats_size):
set_seed(seed)
model_path = f"checkpoints/model_seed_{seed}.pt"
os.makedirs("checkpoints", exist_ok=True)
# Hyperparams from config
hidden_size = settings.get("hidden_size", 300)
depth = settings.get("depth", 3)
lr = settings.get("learning_rate", 1e-3)
epochs = settings.get("epochs", 30)
model = DMPNN(
hidden_size=hidden_size,
depth=depth,
global_feats_size=global_feats_size,
n_tasks=len(TASKS)
).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights, reduction='none')
best_val_auc = 0
for epoch in range(epochs):
model.train()
train_loss = 0
for batch_graph, global_feats, labels in train_loader:
optimizer.zero_grad()
logits = model(batch_graph, global_feats)
# Mask missing values (-1)
mask = (labels != -1).float().to(DEVICE)
labels_fixed = torch.clamp(labels, min=0).to(DEVICE)
loss = criterion(logits, labels_fixed)
loss = (loss * mask).sum() / mask.sum()
loss.backward()
optimizer.step()
train_loss += loss.item()
val_auc = evaluate(model, val_loader)
if val_auc > best_val_auc:
best_val_auc = val_auc
torch.save(model.state_dict(), model_path)
return best_val_auc
def train_ensemble():
settings = load_settings()
logger.info("Loading data from HF Hub...")
train_data, val_data, _ = load_data()
global_feats_size = len(train_data[0]['global_features'])
pos_weights = calculate_pos_weights(train_data)
batch_size = settings.get("batch_size", 50)
seeds = settings.get("ensemble_seeds", [42, 43, 44, 45, 46])
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
for seed in seeds:
logger.info(f"Starting training for seed {seed}...")
best_auc = train_model(train_loader, val_loader, pos_weights, settings, seed, global_feats_size)
logger.info(f"Seed {seed} complete. Best Val AUC: {best_auc:.4f}")
if __name__ == "__main__":
train_ensemble()