Spaces:
Sleeping
Sleeping
| import argparse | |
| import os | |
| import json | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from sklearn.metrics import roc_auc_score | |
| import numpy as np | |
| import logging | |
| from tqdm import tqdm | |
| from data import load_data, Tox21Dataset, TASKS | |
| from model import DMPNN, MolGraph, BatchMolGraph | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| def load_settings(): | |
| """ | |
| Leaderboard Requirement: Load hyperparameters from a config folder. | |
| """ | |
| config_path = os.path.join("config", "model_config.json") | |
| if os.path.exists(config_path): | |
| with open(config_path, "r") as f: | |
| logger.info(f"Loading configuration from {config_path}") | |
| return json.load(f) | |
| logger.warning("Config file not found in config/ folder. Using default values.") | |
| return {} | |
| def set_seed(seed): | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| def collate_fn(batch): | |
| smiles = [item['smiles'] for item in batch] | |
| labels = torch.stack([item['labels'] for item in batch]) | |
| global_features = torch.stack([item['global_features'] for item in batch]) | |
| mol_graphs = [MolGraph(s) for s in smiles] | |
| batch_graph = BatchMolGraph(mol_graphs) | |
| return batch_graph, global_features, labels | |
| def calculate_pos_weights(train_data): | |
| labels = torch.stack([item['labels'] for item in train_data]) | |
| weights = [] | |
| for i in range(labels.shape[1]): | |
| task_labels = labels[:, i] | |
| valid_labels = task_labels[task_labels != -1] | |
| num_pos = (valid_labels == 1).sum().item() | |
| num_neg = (valid_labels == 0).sum().item() | |
| weight = num_neg / max(num_pos, 1) | |
| weights.append(weight) | |
| return torch.tensor(weights, dtype=torch.float32).to(DEVICE) | |
| def evaluate(model, loader): | |
| model.eval() | |
| all_preds = [] | |
| all_labels = [] | |
| with torch.no_grad(): | |
| for batch_graph, global_feats, labels in loader: | |
| logits = model(batch_graph, global_feats) | |
| preds = torch.sigmoid(logits) | |
| all_preds.append(preds.cpu().numpy()) | |
| all_labels.append(labels.numpy()) | |
| all_preds = np.vstack(all_preds) | |
| all_labels = np.vstack(all_labels) | |
| aucs = [] | |
| for i in range(len(TASKS)): | |
| valid_idx = all_labels[:, i] != -1 | |
| if sum(all_labels[valid_idx, i]) > 0 and sum(all_labels[valid_idx, i]) < len(all_labels[valid_idx, i]): | |
| aucs.append(roc_auc_score(all_labels[valid_idx, i], all_preds[valid_idx, i])) | |
| return np.mean(aucs) if aucs else 0.5 | |
| def train_model(train_loader, val_loader, pos_weights, settings, seed, global_feats_size): | |
| set_seed(seed) | |
| model_path = f"checkpoints/model_seed_{seed}.pt" | |
| os.makedirs("checkpoints", exist_ok=True) | |
| # Hyperparams from config | |
| hidden_size = settings.get("hidden_size", 300) | |
| depth = settings.get("depth", 3) | |
| lr = settings.get("learning_rate", 1e-3) | |
| epochs = settings.get("epochs", 30) | |
| model = DMPNN( | |
| hidden_size=hidden_size, | |
| depth=depth, | |
| global_feats_size=global_feats_size, | |
| n_tasks=len(TASKS) | |
| ).to(DEVICE) | |
| optimizer = optim.Adam(model.parameters(), lr=lr) | |
| criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights, reduction='none') | |
| best_val_auc = 0 | |
| for epoch in range(epochs): | |
| model.train() | |
| train_loss = 0 | |
| for batch_graph, global_feats, labels in train_loader: | |
| optimizer.zero_grad() | |
| logits = model(batch_graph, global_feats) | |
| # Mask missing values (-1) | |
| mask = (labels != -1).float().to(DEVICE) | |
| labels_fixed = torch.clamp(labels, min=0).to(DEVICE) | |
| loss = criterion(logits, labels_fixed) | |
| loss = (loss * mask).sum() / mask.sum() | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| val_auc = evaluate(model, val_loader) | |
| if val_auc > best_val_auc: | |
| best_val_auc = val_auc | |
| torch.save(model.state_dict(), model_path) | |
| return best_val_auc | |
| def train_ensemble(): | |
| settings = load_settings() | |
| logger.info("Loading data from HF Hub...") | |
| train_data, val_data, _ = load_data() | |
| global_feats_size = len(train_data[0]['global_features']) | |
| pos_weights = calculate_pos_weights(train_data) | |
| batch_size = settings.get("batch_size", 50) | |
| seeds = settings.get("ensemble_seeds", [42, 43, 44, 45, 46]) | |
| train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) | |
| val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, collate_fn=collate_fn) | |
| for seed in seeds: | |
| logger.info(f"Starting training for seed {seed}...") | |
| best_auc = train_model(train_loader, val_loader, pos_weights, settings, seed, global_feats_size) | |
| logger.info(f"Seed {seed} complete. Best Val AUC: {best_auc:.4f}") | |
| if __name__ == "__main__": | |
| train_ensemble() |