{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "b00e4cd9", "metadata": { "scrolled": true }, "outputs": [], "source": [ "!hf download SaladTechnologies/fiction-ner-750m --quiet --repo-type=dataset --local-dir .\n", "!unzip -q data.zip" ] }, { "cell_type": "code", "execution_count": null, "id": "b1be4895", "metadata": {}, "outputs": [], "source": [ "import string\n", "import random\n", "\n", "def get_random_string(length=8):\n", " \"\"\"Generate a random string of fixed length.\"\"\"\n", " letters = string.ascii_letters\n", " return ''.join(random.choice(letters) for i in range(length))\n", "\n", "run_name = f\"ner-{get_random_string(8)}\"" ] }, { "cell_type": "code", "execution_count": null, "id": "f21e8995", "metadata": {}, "outputs": [], "source": [ "from accelerate import notebook_launcher\n", "import os\n", "\n", "\n", "cuda_visible_devices = os.getenv(\"CUDA_VISIBLE_DEVICES\", \"0\")\n", "num_devices = len(cuda_visible_devices.split(\",\"))\n", "\n", "\n", "def train_fn():\n", " global num_processes\n", " from datasets import Dataset, concatenate_datasets\n", " import pandas as pd\n", " from pathlib import Path\n", " import random\n", " from transformers import AutoTokenizer\n", " import torch\n", " import numpy as np\n", " from transformers import AutoModelForTokenClassification\n", " from transformers.data.data_collator import DataCollatorForTokenClassification\n", " from transformers.training_args import TrainingArguments\n", " from transformers.trainer import Trainer\n", " from transformers.trainer_callback import TrainerCallback\n", " import numpy as np\n", " from sklearn.metrics import precision_recall_fscore_support\n", " import os\n", " import wandb\n", "\n", " num_epochs = int(os.getenv(\"NUM_EPOCHS\", 1))\n", " output_dir = os.getenv(\"OUTPUT_DIR\", \"./model\")\n", " seed = int(os.getenv(\"RANDOM_SEED\", 42))\n", " model_id = os.getenv(\"MODEL_ID\")\n", " hub_token = os.getenv(\"HF_TOKEN\")\n", " save_steps = float(os.getenv(\"SAVE_STEPS\", 100))\n", " if save_steps.is_integer():\n", " save_steps = int(save_steps)\n", " train_size = float(os.getenv(\"TRAIN_SIZE\", 12_000_000))\n", " test_size = float(os.getenv(\"TEST_SIZE\", 1_200_000))\n", " if train_size.is_integer():\n", " train_size = int(train_size)\n", " if test_size.is_integer():\n", " test_size = int(test_size)\n", " hidden_dropout_prob = float(os.getenv(\"HIDDEN_DROPOUT_PROB\", 0.14))\n", " attention_probs_dropout_prob = float(os.getenv(\"ATTENTION_PROBS_DROPOUT_PROB\", 0.14))\n", " frequency_exponent = float(os.getenv(\"FREQUENCY_EXPONENT\", 0.35))\n", " gamma = float(os.getenv(\"GAMMA\", 2.1))\n", " learning_rate = float(os.getenv(\"LEARNING_RATE\", 2.5e-5))\n", " lr_scheduler_type = os.getenv(\"LR_SCHEDULER_TYPE\", \"cosine\")\n", " weight_decay = float(os.getenv(\"WEIGHT_DECAY\", 0.007))\n", " warmup_ratio = float(os.getenv(\"WARMUP_RATIO\", 0.03))\n", " per_device_train_batch_size = int(os.getenv(\"PER_DEVICE_TRAIN_BATCH_SIZE\", 256))\n", " max_saved_checkpoints = int(os.getenv(\"MAX_SAVED_CHECKPOINTS\", 8))\n", " patience = max_saved_checkpoints - 1\n", "\n", " num_processes = torch.cuda.device_count()\n", " \n", " tokenizer = AutoTokenizer.from_pretrained(\"microsoft/deberta-v3-base\")\n", " \n", " data_dir = Path(\"data\")\n", " output = Path(output_dir)\n", " random.seed(seed)\n", " torch.manual_seed(seed)\n", " np.random.seed(seed)\n", "\n", " \n", " label_list = [\n", " \"O\",\n", " \"B-CHA\",\n", " \"I-CHA\",\n", " \"B-LOC\",\n", " \"I-LOC\",\n", " \"B-FAC\",\n", " \"I-FAC\",\n", " \"B-OBJ\",\n", " \"I-OBJ\",\n", " \"B-EVT\",\n", " \"I-EVT\",\n", " \"B-ORG\",\n", " \"I-ORG\",\n", " \"B-MISC\",\n", " \"I-MISC\"\n", " ]\n", " label_to_id = {label: i for i, label in enumerate(label_list)}\n", " id_to_label = {i: label for i, label in enumerate(label_list)}\n", "\n", " datasets = []\n", " for parquet_file in sorted(data_dir.glob(\"*.parquet\")):\n", " ds = Dataset.from_parquet(str(parquet_file))\n", " datasets.append(ds)\n", "\n", " full_ds = concatenate_datasets(datasets)\n", " splits = full_ds.train_test_split(train_size=train_size, test_size=test_size, seed=seed)\n", "\n", " train_ds = splits['train']\n", " eval_ds = splits['test']\n", "\n", " stats_file = \"label_counts.csv\"\n", " stats_df = pd.read_csv(stats_file)\n", " stats_df.head()\n", "\n", " total_count = stats_df[\"total\"].sum()\n", " label_frequencies = {\n", " label: stats_df[label].sum() / total_count for label in label_list\n", " }\n", " \n", " label_weights = {}\n", " for label, freq in label_frequencies.items():\n", " label_weights[label] = 1.0 / freq ** frequency_exponent\n", "\n", " weight_tensor = torch.tensor([label_weights[label] for label in label_list], dtype=torch.float32)\n", "\n", " model = AutoModelForTokenClassification.from_pretrained(\n", " \"microsoft/deberta-v3-base\",\n", " num_labels=len(label_list),\n", " id2label=id_to_label,\n", " label2id=label_to_id,\n", " ignore_mismatched_sizes=True,\n", " hidden_dropout_prob=hidden_dropout_prob,\n", " attention_probs_dropout_prob=attention_probs_dropout_prob\n", " )\n", " \n", " data_collator = DataCollatorForTokenClassification(\n", " tokenizer=tokenizer,\n", " padding=True\n", " )\n", "\n", "\n", " def create_compute_metrics_fn(eval_dataset):\n", " \"\"\"\n", " Factory function that creates a compute_metrics function with access to eval_dataset.\n", " \"\"\"\n", " def compute_metrics(eval_pred):\n", " predictions, labels = eval_pred\n", " predictions_raw = predictions # Keep raw predictions for logging\n", " predictions = np.argmax(predictions, axis=2)\n", " \n", " # Remove ignored indices\n", " true_predictions = [\n", " [id_to_label[p] for (p, l) in zip(pred, label) if l != -100]\n", " for pred, label in zip(predictions, labels)\n", " ]\n", " true_labels = [\n", " [id_to_label[l] for (p, l) in zip(pred, label) if l != -100]\n", " for pred, label in zip(predictions, labels)\n", " ]\n", " \n", " # Flatten\n", " all_predictions = [item for sublist in true_predictions for item in sublist]\n", " all_labels = [item for sublist in true_labels for item in sublist]\n", " \n", " # Calculate metrics excluding 'O' class\n", " entity_labels = [l for l in label_list if l != 'O']\n", " \n", " precision, recall, f1, support = precision_recall_fscore_support(\n", " all_labels,\n", " all_predictions,\n", " labels=entity_labels,\n", " average='weighted',\n", " zero_division=0\n", " )\n", "\n", " return {\n", " 'entity_precision': precision,\n", " 'entity_recall': recall,\n", " 'entity_f1': f1,\n", " }\n", " \n", " return compute_metrics\n", "\n", " # Create the compute_metrics function with access to eval_ds\n", " compute_metrics = create_compute_metrics_fn(eval_ds)\n", "\n", " class FocalLoss(torch.nn.Module):\n", " def __init__(self, alpha=None, gamma=2.0, reduction='mean', ignore_index=-100):\n", " \"\"\"\n", " alpha: class weights tensor\n", " gamma: focusing parameter (higher = more focus on hard examples)\n", " ignore_index: label to ignore (for padding tokens)\n", " \"\"\"\n", " super().__init__()\n", " self.alpha = alpha\n", " self.gamma = gamma\n", " self.reduction = reduction\n", " self.ignore_index = ignore_index\n", " \n", " def forward(self, logits, labels):\n", " # logits shape: (batch_size, seq_len, num_classes)\n", " # labels shape: (batch_size, seq_len)\n", " \n", " # Reshape for loss calculation\n", " logits_flat = logits.view(-1, logits.size(-1)) # (batch*seq_len, num_classes)\n", " labels_flat = labels.view(-1) # (batch*seq_len)\n", " \n", " # Calculate cross entropy (without reduction)\n", " ce_loss = torch.nn.functional.cross_entropy(\n", " logits_flat, \n", " labels_flat, \n", " reduction='none',\n", " ignore_index=self.ignore_index\n", " )\n", " \n", " # Get the probabilities for the correct class\n", " p = torch.exp(-ce_loss)\n", " \n", " # Calculate focal term: (1 - p)^gamma\n", " focal_term = (1 - p) ** self.gamma\n", " \n", " # Apply focal term to loss\n", " focal_loss = focal_term * ce_loss\n", " \n", " # Apply class weights if provided\n", " if self.alpha is not None:\n", " # Create a mask for valid (non-ignored) tokens\n", " valid_mask = labels_flat != self.ignore_index\n", " \n", " # Gather the weights for each sample's true class\n", " # Only for valid labels to avoid index errors\n", " valid_labels = labels_flat.clone()\n", " valid_labels[~valid_mask] = 0 # Set ignored labels to 0 to avoid index errors\n", " \n", " alpha_t = self.alpha.gather(0, valid_labels)\n", " # Apply mask to weights\n", " alpha_t = alpha_t * valid_mask.float()\n", " \n", " focal_loss = alpha_t * focal_loss\n", " \n", " # Apply reduction\n", " if self.reduction == 'mean':\n", " # Only average over non-ignored tokens\n", " valid_tokens = (labels_flat != self.ignore_index).sum()\n", " return focal_loss.sum() / valid_tokens.clamp(min=1)\n", " elif self.reduction == 'sum':\n", " return focal_loss.sum()\n", " else:\n", " return focal_loss\n", " \n", " class FocalLossTrainer(Trainer):\n", " def __init__(self, *args, class_weights=None, gamma=2.0, **kwargs):\n", " super().__init__(*args, **kwargs)\n", " self.class_weights = class_weights\n", " self.gamma = gamma\n", " \n", " def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):\n", " \"\"\"\n", " Override compute_loss to use focal loss.\n", " num_items_in_batch parameter added for compatibility with newer transformers versions.\n", " \"\"\"\n", " labels = inputs.get(\"labels\")\n", " outputs = model(**inputs)\n", " logits = outputs.get(\"logits\")\n", " \n", " # Move weights to the same device as logits\n", " if self.class_weights is not None:\n", " weights = self.class_weights.to(logits.device)\n", " else:\n", " weights = None\n", " \n", " # Initialize focal loss\n", " loss_fct = FocalLoss(\n", " alpha=weights,\n", " gamma=self.gamma,\n", " ignore_index=-100\n", " )\n", " \n", " # Calculate loss\n", " loss = loss_fct(logits, labels)\n", " \n", " return (loss, outputs) if return_outputs else loss\n", "\n", " \n", "\n", " training_args = TrainingArguments(\n", " output_dir=str(output),\n", " learning_rate=learning_rate,\n", " lr_scheduler_type=lr_scheduler_type,\n", " per_device_train_batch_size=per_device_train_batch_size,\n", " weight_decay=weight_decay,\n", " warmup_ratio=warmup_ratio,\n", " gradient_accumulation_steps=1,\n", " logging_steps=50,\n", " num_train_epochs=num_epochs,\n", " save_strategy=\"steps\",\n", " save_steps=save_steps,\n", " save_total_limit=3,\n", " eval_strategy=\"steps\",\n", " eval_steps=save_steps,\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"eval_entity_f1\",\n", " greater_is_better=True,\n", " bf16=True,\n", " tf32=True,\n", " report_to='wandb',\n", " run_name=run_name,\n", " push_to_hub=True,\n", " hub_strategy=\"checkpoint\",\n", " hub_token=hub_token,\n", " dataloader_persistent_workers=True,\n", " dataloader_num_workers=2,\n", " dataloader_pin_memory=True,\n", " ddp_find_unused_parameters=False,\n", " gradient_checkpointing=False,\n", " hub_model_id=model_id,\n", " hub_private_repo=True\n", " )\n", "\n", " class CustomEarlyStoppingCallback(TrainerCallback):\n", " def __init__(self, patience=2, threshold=0.001):\n", " self.patience = patience\n", " self.threshold = threshold\n", " self.best_metric = None\n", " self.wait = 0\n", " \n", " def on_evaluate(self, args, state, control, metrics=None, **kwargs):\n", " if metrics is None or \"eval_entity_f1\" not in metrics:\n", " return control\n", " metric_value = metrics.get(\"eval_entity_f1\")\n", " \n", " if self.best_metric is None:\n", " self.best_metric = metric_value\n", " elif metric_value > self.best_metric + self.threshold:\n", " self.best_metric = metric_value\n", " self.wait = 0\n", " else:\n", " self.wait += 1\n", " if self.wait >= self.patience:\n", " control.should_training_stop = True\n", " print(f\"Early stopping triggered. Best F1: {self.best_metric:.4f}\")\n", " \n", " return control\n", " \n", "\n", " trainer = FocalLossTrainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_ds,\n", " eval_dataset=eval_ds,\n", " processing_class=tokenizer,\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", " class_weights=weight_tensor,\n", " gamma=gamma,\n", " callbacks=[CustomEarlyStoppingCallback(patience=patience, threshold=0.0001)]\n", " )\n", " \n", " if wandb.run is not None:\n", " # Add custom config values\n", " wandb.config.update({\n", " # Data configuration\n", " \"train_samples\": len(train_ds),\n", " \"eval_samples\": len(eval_ds),\n", " \"train_size_requested\": train_size,\n", " \"test_size_requested\": test_size,\n", " \"actual_train_size\": len(train_ds),\n", " \"actual_eval_size\": len(eval_ds),\n", "\n", " # Model architecture details\n", " \"model_architecture\": \"deberta-v3-base\",\n", " \"num_labels\": len(label_list),\n", " \"label_list\": label_list,\n", "\n", " # Loss function configuration\n", " \"loss_function\": \"focal_loss\",\n", " \"focal_gamma\": gamma,\n", " \"focal_alpha\": \"weighted\",\n", " \"frequency_exponent\": frequency_exponent,\n", "\n", " # Dropout configuration\n", " \"hidden_dropout_prob\": hidden_dropout_prob,\n", " \"attention_probs_dropout_prob\": attention_probs_dropout_prob,\n", "\n", " # Training configuration not in TrainingArguments\n", " \"max_saved_checkpoints\": max_saved_checkpoints,\n", " \"early_stopping_patience\": patience,\n", " \"early_stopping_threshold\": 0.001,\n", "\n", " # Environment info\n", " \"cuda_devices\": cuda_visible_devices,\n", " \"num_gpus\": num_devices,\n", "\n", " # Data processing\n", " \"tokenizer\": \"microsoft/deberta-v3-base\"\n", "\n", " # Experiment metadata\n", " \"experiment_type\": \"ner_fiction\",\n", " \"data_source\": \"gutenberg_ao3_mixed\",\n", " \"random_seed\": seed,\n", "\n", " # Logging configuration\n", " \"n_eval_samples\": n_eval_samples,\n", " \"log_predictions_to_wandb\": log_predictions_to_wandb,\n", " })\n", "\n", " has_checkpoints = bool([f for f in os.scandir(output_dir) if f.is_dir() and \"checkpoint\" in f.name])\n", " if has_checkpoints:\n", " trainer.train(resume_from_checkpoint=True)\n", " else:\n", " trainer.train()\n", "\n", "notebook_launcher(train_fn, num_processes=num_devices)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }