In [None]:
!hf download SaladTechnologies/fiction-ner-750m --quiet --repo-type=dataset --local-dir .
!unzip -q data.zip

In [None]:
import string
import random

def get_random_string(length=8):
 """Generate a random string of fixed length."""
 letters = string.ascii_letters
 return ''.join(random.choice(letters) for i in range(length))

run_name = f"ner-{get_random_string(8)}"

In [None]:
from accelerate import notebook_launcher
import os


cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "0")
num_devices = len(cuda_visible_devices.split(","))


def train_fn():
 global num_processes
 from datasets import Dataset, concatenate_datasets
 import pandas as pd
 from pathlib import Path
 import random
 from transformers import AutoTokenizer
 import torch
 import numpy as np
 from transformers import AutoModelForTokenClassification
 from transformers.data.data_collator import DataCollatorForTokenClassification
 from transformers.training_args import TrainingArguments
 from transformers.trainer import Trainer
 from transformers.trainer_callback import TrainerCallback
 import numpy as np
 from sklearn.metrics import precision_recall_fscore_support
 import os
 import wandb

 num_epochs = int(os.getenv("NUM_EPOCHS", 1))
 output_dir = os.getenv("OUTPUT_DIR", "./model")
 seed = int(os.getenv("RANDOM_SEED", 42))
 model_id = os.getenv("MODEL_ID")
 hub_token = os.getenv("HF_TOKEN")
 save_steps = float(os.getenv("SAVE_STEPS", 100))
 if save_steps.is_integer():
 save_steps = int(save_steps)
 train_size = float(os.getenv("TRAIN_SIZE", 12_000_000))
 test_size = float(os.getenv("TEST_SIZE", 1_200_000))
 if train_size.is_integer():
 train_size = int(train_size)
 if test_size.is_integer():
 test_size = int(test_size)
 hidden_dropout_prob = float(os.getenv("HIDDEN_DROPOUT_PROB", 0.14))
 attention_probs_dropout_prob = float(os.getenv("ATTENTION_PROBS_DROPOUT_PROB", 0.14))
 frequency_exponent = float(os.getenv("FREQUENCY_EXPONENT", 0.35))
 gamma = float(os.getenv("GAMMA", 2.1))
 learning_rate = float(os.getenv("LEARNING_RATE", 2.5e-5))
 lr_scheduler_type = os.getenv("LR_SCHEDULER_TYPE", "cosine")
 weight_decay = float(os.getenv("WEIGHT_DECAY", 0.007))
 warmup_ratio = float(os.getenv("WARMUP_RATIO", 0.03))
 per_device_train_batch_size = int(os.getenv("PER_DEVICE_TRAIN_BATCH_SIZE", 256))
 max_saved_checkpoints = int(os.getenv("MAX_SAVED_CHECKPOINTS", 8))
 patience = max_saved_checkpoints - 1

 num_processes = torch.cuda.device_count()
 
 tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base")
 
 data_dir = Path("data")
 output = Path(output_dir)
 random.seed(seed)
 torch.manual_seed(seed)
 np.random.seed(seed)

 
 label_list = [
 "O",
 "B-CHA",
 "I-CHA",
 "B-LOC",
 "I-LOC",
 "B-FAC",
 "I-FAC",
 "B-OBJ",
 "I-OBJ",
 "B-EVT",
 "I-EVT",
 "B-ORG",
 "I-ORG",
 "B-MISC",
 "I-MISC"
 ]
 label_to_id = {label: i for i, label in enumerate(label_list)}
 id_to_label = {i: label for i, label in enumerate(label_list)}

 datasets = []
 for parquet_file in sorted(data_dir.glob("*.parquet")):
 ds = Dataset.from_parquet(str(parquet_file))
 datasets.append(ds)

 full_ds = concatenate_datasets(datasets)
 splits = full_ds.train_test_split(train_size=train_size, test_size=test_size, seed=seed)

 train_ds = splits['train']
 eval_ds = splits['test']

 stats_file = "label_counts.csv"
 stats_df = pd.read_csv(stats_file)
 stats_df.head()

 total_count = stats_df["total"].sum()
 label_frequencies = {
 label: stats_df[label].sum() / total_count for label in label_list
 }
 
 label_weights = {}
 for label, freq in label_frequencies.items():
 label_weights[label] = 1.0 / freq ** frequency_exponent

 weight_tensor = torch.tensor([label_weights[label] for label in label_list], dtype=torch.float32)

 model = AutoModelForTokenClassification.from_pretrained(
 "microsoft/deberta-v3-base",
 num_labels=len(label_list),
 id2label=id_to_label,
 label2id=label_to_id,
 ignore_mismatched_sizes=True,
 hidden_dropout_prob=hidden_dropout_prob,
 attention_probs_dropout_prob=attention_probs_dropout_prob
 )
 
 data_collator = DataCollatorForTokenClassification(
 tokenizer=tokenizer,
 padding=True
 )


 def create_compute_metrics_fn(eval_dataset):
 """
 Factory function that creates a compute_metrics function with access to eval_dataset.
 """
 def compute_metrics(eval_pred):
 predictions, labels = eval_pred
 predictions_raw = predictions # Keep raw predictions for logging
 predictions = np.argmax(predictions, axis=2)
 
 # Remove ignored indices
 true_predictions = [
 [id_to_label[p] for (p, l) in zip(pred, label) if l != -100]
 for pred, label in zip(predictions, labels)
 ]
 true_labels = [
 [id_to_label[l] for (p, l) in zip(pred, label) if l != -100]
 for pred, label in zip(predictions, labels)
 ]
 
 # Flatten
 all_predictions = [item for sublist in true_predictions for item in sublist]
 all_labels = [item for sublist in true_labels for item in sublist]
 
 # Calculate metrics excluding 'O' class
 entity_labels = [l for l in label_list if l != 'O']
 
 precision, recall, f1, support = precision_recall_fscore_support(
 all_labels,
 all_predictions,
 labels=entity_labels,
 average='weighted',
 zero_division=0
 )

 return {
 'entity_precision': precision,
 'entity_recall': recall,
 'entity_f1': f1,
 }
 
 return compute_metrics

 # Create the compute_metrics function with access to eval_ds
 compute_metrics = create_compute_metrics_fn(eval_ds)

 class FocalLoss(torch.nn.Module):
 def __init__(self, alpha=None, gamma=2.0, reduction='mean', ignore_index=-100):
 """
 alpha: class weights tensor
 gamma: focusing parameter (higher = more focus on hard examples)
 ignore_index: label to ignore (for padding tokens)
 """
 super().__init__()
 self.alpha = alpha
 self.gamma = gamma
 self.reduction = reduction
 self.ignore_index = ignore_index
 
 def forward(self, logits, labels):
 # logits shape: (batch_size, seq_len, num_classes)
 # labels shape: (batch_size, seq_len)
 
 # Reshape for loss calculation
 logits_flat = logits.view(-1, logits.size(-1)) # (batch*seq_len, num_classes)
 labels_flat = labels.view(-1) # (batch*seq_len)
 
 # Calculate cross entropy (without reduction)
 ce_loss = torch.nn.functional.cross_entropy(
 logits_flat, 
 labels_flat, 
 reduction='none',
 ignore_index=self.ignore_index
 )
 
 # Get the probabilities for the correct class
 p = torch.exp(-ce_loss)
 
 # Calculate focal term: (1 - p)^gamma
 focal_term = (1 - p) ** self.gamma
 
 # Apply focal term to loss
 focal_loss = focal_term * ce_loss
 
 # Apply class weights if provided
 if self.alpha is not None:
 # Create a mask for valid (non-ignored) tokens
 valid_mask = labels_flat != self.ignore_index
 
 # Gather the weights for each sample's true class
 # Only for valid labels to avoid index errors
 valid_labels = labels_flat.clone()
 valid_labels[~valid_mask] = 0 # Set ignored labels to 0 to avoid index errors
 
 alpha_t = self.alpha.gather(0, valid_labels)
 # Apply mask to weights
 alpha_t = alpha_t * valid_mask.float()
 
 focal_loss = alpha_t * focal_loss
 
 # Apply reduction
 if self.reduction == 'mean':
 # Only average over non-ignored tokens
 valid_tokens = (labels_flat != self.ignore_index).sum()
 return focal_loss.sum() / valid_tokens.clamp(min=1)
 elif self.reduction == 'sum':
 return focal_loss.sum()
 else:
 return focal_loss
 
 class FocalLossTrainer(Trainer):
 def __init__(self, *args, class_weights=None, gamma=2.0, **kwargs):
 super().__init__(*args, **kwargs)
 self.class_weights = class_weights
 self.gamma = gamma
 
 def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
 """
 Override compute_loss to use focal loss.
 num_items_in_batch parameter added for compatibility with newer transformers versions.
 """
 labels = inputs.get("labels")
 outputs = model(**inputs)
 logits = outputs.get("logits")
 
 # Move weights to the same device as logits
 if self.class_weights is not None:
 weights = self.class_weights.to(logits.device)
 else:
 weights = None
 
 # Initialize focal loss
 loss_fct = FocalLoss(
 alpha=weights,
 gamma=self.gamma,
 ignore_index=-100
 )
 
 # Calculate loss
 loss = loss_fct(logits, labels)
 
 return (loss, outputs) if return_outputs else loss

 

 training_args = TrainingArguments(
 output_dir=str(output),
 learning_rate=learning_rate,
 lr_scheduler_type=lr_scheduler_type,
 per_device_train_batch_size=per_device_train_batch_size,
 weight_decay=weight_decay,
 warmup_ratio=warmup_ratio,
 gradient_accumulation_steps=1,
 logging_steps=50,
 num_train_epochs=num_epochs,
 save_strategy="steps",
 save_steps=save_steps,
 save_total_limit=3,
 eval_strategy="steps",
 eval_steps=save_steps,
 load_best_model_at_end=True,
 metric_for_best_model="eval_entity_f1",
 greater_is_better=True,
 bf16=True,
 tf32=True,
 report_to='wandb',
 run_name=run_name,
 push_to_hub=True,
 hub_strategy="checkpoint",
 hub_token=hub_token,
 dataloader_persistent_workers=True,
 dataloader_num_workers=2,
 dataloader_pin_memory=True,
 ddp_find_unused_parameters=False,
 gradient_checkpointing=False,
 hub_model_id=model_id,
 hub_private_repo=True
 )

 class CustomEarlyStoppingCallback(TrainerCallback):
 def __init__(self, patience=2, threshold=0.001):
 self.patience = patience
 self.threshold = threshold
 self.best_metric = None
 self.wait = 0
 
 def on_evaluate(self, args, state, control, metrics=None, **kwargs):
 if metrics is None or "eval_entity_f1" not in metrics:
 return control
 metric_value = metrics.get("eval_entity_f1")
 
 if self.best_metric is None:
 self.best_metric = metric_value
 elif metric_value > self.best_metric + self.threshold:
 self.best_metric = metric_value
 self.wait = 0
 else:
 self.wait += 1
 if self.wait >= self.patience:
 control.should_training_stop = True
 print(f"Early stopping triggered. Best F1: {self.best_metric:.4f}")
 
 return control
 

 trainer = FocalLossTrainer(
 model=model,
 args=training_args,
 train_dataset=train_ds,
 eval_dataset=eval_ds,
 processing_class=tokenizer,
 data_collator=data_collator,
 compute_metrics=compute_metrics,
 class_weights=weight_tensor,
 gamma=gamma,
 callbacks=[CustomEarlyStoppingCallback(patience=patience, threshold=0.0001)]
 )
 
 if wandb.run is not None:
 # Add custom config values
 wandb.config.update({
 # Data configuration
 "train_samples": len(train_ds),
 "eval_samples": len(eval_ds),
 "train_size_requested": train_size,
 "test_size_requested": test_size,
 "actual_train_size": len(train_ds),
 "actual_eval_size": len(eval_ds),

 # Model architecture details
 "model_architecture": "deberta-v3-base",
 "num_labels": len(label_list),
 "label_list": label_list,

 # Loss function configuration
 "loss_function": "focal_loss",
 "focal_gamma": gamma,
 "focal_alpha": "weighted",
 "frequency_exponent": frequency_exponent,

 # Dropout configuration
 "hidden_dropout_prob": hidden_dropout_prob,
 "attention_probs_dropout_prob": attention_probs_dropout_prob,

 # Training configuration not in TrainingArguments
 "max_saved_checkpoints": max_saved_checkpoints,
 "early_stopping_patience": patience,
 "early_stopping_threshold": 0.001,

 # Environment info
 "cuda_devices": cuda_visible_devices,
 "num_gpus": num_devices,

 # Data processing
 "tokenizer": "microsoft/deberta-v3-base"

 # Experiment metadata
 "experiment_type": "ner_fiction",
 "data_source": "gutenberg_ao3_mixed",
 "random_seed": seed,

 # Logging configuration
 "n_eval_samples": n_eval_samples,
 "log_predictions_to_wandb": log_predictions_to_wandb,
 })

 has_checkpoints = bool([f for f in os.scandir(output_dir) if f.is_dir() and "checkpoint" in f.name])
 if has_checkpoints:
 trainer.train(resume_from_checkpoint=True)
 else:
 trainer.train()

notebook_launcher(train_fn, num_processes=num_devices)