DualTrack Alignment Module

Anonymous submission to ACL 2026

A cross-encoder model for detecting citation drift in Retrieval-Augmented Generation (RAG) systems. Given a user-facing claim, an evidence representation, and a source passage, the model predicts whether the citation is valid (the source supports the claim).

Model Description

This model addresses a critical reliability problem in RAG systems: citation drift, where generated text diverges from source documents in ways that break attribution. The problem is particularly severe in cross-lingual settings where the answer language differs from source document language.

Architecture

Input: "[CLS] User claim: {claim} [SEP] Evidence: {evidence} [SEP] Source passage: {context} [SEP]"
         ↓
    DeBERTa-v3-base (184M parameters)
         ↓
    [CLS] embedding (768-dim)
         ↓
    Linear(768, 2) → Softmax
         ↓
    Output: P(valid citation)

Why Cross-Encoder?

Unlike embedding-based approaches that encode texts separately, the cross-encoder sees all three components together, enabling:

  • Cross-attention between claim and source
  • Detection of subtle semantic mismatches
  • Better handling of paraphrases vs. factual errors

Intended Use

Primary Use Cases

  1. Post-hoc citation verification: Validate citations in RAG outputs before serving to users
  2. Citation drift detection: Identify claims that have semantically drifted from their sources
  3. Training signal: Provide rewards for citation-aware generation

Out of Scope

  • General NLI/entailment (model is specialized for RAG citation patterns)
  • Fact-checking against world knowledge (requires source passage)
  • Non-English source documents (trained on English sources only)

How to Use

Installation

pip install transformers torch

Basic Usage

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# Load model
model_name = "anonymous-acl2026/dualtrack-alignment"  # Replace with actual path
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()

def check_citation(user_claim: str, evidence: str, source: str, threshold: float = 0.5) -> tuple[bool, float]:
    """
    Check if a citation is valid.
    
    Args:
        user_claim: The claim shown to the user
        evidence: Evidence track representation (can be same as user_claim)
        source: The source passage being cited
        threshold: Classification threshold (default from training)
    
    Returns:
        (is_valid, probability)
    """
    # Format input
    text = f"User claim: {user_claim}\n\nEvidence: {evidence}\n\nSource passage: {source}"
    
    # Tokenize
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    
    # Predict
    with torch.no_grad():
        outputs = model(**inputs)
        prob = torch.softmax(outputs.logits, dim=-1)[0, 1].item()
    
    return prob >= threshold, prob

# Example: Valid citation
is_valid, prob = check_citation(
    user_claim="Python was created by Guido van Rossum.",
    evidence="Python was created by Guido van Rossum.",
    source="Python is a programming language created by Guido van Rossum in 1991."
)
print(f"Valid: {is_valid}, Probability: {prob:.3f}")
# Output: Valid: True, Probability: 0.95

# Example: Invalid citation (wrong date)
is_valid, prob = check_citation(
    user_claim="Python was created in 1989.",
    evidence="Python was created in 1989.",
    source="Python is a programming language created by Guido van Rossum in 1991."
)
print(f"Valid: {is_valid}, Probability: {prob:.3f}")
# Output: Valid: False, Probability: 0.12

Batch Processing

def batch_check_citations(examples: list[dict], batch_size: int = 16) -> list[float]:
    """
    Check multiple citations efficiently.
    
    Args:
        examples: List of dicts with keys 'user', 'evidence', 'source'
        batch_size: Batch size for inference
    
    Returns:
        List of probabilities
    """
    all_probs = []
    
    for i in range(0, len(examples), batch_size):
        batch = examples[i:i + batch_size]
        
        texts = [
            f"User claim: {ex['user']}\n\nEvidence: {ex['evidence']}\n\nSource passage: {ex['source']}"
            for ex in batch
        ]
        
        inputs = tokenizer(
            texts, 
            return_tensors="pt", 
            truncation=True, 
            max_length=512, 
            padding=True
        )
        
        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.softmax(outputs.logits, dim=-1)[:, 1].tolist()
        
        all_probs.extend(probs)
    
    return all_probs

Integration with DualTrack

class DualTrackAlignmentModule:
    """
    Alignment module for the DualTrack RAG system.
    
    Detects citation drift between user track and source documents.
    """
    
    def __init__(self, model_path: str, threshold: float = None, device: str = None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
        self.model.to(self.device)
        self.model.eval()
        
        # Load optimal threshold from metadata
        import json
        import os
        metadata_path = os.path.join(model_path, "metadata.json")
        if os.path.exists(metadata_path):
            with open(metadata_path) as f:
                metadata = json.load(f)
            self.threshold = threshold or metadata.get("optimal_threshold", 0.5)
        else:
            self.threshold = threshold or 0.5
    
    def detect_drift(
        self, 
        user_claims: list[str], 
        evidence_claims: list[str], 
        sources: list[str]
    ) -> list[dict]:
        """
        Detect citation drift for multiple claim-source pairs.
        
        Returns list of {is_valid, probability, drift_detected}.
        """
        results = []
        
        for user, evidence, source in zip(user_claims, evidence_claims, sources):
            text = f"User claim: {user}\n\nEvidence: {evidence}\n\nSource passage: {source}"
            
            inputs = self.tokenizer(
                text, return_tensors="pt", truncation=True, max_length=512
            ).to(self.device)
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                prob = torch.softmax(outputs.logits, dim=-1)[0, 1].item()
            
            results.append({
                "is_valid": prob >= self.threshold,
                "probability": prob,
                "drift_detected": prob < self.threshold
            })
        
        return results

Training Details

Training Data

The model was trained on a curated dataset combining multiple sources:

Source Examples Description
FEVER ~8,000 Fact verification with SUPPORTS/REFUTES labels
HAGRID ~2,000 Attributed QA with quote-based evidence
ASQA ~3,000 Ambiguous questions with long-form answers

Label Generation (V3 - LLM-Supervised):

  • Training labels verified by GPT-4o-mini ("Does context support claim?")
  • Evaluation uses independent NLI model (DeBERTa-MNLI)
  • This breaks circularity: model learns LLM judgment, evaluated by NLI

Data Augmentation:

  • Negative perturbations: date_change, number_change, entity_swap, false_detail, negation, topic_drift
  • Positive perturbations: paraphrase, synonym_swap, formal_informal register changes

Training Procedure

Hyperparameter Value
Base model microsoft/deberta-v3-base
Max sequence length 512
Batch size 8
Gradient accumulation 2
Effective batch size 16
Learning rate 2e-5
Warmup ratio 0.1
Weight decay 0.01
Epochs 5
Early stopping patience 3
FP16 training Yes
Optimizer AdamW

Training Infrastructure:

  • Single GPU (NVIDIA T4/V100)
  • Training time: ~2-3 hours
  • Framework: HuggingFace Transformers + PyTorch

Evaluation

Validation Set Performance (15% held-out, stratified):

Metric Score
Accuracy 0.87
Precision 0.88
Recall 0.90
F1 0.89
ROC-AUC 0.94

Optimal Threshold: 0.50 (determined via F1 maximization on validation set)

Performance by Perturbation Type:

Type Accuracy Notes
original 0.91 Clean examples
paraphrase 0.88 Meaning-preserving rewrites
entity_swap 0.94 Wrong person/place/org
date_change 0.92 Incorrect dates
negation 0.89 Reversed claims
topic_drift 0.85 Subtle semantic shifts

Limitations

  1. English only: Trained on English source passages. Cross-lingual application requires translation or multilingual encoder.

  2. RAG-specific: Optimized for RAG citation patterns; may not generalize to arbitrary NLI tasks.

  3. Passage length: Max 512 tokens. Long documents require chunking or summarization.

  4. Threshold sensitivity: Default threshold (0.5) may need tuning for specific applications. High-precision applications should use higher thresholds.

  5. Training data bias: Performance may vary on domains not represented in FEVER/HAGRID/ASQA (e.g., legal, medical, code).

Ethical Considerations

Intended Benefits

  • Improved reliability of AI-generated citations
  • Reduced misinformation from RAG hallucinations
  • Better transparency in AI-assisted research

Potential Risks

  • Over-reliance on automated verification (human review still recommended for high-stakes applications)
  • False negatives may incorrectly flag valid citations
  • False positives may miss genuine attribution errors

Recommendations

  • Use as one signal among many, not sole arbiter
  • Monitor performance on domain-specific data
  • Combine with human review for critical applications

This model is part of an anonymous submission to ACL 2026. Author information will be added upon acceptance.

Downloads last month
18
Safetensors
Model size
0.4B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for convexray/alignment-module-cross-encoder-base

Finetuned
(494)
this model

Datasets used to train convexray/alignment-module-cross-encoder-base

Evaluation results