Quality Estimation for Machine Translation
This model is a fine-tuned version of answerdotai/ModernBERT-base on the ymoslem/wmt-da-human-evaluation-long-context dataset. It achieves the following results on the evaluation set:
- Loss: 0.0214
- Pearson: 0.5013
- MAE: 0.1024
- RMSE: 0.1464
- R2: 0.251
Model description
This model is for reference-free, long-context quality estimation (QE) of machine translation (MT) systems. It is trained on a dataset of translation pairs comprising up to 32 sentences (64 sentences for the source and target). Hence, this model is suitable for document-level quality estimation.
Training and evaluation data
The model is trained on the long-context dataset ymoslem/wmt-da-human-evaluation-long-context.
The used long-context / document-level dataset for Quality Estimation of Machine Translation is an augmented variant of the sentence-level WMT DA Human Evaluation dataset.
In addition to individual sentences, it contains augmentations of 2, 4, 8, 16, and 32 sentences, among each language pair lp and domain.
The raw column represents a weighted average of scores of augmented sentences using character lengths of src and mt as weights.
- Training data: 7.65 million long-context texts
- Test data: 59,235 long-context texts
Training procedure
The model is trained on 1x H200 SXM (143 GB VRAM) for approx. 26 hours.
- tokenizer.model_max_length: 8192 (full context length)
- attn_implementation: flash_attention_2
Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 0.0003
- train_batch_size: 128
- eval_batch_size: 128
- seed: 42
- optimizer: Use OptimizerNames.ADAMW_TORCH_FUSED with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
- lr_scheduler_type: linear
- training_steps: 60000 (approx. 1 epoch)
Training results
| Training Loss | Epoch | Step | Validation Loss |
|---|---|---|---|
| 0.0233 | 0.0167 | 1000 | 0.0233 |
| 0.0232 | 0.0335 | 2000 | 0.0230 |
| 0.0225 | 0.0502 | 3000 | 0.0230 |
| 0.023 | 0.0669 | 4000 | 0.0224 |
| 0.0226 | 0.0837 | 5000 | 0.0223 |
| 0.0226 | 0.1004 | 6000 | 0.0225 |
| 0.0219 | 0.1171 | 7000 | 0.0222 |
| 0.022 | 0.1339 | 8000 | 0.0222 |
| 0.0213 | 0.1506 | 9000 | 0.0221 |
| 0.0213 | 0.1673 | 10000 | 0.0220 |
| 0.0218 | 0.1840 | 11000 | 0.0219 |
| 0.0215 | 0.2008 | 12000 | 0.0225 |
| 0.0218 | 0.2175 | 13000 | 0.0219 |
| 0.0218 | 0.2342 | 14000 | 0.0218 |
| 0.0217 | 0.2510 | 15000 | 0.0219 |
| 0.0219 | 0.2677 | 16000 | 0.0219 |
| 0.0212 | 0.2844 | 17000 | 0.0219 |
| 0.0219 | 0.3012 | 18000 | 0.0219 |
| 0.0218 | 0.3179 | 19000 | 0.0219 |
| 0.0213 | 0.3346 | 20000 | 0.0217 |
| 0.0218 | 0.3514 | 21000 | 0.0217 |
| 0.021 | 0.3681 | 22000 | 0.0217 |
| 0.0219 | 0.3848 | 23000 | 0.0220 |
| 0.0211 | 0.4016 | 24000 | 0.0216 |
| 0.0211 | 0.4183 | 25000 | 0.0216 |
| 0.0206 | 0.4350 | 26000 | 0.0216 |
| 0.021 | 0.4517 | 27000 | 0.0215 |
| 0.0214 | 0.4685 | 28000 | 0.0215 |
| 0.0214 | 0.4852 | 29000 | 0.0216 |
| 0.0204 | 0.5019 | 30000 | 0.0216 |
| 0.022 | 0.5187 | 31000 | 0.0216 |
| 0.0212 | 0.5354 | 32000 | 0.0217 |
| 0.0211 | 0.5521 | 33000 | 0.0216 |
| 0.0208 | 0.5689 | 34000 | 0.0215 |
| 0.0208 | 0.5856 | 35000 | 0.0215 |
| 0.0215 | 0.6023 | 36000 | 0.0215 |
| 0.0212 | 0.6191 | 37000 | 0.0215 |
| 0.0213 | 0.6358 | 38000 | 0.0215 |
| 0.0211 | 0.6525 | 39000 | 0.0215 |
| 0.0208 | 0.6693 | 40000 | 0.0215 |
| 0.0205 | 0.6860 | 41000 | 0.0215 |
| 0.0209 | 0.7027 | 42000 | 0.0215 |
| 0.021 | 0.7194 | 43000 | 0.0215 |
| 0.0207 | 0.7362 | 44000 | 0.0215 |
| 0.0197 | 0.7529 | 45000 | 0.0215 |
| 0.0211 | 0.7696 | 46000 | 0.0214 |
| 0.021 | 0.7864 | 47000 | 0.0215 |
| 0.0207 | 0.8031 | 48000 | 0.0214 |
| 0.0219 | 0.8198 | 49000 | 0.0215 |
| 0.0208 | 0.8366 | 50000 | 0.0215 |
| 0.0202 | 0.8533 | 51000 | 0.0215 |
| 0.02 | 0.8700 | 52000 | 0.0215 |
| 0.0205 | 0.8868 | 53000 | 0.0214 |
| 0.0214 | 0.9035 | 54000 | 0.0215 |
| 0.0205 | 0.9202 | 55000 | 0.0214 |
| 0.0209 | 0.9370 | 56000 | 0.0214 |
| 0.0206 | 0.9537 | 57000 | 0.0214 |
| 0.0204 | 0.9704 | 58000 | 0.0214 |
| 0.0203 | 0.9872 | 59000 | 0.0214 |
| 0.0209 | 1.0039 | 60000 | 0.0214 |
Framework versions
- Transformers 4.48.1
- Pytorch 2.4.1+cu124
- Datasets 3.2.0
- Tokenizers 0.21.0
Inference
- Install the required libraries.
pip3 install --upgrade datasets accelerate transformers
pip3 install --upgrade flash_attn triton
- Load the test dataset.
from datasets import load_dataset
test_dataset = load_dataset("ymoslem/wmt-da-human-evaluation",
split="test",
trust_remote_code=True
)
print(test_dataset)
- Load the model and tokenizer:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
# Load the fine-tuned model and tokenizer
model_name = "ymoslem/ModernBERT-base-long-context-qe-v1"
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
- Prepare the dataset. Each source segment
srcand target segmenttgtare separated by thesep_token, which is'</s>'for ModernBERT.
sep_token = tokenizer.sep_token
input_test_texts = [f"{src} {sep_token} {tgt}" for src, tgt in zip(test_dataset["src"], test_dataset["mt"])]
- Generate predictions.
If you print model.config.problem_type, the output is regression.
Still, you can use the "text-classification" pipeline as follows (cf. pipeline documentation):
from transformers import pipeline
classifier = pipeline("text-classification",
model=model_name,
tokenizer=tokenizer,
device=0,
)
predictions = classifier(input_test_texts,
batch_size=128,
truncation=True,
padding="max_length",
max_length=tokenizer.model_max_length,
)
predictions = [prediction["score"] for prediction in predictions]
Alternatively, you can use an elaborate version of the code, which is slightly faster and provides more control.
from torch.utils.data import DataLoader
import torch
from tqdm.auto import tqdm
# Tokenization function
def process_batch(batch, tokenizer, device):
sep_token = tokenizer.sep_token
input_texts = [f"{src} {sep_token} {tgt}" for src, tgt in zip(batch["src"], batch["mt"])]
tokens = tokenizer(input_texts,
truncation=True,
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
).to(device)
return tokens
# Create a DataLoader for batching
test_dataloader = DataLoader(test_dataset,
batch_size=128, # Adjust batch size as needed
shuffle=False)
# List to store all predictions
predictions = []
with torch.no_grad():
for batch in tqdm(test_dataloader, desc="Inference Progress", unit="batch"):
tokens = process_batch(batch, tokenizer, device)
# Forward pass: Generate model's logits
outputs = model(**tokens)
# Get logits (predictions)
logits = outputs.logits
# Extract the regression predicted values
batch_predictions = logits.squeeze()
# Extend the list with the predictions
predictions.extend(batch_predictions.tolist())
- Downloads last month
- 14
Model tree for ymoslem/ModernBERT-base-long-context-qe-v1
Base model
answerdotai/ModernBERT-baseDataset used to train ymoslem/ModernBERT-base-long-context-qe-v1
Collection including ymoslem/ModernBERT-base-long-context-qe-v1
Evaluation results
- Pearson Correlation on ymoslem/wmt-da-human-evaluation-long-contextself-reported0.501
- Mean Absolute Error on ymoslem/wmt-da-human-evaluation-long-contextself-reported0.102
- Root Mean Squared Error on ymoslem/wmt-da-human-evaluation-long-contextself-reported0.146
- R-Squared on ymoslem/wmt-da-human-evaluation-long-contextself-reported0.251