| | import optuna |
| | from optuna.samplers import TPESampler |
| | from optuna.pruners import MedianPruner |
| | import wandb |
| | import pandas as pd |
| | from model.train import train, init_model, create_dataloaders, ToxicDataset |
| | from model.training_config import TrainingConfig |
| | from transformers import XLMRobertaTokenizer |
| | import json |
| | import torch |
| |
|
| | def load_dataset(file_path: str): |
| | """Load and prepare dataset""" |
| | df = pd.read_csv(file_path) |
| | tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') |
| | config = TrainingConfig() |
| | return ToxicDataset(df, tokenizer, config) |
| |
|
| | class HyperparameterTuner: |
| | def __init__(self, train_dataset, val_dataset, n_trials=10): |
| | self.train_dataset = train_dataset |
| | self.val_dataset = val_dataset |
| | self.n_trials = n_trials |
| | |
| | |
| | self.study = optuna.create_study( |
| | direction="maximize", |
| | sampler=TPESampler(seed=42), |
| | pruner=MedianPruner( |
| | n_startup_trials=2, |
| | n_warmup_steps=2, |
| | interval_steps=1 |
| | ) |
| | ) |
| |
|
| | def objective(self, trial): |
| | """Objective function for Optuna optimization with optimal ranges""" |
| | |
| | config_params = { |
| | |
| | "model_name": "xlm-roberta-large", |
| | "hidden_size": 1024, |
| | "num_attention_heads": 16, |
| | |
| | |
| | "lr": trial.suggest_float("lr", 1e-5, 5e-5, log=True), |
| | "batch_size": trial.suggest_categorical("batch_size", [32, 64]), |
| | "model_dropout": trial.suggest_float("model_dropout", 0.3, 0.45), |
| | "weight_decay": trial.suggest_float("weight_decay", 0.01, 0.03), |
| | "grad_accum_steps": trial.suggest_int("grad_accum_steps", 1, 4), |
| | |
| | |
| | "epochs": 2, |
| | "mixed_precision": "bf16", |
| | "max_length": 128, |
| | "fp16": False, |
| | "distributed": False, |
| | "world_size": 1, |
| | "num_workers": 12, |
| | "activation_checkpointing": True, |
| | "tensor_float_32": True, |
| | "gc_frequency": 500 |
| | } |
| |
|
| | |
| | config = TrainingConfig(**config_params) |
| |
|
| | |
| | wandb.init( |
| | project="toxic-classification-hparam-tuning", |
| | name=f"trial-{trial.number}", |
| | config={ |
| | **config_params, |
| | 'trial_number': trial.number, |
| | 'pruner': str(trial.study.pruner), |
| | 'sampler': str(trial.study.sampler) |
| | }, |
| | reinit=True, |
| | tags=['hyperparameter-optimization', f'trial-{trial.number}'] |
| | ) |
| |
|
| | try: |
| | |
| | model = init_model(config) |
| | train_loader, val_loader = create_dataloaders( |
| | self.train_dataset, |
| | self.val_dataset, |
| | config |
| | ) |
| |
|
| | |
| | metrics = train(model, train_loader, val_loader, config) |
| | |
| | |
| | wandb.log({ |
| | 'final_val_auc': metrics['val/auc'], |
| | 'final_val_loss': metrics['val/loss'], |
| | 'final_train_loss': metrics['train/loss'], |
| | 'peak_gpu_memory': torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0, |
| | 'trial_completed': True |
| | }) |
| | |
| | |
| | trial.report(metrics['val/auc'], step=config.epochs) |
| | |
| | |
| | if trial.should_prune(): |
| | wandb.log({'pruned': True}) |
| | raise optuna.TrialPruned() |
| |
|
| | return metrics['val/auc'] |
| |
|
| | except Exception as e: |
| | wandb.log({ |
| | 'error': str(e), |
| | 'trial_failed': True |
| | }) |
| | print(f"Trial failed: {str(e)}") |
| | raise optuna.TrialPruned() |
| |
|
| | finally: |
| | |
| | if 'model' in locals(): |
| | del model |
| | torch.cuda.empty_cache() |
| | wandb.finish() |
| |
|
| | def run_optimization(self): |
| | """Run the hyperparameter optimization""" |
| | print("Starting hyperparameter optimization...") |
| | print("Search space:") |
| | print(" - Learning rate: 1e-5 to 5e-5") |
| | print(" - Batch size: [32, 64]") |
| | print(" - Dropout: 0.3 to 0.45") |
| | print(" - Weight decay: 0.01 to 0.03") |
| | print(" - Gradient accumulation steps: 1 to 4") |
| | print("\nFixed parameters:") |
| | print(" - Hidden size: 1024 (original)") |
| | print(" - Attention heads: 16 (original)") |
| | |
| | try: |
| | self.study.optimize( |
| | self.objective, |
| | n_trials=self.n_trials, |
| | timeout=None, |
| | callbacks=[self._log_trial] |
| | ) |
| |
|
| | |
| | print("\nBest trial:") |
| | best_trial = self.study.best_trial |
| | print(f" Value: {best_trial.value:.4f}") |
| | print(" Params:") |
| | for key, value in best_trial.params.items(): |
| | print(f" {key}: {value}") |
| |
|
| | |
| | self._save_study_results() |
| |
|
| | except KeyboardInterrupt: |
| | print("\nOptimization interrupted by user.") |
| | self._save_study_results() |
| | except Exception as e: |
| | print(f"Optimization failed: {str(e)}") |
| | raise |
| |
|
| | def _log_trial(self, study, trial): |
| | """Callback to log trial results with enhanced metrics""" |
| | if trial.value is not None: |
| | metrics = { |
| | "best_auc": study.best_value, |
| | "trial_auc": trial.value, |
| | "trial_number": trial.number, |
| | **trial.params |
| | } |
| | |
| | |
| | if len(study.trials) > 1: |
| | metrics.update({ |
| | "optimization_progress": { |
| | "trials_completed": len(study.trials), |
| | "improvement_rate": (study.best_value - study.trials[0].value) / len(study.trials), |
| | "best_trial_number": study.best_trial.number |
| | } |
| | }) |
| | |
| | wandb.log(metrics) |
| |
|
| | def _save_study_results(self): |
| | """Save optimization results with enhanced metadata""" |
| | import joblib |
| | from pathlib import Path |
| | from datetime import datetime |
| | |
| | |
| | results_dir = Path("optimization_results") |
| | results_dir.mkdir(exist_ok=True) |
| | |
| | |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | study_path = results_dir / f"hparam_optimization_study_{timestamp}.pkl" |
| | joblib.dump(self.study, study_path) |
| | |
| | |
| | results = { |
| | "best_trial": { |
| | "number": self.study.best_trial.number, |
| | "value": self.study.best_value, |
| | "params": self.study.best_trial.params |
| | }, |
| | "study_statistics": { |
| | "n_trials": len(self.study.trials), |
| | "n_completed": len([t for t in self.study.trials if t.state == optuna.trial.TrialState.COMPLETE]), |
| | "n_pruned": len([t for t in self.study.trials if t.state == optuna.trial.TrialState.PRUNED]), |
| | "datetime_start": self.study.trials[0].datetime_start.isoformat(), |
| | "datetime_complete": datetime.now().isoformat() |
| | }, |
| | "search_space": { |
| | "lr": {"low": 1e-5, "high": 5e-5}, |
| | "batch_size": [32, 64], |
| | "model_dropout": {"low": 0.3, "high": 0.45}, |
| | "weight_decay": {"low": 0.01, "high": 0.03}, |
| | "grad_accum_steps": {"low": 1, "high": 4} |
| | }, |
| | "trial_history": [ |
| | { |
| | "number": t.number, |
| | "value": t.value, |
| | "state": str(t.state), |
| | "params": t.params if hasattr(t, 'params') else None |
| | } |
| | for t in self.study.trials |
| | ] |
| | } |
| | |
| | results_path = results_dir / f"optimization_results_{timestamp}.json" |
| | with open(results_path, "w") as f: |
| | json.dump(results, f, indent=4) |
| | |
| | print(f"\nResults saved to:") |
| | print(f" - Study: {study_path}") |
| | print(f" - Results: {results_path}") |
| |
|
| | def main(): |
| | """Main function to run hyperparameter optimization""" |
| | |
| | train_dataset = load_dataset("dataset/split/train.csv") |
| | val_dataset = load_dataset("dataset/split/val.csv") |
| |
|
| | |
| | tuner = HyperparameterTuner( |
| | train_dataset=train_dataset, |
| | val_dataset=val_dataset, |
| | n_trials=10 |
| | ) |
| |
|
| | |
| | tuner.run_optimization() |
| |
|
| | if __name__ == "__main__": |
| | main() |