In [None]:
%pip install sacrebleu sentencepiece torch datasets==3.6.0 scipy tqdm numpy tensorboard optuna

In [None]:
# imports

from __future__ import annotations

import json
import math
import random
import time
from datetime import timedelta
from pathlib import Path
from typing import List, Tuple

import numpy as np
import sacrebleu
import sentencepiece as spm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as tud
from torch.optim.lr_scheduler import StepLR
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from datasets import load_dataset
from scipy import stats
from tqdm.auto import tqdm
import optuna

In [None]:
# Model definitions

class LuongAttention(nn.Module):
 def __init__(self, hidden_size: int):
 super().__init__()
 self.scale = 1.0 / math.sqrt(hidden_size)

 def forward(self, query, keys, values, mask=None):
 # query: (B, 1, H); keys: (B, T, H)
 scores = torch.bmm(query, keys.transpose(1, 2)) * self.scale # (B,1,T)
 if mask is not None:
 scores = scores.masked_fill_(~mask[:, None, :], -1e9)
 attn = torch.softmax(scores, dim=-1) # (B,1,T)
 context = torch.bmm(attn, values) # (B,1,H)
 return context, attn.squeeze(1)

class BiLSTMTranslator(nn.Module):
 """
 2-layer bidirectional LSTM encoder + 2-layer unidirectional LSTM decoder
 with Luong global attention.

 The final forward & backward encoder states are concatenated, then
 replicated across decoder layers so the initial (h_0, c_0) have shape
 (num_layers, batch, hidden_size), as required by nn.LSTM.
 """
 def __init__(
 self,
 # These arguments will be supplied by Optuna. Values here are placeholders
 vocab_size: int,
 emb_size: int = 512,
 hidden_size: int = 512,
 num_layers: int = 2,
 dropout: float = 0.1,
 **kwargs: dict,
 ):
 super().__init__()
 self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=0)
 self.encoder = nn.LSTM(
 input_size=emb_size,
 hidden_size=hidden_size // 2,
 num_layers=num_layers,
 dropout=dropout,
 bidirectional=True,
 batch_first=True,
 )

 self.decoder = nn.LSTM(
 input_size=emb_size,
 hidden_size=hidden_size,
 num_layers=num_layers,
 dropout=dropout,
 batch_first=True,
 )

 self.attn = LuongAttention(hidden_size)
 self.out = nn.Linear(hidden_size * 2, vocab_size)
 self.dropout = nn.Dropout(dropout)
 
 def forward(self, src, src_lens, tgt):
 # encoder
 emb_src = self.dropout(self.embedding(src))
 packed_src = nn.utils.rnn.pack_padded_sequence(
 emb_src, src_lens.cpu(), batch_first=True, enforce_sorted=False
 )
 enc_out, (h_enc, c_enc) = self.encoder(packed_src)
 enc_out, _ = nn.utils.rnn.pad_packed_sequence(enc_out, batch_first=True)
 # h_enc & c_enc: (num_layers*2, batch, hidden_size//2)
 
 # Concatenate last forward & backward states -> (batch, hidden_size)
 h_final = torch.cat([h_enc[-2], h_enc[-1]], dim=-1)
 c_final = torch.cat([c_enc[-2], c_enc[-1]], dim=-1)
 
 # Expand to match decoder layers: (num_layers, batch, hidden_size)
 num_dec_layers = self.decoder.num_layers
 h0 = h_final.unsqueeze(0).repeat(num_dec_layers, 1, 1)
 c0 = c_final.unsqueeze(0).repeat(num_dec_layers, 1, 1)
 
 # decoder
 emb_tgt = self.dropout(self.embedding(tgt))
 dec_out, _ = self.decoder(emb_tgt, (h0, c0)) # (B, T, H)
 
 # attention
 context, _ = self.attn(dec_out, enc_out, enc_out) # (B, T, H)
 concat = torch.cat([dec_out, context], dim=-1) # (B, T, 2H)
 logits = self.out(concat) # (B, T, V)
 return logits


class PositionalEncoding(nn.Module):
 """
 Implements sinusoidal positional encoding as described in "Attention Is All You Need".
 """
 def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
 super().__init__()
 self.dropout = nn.Dropout(dropout)

 # Create constant "pe" matrix with values dependent on
 # pos and i
 pe = torch.zeros(max_len, d_model)
 position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
 div_term = torch.exp(
 torch.arange(0, d_model, 2, dtype=torch.float)
 * (-math.log(10000.0) / d_model)
 )
 pe[:, 0::2] = torch.sin(position * div_term)
 pe[:, 1::2] = torch.cos(position * div_term)
 pe = pe.unsqueeze(0) # shape (1, max_len, d_model)

 # Register as buffer so it's saved/loaded but not trained
 self.register_buffer("pe", pe)

 def forward(self, x: torch.Tensor) -> torch.Tensor:
 """
 Args:
 x: Tensor, shape (batch_size, seq_len, d_model)
 """
 x = x + self.pe[:, : x.size(1)]
 return self.dropout(x)


class TransformerTranslator(nn.Module):
 def __init__(
 self,
 # These arguments will be supplied by Optuna. Values here are placeholders
 vocab_size: int,
 d_model: int = 256,
 nhead: int = 8,
 num_layers: int = 4,
 dropout: float = 0.1,
 max_len: int = 5000,
 **kwargs
 ):
 super().__init__()
 self.d_model = d_model

 # Token embedding + positional encoding
 self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
 self.pos_enc = PositionalEncoding(d_model, dropout, max_len)

 # Encoder and decoder stacks
 encoder_layer = nn.TransformerEncoderLayer(
 d_model=d_model,
 nhead=nhead,
 dim_feedforward=d_model * 4,
 dropout=dropout,
 batch_first=True,
 )
 self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

 decoder_layer = nn.TransformerDecoderLayer(
 d_model=d_model,
 nhead=nhead,
 dim_feedforward=d_model * 4,
 dropout=dropout,
 batch_first=True,
 )
 self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

 # Final linear projection
 self.out = nn.Linear(d_model, vocab_size)

 def forward(
 self,
 src: torch.Tensor,
 src_lens,
 tgt: torch.Tensor,
 ) -> torch.Tensor:
 """
 Args:
 src: (batch_size, src_seq_len)
 src_lens: (unused here, but kept for compatibility)
 tgt: (batch_size, tgt_seq_len)
 Returns:
 logits: (batch_size, tgt_seq_len, vocab_size)
 """
 # Padding masks
 src_key_padding_mask = src == 0 # True at padding positions
 tgt_key_padding_mask = tgt == 0

 # Embedding + scaling + positional encoding
 emb_src = self.embedding(src) * math.sqrt(self.d_model)
 emb_src = self.pos_enc(emb_src)

 emb_tgt = self.embedding(tgt) * math.sqrt(self.d_model)
 emb_tgt = self.pos_enc(emb_tgt)

 # Encoder forward
 memory = self.encoder(
 emb_src,
 src_key_padding_mask=src_key_padding_mask,
 )

 # Create causal mask for decoder (prevent attending to future tokens)
 tgt_seq_len = tgt.size(1)
 causal_mask = nn.Transformer.generate_square_subsequent_mask(tgt_seq_len).to(
 src.device
 )

 causal_mask = causal_mask.to(torch.bool)
 output = self.decoder(
 emb_tgt,
 memory,
 tgt_mask=causal_mask,
 tgt_key_padding_mask=tgt_key_padding_mask,
 memory_key_padding_mask=src_key_padding_mask,
 )

 # Project to vocabulary
 logits = self.out(output)
 return logits


In [None]:
# Data loading & utilities

def set_seed(seed: int) -> None:
 random.seed(seed)
 np.random.seed(seed)
 torch.manual_seed(seed)
 torch.cuda.manual_seed_all(seed)
 torch.backends.cudnn.deterministic = True
 torch.backends.cudnn.benchmark = False


class Timer:
 """Context manager that measures (wall-clock) seconds."""

 def __enter__(self):
 self.start = time.time()
 return self

 def __exit__(self, exc_type, exc_val, exc_tb):
 self.end = time.time()
 self.elapsed = self.end - self.start


BOS, EOS, PAD, UNK = "", "", "", ""


def download_iwslt17_de_en(data_dir: Path) -> Tuple[Path, Path, Path]:
 dataset = load_dataset("iwslt2017", "iwslt2017-de-en")
 splits = {}
 for split in ("train", "validation", "test"):
 lines = [
 f"{ex['translation']['de']}\t{ex['translation']['en']}"
 for ex in dataset[split]
 ]
 out_path = data_dir / f"{split}.tsv"
 out_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
 splits[split] = out_path
 return splits["train"], splits["validation"], splits["test"]


def train_sentencepiece(input_paths: List[Path], model_prefix: str, vocab_size: int = 8000) -> Path:
 input_text = "\n".join([p.read_text(encoding="utf-8") for p in input_paths])
 tmp = Path(f"{model_prefix}_corpus.txt")
 tmp.write_text(input_text, encoding="utf-8")
 spm.SentencePieceTrainer.train(
 input=str(tmp), model_prefix=model_prefix, vocab_size=vocab_size,
 character_coverage=1.0, model_type="bpe",
 pad_id=0, unk_id=1, bos_id=2, eos_id=3, user_defined_symbols="" # PAD, UNK, BOS, EOS
 )
 tmp.unlink() # cleanup
 return Path(f"{model_prefix}.model")


def encode_file(sp: spm.SentencePieceProcessor, in_path: Path, out_path: Path) -> None:
 with in_path.open("r", encoding="utf-8") as fi, out_path.open("w", encoding="utf-8") as fo:
 for line in fi:
 src, tgt = line.rstrip().split("\t")
 pieces_src = sp.encode(src, out_type=str)
 pieces_tgt = sp.encode(tgt, out_type=str)
 fo.write(" ".join(pieces_src) + "\t" + " ".join(pieces_tgt) + "\n")


class ParallelDataset(tud.Dataset):
 def __init__(self, path: Path, sp: spm.SentencePieceProcessor, max_len: int = 100):
 self.samples = []
 BOS_ID, EOS_ID = sp.bos_id(), sp.eos_id()

 with path.open("r", encoding="utf-8") as fh:
 for ln in fh:
 if "\t" not in ln:
 continue
 src_txt, tgt_txt = ln.rstrip().split("\t", maxsplit=1)

 # Tokens already split, just convert to IDs directly
 src_ids = [BOS_ID] + sp.piece_to_id(src_txt.split()) + [EOS_ID]
 tgt_ids = [BOS_ID] + sp.piece_to_id(tgt_txt.split()) + [EOS_ID]

 if len(src_ids) <= max_len and len(tgt_ids) <= max_len:
 self.samples.append(
 (torch.LongTensor(src_ids), torch.LongTensor(tgt_ids))
 )

 def __len__(self):
 return len(self.samples)

 def __getitem__(self, idx):
 return self.samples[idx]


def collate_fn(batch):
 srcs, tgts = zip(*batch)
 src_lens = [len(x) for x in srcs]
 tgt_lens = [len(x) for x in tgts]
 max_src, max_tgt = max(src_lens), max(tgt_lens)
 src_pad = torch.zeros(len(batch), max_src, dtype=torch.long)
 tgt_pad = torch.zeros(len(batch), max_tgt, dtype=torch.long)
 for i, (src, tgt) in enumerate(zip(srcs, tgts)):
 src_pad[i, : len(src)] = src
 tgt_pad[i, : len(tgt)] = tgt
 return src_pad, torch.tensor(src_lens), tgt_pad, torch.tensor(tgt_lens)


def get_noam_scheduler(optimizer, d_model, warmup_steps, lr_scale=1.0):
 def lr_lambda(step):
 t = step + 1 
 scale = d_model ** -0.5
 return lr_scale * scale * min(t**-0.5, t * warmup_steps**-1.5)
 return LambdaLR(optimizer, lr_lambda)

In [None]:
# Training & Evaluation

def label_smoothing_loss(logits, targets, pad_idx: int = 0, smoothing: float = 0.1):
 """
 Cross-entropy with uniform label smoothing.
 Args
 logits : (B, T, V) - raw scores from the model
 targets : (B, T) - ground-truth token IDs
 """
 vocab = logits.size(-1)

 
 logits_flat = logits.contiguous().view(-1, vocab) # (B*T, V)
 targets_flat = targets.contiguous().view(-1) # (B*T)

 # Standard CE per token
 nll = torch.nn.functional.cross_entropy(
 logits_flat,
 targets_flat,
 ignore_index=pad_idx,
 reduction="none",
 )

 # Apply smoothing
 loss = (1.0 - smoothing) * nll + smoothing / vocab

 # Remove padding positions
 loss = loss[targets_flat != pad_idx]

 return loss.mean()


def train_epoch(model, iterator, optimizer, device, scheduler, clip_norm=1.0):
 model.train()
 total_loss = 0.0

 for src, src_lens, tgt, _ in tqdm(
 iterator, desc="Train batches", leave=False
 ):
 src, src_lens = src.to(device), src_lens.to(device)
 tgt_inp, tgt_out = tgt[:, :-1].to(device), tgt[:, 1:].to(device)

 optimizer.zero_grad()
 logits = model(src, src_lens, tgt_inp)
 loss = label_smoothing_loss(logits, tgt_out)
 loss.backward()
 torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
 optimizer.step()
 scheduler.step()

 total_loss += loss.item() * src.size(0)

 return total_loss / len(iterator.dataset)


def greedy_translate_ids(model, sp, src, src_lens, device, max_len: int = 60):
 """
 src : (B, S) padded batch on CPU or GPU
 src_lens : (B,) true lengths
 returns : List[List[int]] – token-id sequences (B of them)
 """
 BOS, EOS = sp.bos_id(), sp.eos_id()
 model.eval()

 with torch.no_grad():
 src, src_lens = src.to(device), src_lens.to(device)
 B = src.size(0)

 tgt = torch.full((B, 1), BOS, dtype=torch.long, device=device)
 finished = torch.zeros(B, dtype=torch.bool, device=device)

 for _ in range(max_len):
 logits = model(src, src_lens, tgt) # (B, T, V)
 next_tok = logits[:, -1].argmax(-1, keepdim=True)

 tgt = torch.cat([tgt, next_tok], dim=1)
 finished |= (next_tok.squeeze(1) == EOS)
 if finished.all():
 break

 # strip BOS/EOS and move to Python lists
 out = []
 for row in tgt.tolist():
 ids = row[1:]
 if EOS in ids:
 ids = ids[: ids.index(EOS)]
 out.append(ids)
 return out

def beam_translate_ids(model, sp, src, src_lens, device, max_len: int = 60, beam_width: int = 4):
 """
 Batched and efficient beam search implementation (without separate encode/decode methods).
 """
 BOS, EOS = sp.bos_id(), sp.eos_id()
 model.eval()

 with torch.no_grad():
 src, src_lens = src.to(device), src_lens.to(device)
 B = src.size(0)

 # Repeat source inputs for beam search
 src = src.repeat_interleave(beam_width, dim=0)
 src_lens = src_lens.repeat_interleave(beam_width, dim=0)

 # Initialize target tokens with BOS
 tgt = torch.full((B * beam_width, 1), BOS, dtype=torch.long, device=device)
 beam_scores = torch.zeros(B, beam_width, device=device)
 beam_scores[:, 1:] = -1e9 # Initially deactivate all beams except first
 beam_scores = beam_scores.view(-1)

 finished = torch.zeros(B * beam_width, dtype=torch.bool, device=device)

 for _ in range(max_len):
 logits = model(src, src_lens, tgt) # (B*beam_width, T, V)
 log_probs = F.log_softmax(logits[:, -1, :], dim=-1) # (B*beam_width, V)

 scores = beam_scores.unsqueeze(1) + log_probs # (B*beam_width, V)
 scores = scores.view(B, -1) # (B, beam_width*V)

 top_scores, top_ids = scores.topk(beam_width, dim=-1) # (B, beam_width)

 beam_indices = top_ids // log_probs.size(-1)
 token_indices = top_ids % log_probs.size(-1)

 # Reorder beams
 tgt = tgt.view(B, beam_width, -1)
 next_tgt = []
 for batch_idx in range(B):
 next_tgt.append(tgt[batch_idx, beam_indices[batch_idx]])
 tgt = torch.stack(next_tgt, dim=0).view(B * beam_width, -1)

 # Append tokens
 tgt = torch.cat([tgt, token_indices.view(-1, 1)], dim=-1)

 beam_scores = top_scores.view(-1)

 # Check EOS
 finished |= (token_indices.view(-1) == EOS)
 if finished.view(B, beam_width).all(dim=1).all():
 break

 # Choose best beams
 tgt = tgt.view(B, beam_width, -1)
 best_seqs = tgt[torch.arange(B), beam_scores.view(B, beam_width).argmax(dim=-1)]

 out = []
 for seq in best_seqs.tolist():
 if EOS in seq:
 seq = seq[1:seq.index(EOS)]
 else:
 seq = seq[1:]
 out.append(seq)

 return out


def evaluate(model, data_iter, sp, device):
 hyps, refs = [], []
 for src, src_lens, tgt, tgt_lens in tqdm(data_iter, desc="Evaluate", leave=False):
 # batched generation
 pred_ids = beam_translate_ids(model, sp, src, src_lens, device)

 # batched decoding
 hyps.extend([sp.decode(ids) for ids in pred_ids])

 # strip BOS/EOS then batch-decode
 ref_ids = [ t[1:l-1].tolist() for t, l in zip(tgt, tgt_lens) ]
 refs.extend([sp.decode(ids) for ids in ref_ids])


 assert len(hyps) == len(refs), "Mismatch between #hypotheses and #references!"

 bleu = sacrebleu.corpus_bleu(hyps, [refs])
 chrf = sacrebleu.corpus_chrf(hyps, [refs])
 return bleu.score, chrf.score



In [None]:
sizes = [10_000, 50_000, 75_000, 100_000, 150_000, 200_000]
trials_per_size = {10_000: 15, 50_000: 20, 75_000: 20, 100_000: 25, 150_000: 25, 200_000: 25}
epochs_per_size = {10_000: 4, 50_000: 5, 75_000: 5, 100_000: 8, 150_000: 10, 200_000: 10}

current_size = None

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
data_dir = Path("data")
data_dir.mkdir(parents=True, exist_ok=True)

tune_dir = Path("tune")
tune_dir.mkdir(parents=True, exist_ok=True)

train_dir = Path("train")
train_dir.mkdir(parents=True, exist_ok=True)

log_dir = Path("logs")
log_dir.mkdir(parents=True, exist_ok=True)

print("Downloading IWSLT-2017...")
train_raw, valid_raw, test_raw = download_iwslt17_de_en(data_dir)

print("Training SentencePiece model...")
spm_path = train_sentencepiece([train_raw], str(data_dir / "bpe8k"), vocab_size=8000)
sp = spm.SentencePieceProcessor(model_file=str(spm_path))

# Pre-encode full corpus once to speed up later sampling
print("Encoding full corpus... (this may take a minute)")
encoded_train = data_dir / "train.bpe.tsv"
encode_file(sp, train_raw, encoded_train)
encode_file(sp, valid_raw, data_dir / "valid.bpe.tsv")
encode_file(sp, test_raw, data_dir / "test.bpe.tsv")

for size in sizes:
 # Down-sample deterministically for reproducibility
 pairs = encoded_train.read_text().splitlines()
 random.Random(42).shuffle(pairs)
 subset_path = data_dir / f"train_{size}.bpe.tsv"
 subset_path.write_text("\n".join(pairs[: size]) + "\n", encoding="utf-8")

In [None]:
def suggest_bilstm_params(trial: optuna.Trial) -> dict:
 """
 Define a compact search space for a Bi-LSTM MT model with fixed batch size 2048,
 Adam optimizer, and StepLR scheduler.
 """
 return {
 # Architecture parameters
 "emb_size": trial.suggest_int("emb_size", 128, 512, step=64),
 "hidden_size": trial.suggest_int("hidden_size", 256, 1024, step=128),
 "num_layers": trial.suggest_int("num_layers", 1, 3),
 "dropout": trial.suggest_float("dropout", 0.1, 0.3),

 # optimizer parameters
 "lr": trial.suggest_float("lr", 3e-4, 5e-3, log=True),
 "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
 "clip_norm": trial.suggest_float("clip_norm", 0.1, 1.0),
 "eps": trial.suggest_float("eps", 1e-9, 1e-6, log=True),
 "beta1": trial.suggest_float("beta1", 0.8, 0.99, step=0.01),
 "beta2": trial.suggest_float("beta2", 0.9, 0.999, step=0.001),

 # scheduler parameters
 "scheduler_step_size": trial.suggest_int("scheduler_step_size", 10, 40),
 "scheduler_gamma": trial.suggest_float("scheduler_gamma", 0.7, 0.9),
 }


def suggest_transformer_params(trial: optuna.Trial) -> dict:
 """
 Define a compact search space for a Transformer MT model with fixed batch size 2048,
 Adam optimizer, and StepLR scheduler.
 """
 return {
 # Architecture parameters
 "d_model": trial.suggest_int("d_model", 256, 512, step=128),
 "nhead": trial.suggest_categorical("nhead", [4, 8]),
 "num_layers": trial.suggest_int("num_layers", 2, 4),
 "dropout": trial.suggest_float("dropout", 0.1, 0.3),

 # optimizer parameters
 "lr": 1.0,
 "lr_scale": trial.suggest_float("lr_scale", 0.2, 2.0, log=True),
 "weight_decay": trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True),
 "clip_norm": trial.suggest_float("clip_norm", 0.1, 1.0),
 "eps": trial.suggest_float("eps", 1e-9, 1e-6, log=True),
 "beta1": trial.suggest_float("beta1", 0.8, 0.99, step=0.01),
 "beta2": trial.suggest_float("beta2", 0.9, 0.999, step=0.001),

 "warmup_steps": trial.suggest_int("warmup_steps", 400, 800, step=50)
 }


def make_objective(model_class, train_iter, valid_iter):
 def objective(trial):
 global current_size, device
 is_bilstm = (model_class.__name__ == 'BiLSTMTranslator')
 params = suggest_bilstm_params(trial) if is_bilstm else suggest_transformer_params(trial)

 # Build model
 if is_bilstm:
 model = model_class(8000,
 hidden_size=params['hidden_size'],
 num_layers=params['num_layers'],
 dropout=params['dropout'])
 else:
 model = model_class(8000,
 d_model=params['d_model'],
 nhead=params['nhead'],
 num_layers=params['num_layers'],
 dropout=params['dropout'])

 model.to(device)
 optimizer = optim.Adam(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'], eps=params['eps'], betas=(params['beta1'], params['beta2']))
 if is_bilstm:
 scheduler = StepLR(optimizer, step_size=params['scheduler_step_size'], gamma=params['scheduler_gamma'])
 else:
 scheduler = get_noam_scheduler(optimizer, params['d_model'], warmup_steps=params['warmup_steps'], lr_scale=params['lr_scale'])

 max_epochs = epochs_per_size[current_size]
 for epoch in range(1, max_epochs + 1):
 train_epoch(model, train_iter, optimizer, device, scheduler, clip_norm=params['clip_norm'])
 bleu, _ = evaluate(model, valid_iter, sp, device)
 trial.report(bleu, epoch)
 if trial.should_prune():
 raise optuna.TrialPruned()
 return bleu
 return objective

pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=2)
best_params = {}

for model_class in [BiLSTMTranslator, TransformerTranslator]:
 model_name = model_class.__name__
 prev_params = None
 print(f"\nTuning {model_name} across dataset sizes...")

 batch_size = 2048 if model_class == BiLSTMTranslator else 1024
 max_steps = 2000 if model_class == BiLSTMTranslator else 4000
 print(f"Using\t Batch size (train & tune): {batch_size}\t Max steps (train): {max_steps}")

 for size in sizes:
 current_size = size
 print(f"\nDataset size: {size}")

 # Load data slice
 train_ds = ParallelDataset(data_dir / f"train_{size}.bpe.tsv", sp)
 valid_ds = ParallelDataset(data_dir / "valid.bpe.tsv", sp)
 test_ds = ParallelDataset(data_dir / "test.bpe.tsv", sp)

 train_iter = tud.DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
 valid_iter = tud.DataLoader(valid_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
 test_iter = tud.DataLoader(test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

 # Create or reuse study
 study_name = f"{model_name}_{size}"
 study = optuna.create_study(storage=f'sqlite:///{tune_dir / f"{study_name}.db"}', direction='maximize', pruner=pruner, study_name=study_name, load_if_exists=False)

 # Warm-start with previous size's best
 if prev_params:
 study.enqueue_trial(prev_params)

 # Optimize
 study.optimize(make_objective(model_class, train_iter, valid_iter), n_trials=trials_per_size[size], gc_after_trial=True, show_progress_bar=True)

 # Record and carry forward
 best = study.best_params
 best_params[size] = best
 prev_params = best
 # write best params to file
 with open(tune_dir / f"{model_name}_{size}_best.json", "w") as f:
 json.dump(best, f, indent=2)

 print(f"Tuning done for {model_name} @ {size}!")
 print(f"Now training best model")

 model = model_class(8000, **best)
 model.to(device)

 optimizer = optim.Adam(model.parameters(), lr=best.get('lr', 1.0), weight_decay=best['weight_decay'], eps=best['eps'], betas=(best['beta1'], best['beta2']))

 if model_class == TransformerTranslator:
 scheduler = get_noam_scheduler(optimizer, best['d_model'], warmup_steps=best['warmup_steps'], lr_scale=best['lr_scale'])
 else:
 scheduler = StepLR(optimizer, step_size=best['scheduler_step_size'], gamma=best['scheduler_gamma'])

 step = 0
 best_bleu = 0.0
 start_time = time.time()
 p_bar = tqdm(leave=False, dynamic_ncols=True, desc="Training", unit="it", total=max_steps)
 i = 0
 write = SummaryWriter(log_dir / f"{model_name}_{size}")
 cfg_name = f"{model_name}_{size}"
 with Timer() as run_timer:
 while step < max_steps and (time.time() - start_time) < 3600:
 epoch_loss = train_epoch(model, train_iter, optimizer, device, scheduler, clip_norm=best['clip_norm'])
 step += len(train_iter)
 i += 1
 p_bar.set_postfix(loss=epoch_loss)
 p_bar.update(len(train_iter))
 bleu, chrf = evaluate(model, valid_iter, sp, device)
 if bleu > best_bleu:
 best_bleu = bleu
 torch.save(model.state_dict(), train_dir / f"{cfg_name}_best.pt")
 with open(train_dir / f"{cfg_name}_best_num_steps.txt", "w") as f:
 f.write(str(step))
 write.add_scalar("loss/train", epoch_loss, step)
 write.add_scalar("bleu/valid", bleu, step)
 write.add_scalar("chrf/valid", chrf, step)
 print(f"[{cfg_name}] step={step} loss={epoch_loss:.3f} BLEU={bleu:.2f} ChrF={chrf:.2f}")
 
 write.add_hparams(best, {"bleu": best_bleu, "chrf": chrf, "steps": step, "time": run_timer.elapsed})
 write.close()
 torch.save(model.state_dict(), train_dir / f"{cfg_name}_final.pt")
 print(f"Training complete for {model_name} @ {size} in {timedelta(seconds=run_timer.elapsed)}")
 print(f"Best BLEU: {best_bleu:.2f}, ChrF: {chrf:.2f}")
 p_bar.close()

In [None]:
# # Welch T test
# for size in sizes:
# bilstm_path = train_dir / f"BiLSTMTranslator_{size}_best.pt"
# transformer_path = train_dir / f"TransformerTranslator_{size}_best.pt"
# if bilstm_path.exists() and transformer_path.exists():
# bilstm_model = BiLSTMTranslator(8000, **best_params[size])
# bilstm_model.load_state_dict(torch.load(bilstm_path, map_location=device))
# bilstm_model.to(device)

# transformer_model = TransformerTranslator(8000, **best_params[size])
# transformer_model.load_state_dict(torch.load(transformer_path, map_location=device))
# transformer_model.to(device)

# bilstm_bleu, _ = evaluate(bilstm_model, valid_iter, sp, device)
# transformer_bleu, _ = evaluate(transformer_model, valid_iter, sp, device)

# t_stat, p_value = stats.ttest_ind_from_stats(
# mean1=bilstm_bleu, std1=0.0, nobs1=1,
# mean2=transformer_bleu, std2=0.0, nobs2=1
# )
# print(f"Size {size}: BiLSTM BLEU={bilstm_bleu:.2f}, Transformer BLEU={transformer_bleu:.2f}, T-stat={t_stat:.3f}, p-value={p_value:.3e}")
# else:
# print(f"Skipping size {size} due to missing model files.")
# continue