dlm-vi2en-checkpoint-66000 / modeling_dlm.py
myduy's picture
Update modeling with argmax_decoding support
400e4bd verified
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedModel, AutoModelForMaskedLM, AutoConfig
try:
from .configuration_dlm import DiscreteDiffusionConfig
except ImportError:
from configuration_dlm import DiscreteDiffusionConfig
from collections import namedtuple
import math
import numpy as np
from typing import List, Optional, Tuple, Union
decoder_out_t = namedtuple(
"decoder_out_t",
["output_tokens", "output_scores", "output_masks", "non_fixed_sym_masks", "attn", "step", "max_step", "history"],
)
def topk_masking(scores, cutoff_len, stochastic=False, temp=1.0):
"""
scores: [b, n]
cutoff_len: [b, 1]
stochastic: bool, whether to add noise to select top_k or not
returns:
mask: [b, n], with 1 if the token is in top-k lowest scores, 0 otherwise
"""
if stochastic:
gumbel_noise = -torch.log(-torch.log(torch.rand_like(scores) + 1e-8) + 1e-8)
_scores = scores + temp * gumbel_noise
else:
_scores = scores
sorted_index = _scores.sort(-1)[0]
cutoff = sorted_index.gather(dim=-1, index=cutoff_len) # + 1e-10
# cutoff_len = k -> select k + 1 tokens
masking = _scores < cutoff
return masking
class DiscreteDiffusionModel(PreTrainedModel):
config_class = DiscreteDiffusionConfig
_keys_to_ignore_on_load_missing = ["fake_layer", "length_trm", "length_predictor", "model.lm_head.decoder.weight"]
def __init__(self, config: DiscreteDiffusionConfig):
super().__init__(config)
self.config = config
self.args = config # Alias for compatibility with existing code
# Initialize backbone
if config.backbone_config:
# We assume backbone_config is a dict
backbone_config_obj = AutoConfig.for_model(**config.backbone_config)
self.model = AutoModelForMaskedLM.from_config(backbone_config_obj)
else:
# Fallback or error
raise ValueError("backbone_config must be provided in config")
if config.tie_word_embeddings:
self.model.lm_head.decoder.weight = self.model.roberta.embeddings.word_embeddings.weight
self.mask_id = config.mask_token_id
self.bos_id = config.bos_token_id
self.eos_id = config.eos_token_id
self.pad_id = config.pad_token_id
# Lora
if config.lora:
self.add_fake_layer()
# Length predictor (optional, as in original code)
self.length_trm = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=self.config.hidden_size,
nhead=self.config.num_attention_heads,
dim_feedforward=self.config.intermediate_size,
batch_first=True
),
num_layers=1,
)
self.length_predictor = nn.Sequential(
nn.Linear(self.config.hidden_size , self.config.intermediate_size),
nn.Tanh(),
nn.Linear(self.config.intermediate_size, self.config.max_position_embeddings)
)
def add_fake_layer(self):
self.fake_layer = nn.Parameter(torch.zeros((self.config.hidden_size, )))
def gradient_checkpointing_enable(self):
self.model.gradient_checkpointing_enable()
def _tie_weights(self):
"""Tie the weights between the input embeddings and the output embeddings."""
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(
self.model.lm_head.decoder,
self.model.roberta.embeddings.word_embeddings
)
def _init_weights(self, module):
"""Initialize the weights - called after loading checkpoint."""
# Call parent init_weights
super()._init_weights(module)
# Ensure weights are tied after initialization
self._tie_weights()
@property
def _tied_weights_keys(self):
"""Return the keys of tied weights."""
if self.config.tie_word_embeddings:
return ["model.lm_head.decoder.weight"]
return []
def q_sample_coupled(self, x_0, t1, t2, maskable_mask):
# ... copy from DiscreteDiffusionBase ...
assert self.config.diffusion_type == "absorbing", "we only support absorbing diffusion temporarily"
t1_eq_t2_mask = (t1 == t2)
t1, t2 = torch.maximum(t1, t2).float(), torch.minimum(t1, t2).float()
u = torch.rand_like(x_0, dtype=torch.float)
t1_mask = (u < (t1 / self.config.num_diffusion_timesteps)[:, None]) & maskable_mask
x_t1 = x_0.masked_fill(t1_mask, self.mask_id)
u = torch.rand_like(x_0, dtype=torch.float)
t2_mask = t1_mask & (u > ((t1 - t2) / t1)[:, None])
u = torch.rand_like(x_0[t1_eq_t2_mask], dtype=torch.float)
t2_mask[t1_eq_t2_mask] = (u < (t1[t1_eq_t2_mask] / self.config.num_diffusion_timesteps)[:, None]) & (maskable_mask[t1_eq_t2_mask])
x_t2 = x_0.masked_fill(t2_mask, self.mask_id)
return {
"x_t": torch.cat([x_t1, x_t2], dim=0),
"t": torch.cat([t1, t2]),
"mask_mask": torch.cat([t1_mask, t2_mask], dim=0)
}
def initialize_decode_samples(self, tokens, partial_masks, prefix_masks, oracle_length=False, length_beam=1, mbr=1):
# ... copy from DiscreteDiffusionBase ...
if tokens is None:
raise NotImplementedError
else:
if not oracle_length:
inputs_tokens = tokens.masked_fill(~prefix_masks, self.pad_id)
src_length = inputs_tokens.ne(self.pad_id).sum(dim=-1)
inputs_tokens = inputs_tokens[:, :src_length.max()]
length_logits = self.forward_length(inputs_tokens)
# Giới hạn độ dài output tối đa: không quá 3x độ dài source và không quá 100 tokens
max_allowed_length = torch.min(
torch.tensor([100]).to(src_length.device),
(src_length * 3)[:, None]
)
length = (
torch.min(
torch.min(
length_logits.topk(length_beam, dim=-1).indices + 1,
max_allowed_length
),
self.config.max_position_embeddings - 2 - src_length[:, None] - 1
)
)
output_tokens = []
new_partial_masks = []
for i, token in enumerate(inputs_tokens):
for b in range(length_beam):
for m in range(mbr):
# Create output token sequence
seq = torch.cat([
token[:src_length[i]],
torch.tensor([self.mask_id] * length[i][b] + [self.eos_id]).to(token)
])
output_tokens.append(seq)
# Create corresponding partial mask
# True for fixed (source), False for generated (mask/eos)
# partial_masks[i] corresponds to token[i]
# We assume partial_masks[i] has same length as token[i] (or at least src_length[i])
p_mask = torch.cat([
partial_masks[i][:src_length[i]],
torch.tensor([False] * (length[i][b] + 1)).to(partial_masks)
])
new_partial_masks.append(p_mask)
output_tokens = pad_sequence(output_tokens, batch_first=True, padding_value=self.pad_id)
# Pad partial masks to match output_tokens length
# We need to pad with True (fixed) or False (maskable)?
# Usually padding tokens should be ignored.
# In finalized_hypos: cutoff = tokens.ne(pad) & ... & (~partial_mask)
# If we pad partial_mask with True, ~partial_mask is False, so it's filtered out.
# If we pad with False, ~partial_mask is True, so it's kept (if not pad_id).
# Since we check tokens.ne(pad_id), padding tokens are filtered anyway.
# But for safety, let's pad with True (fixed) so they are treated as non-generated?
# Actually, pad_sequence pads with 0. For bool tensor, 0 is False.
# So if we use pad_sequence on bool tensor, it pads with False.
partial_masks = pad_sequence(new_partial_masks, batch_first=True, padding_value=True) # Pad with True to be safe?
# Wait, if we pad with True, then ~partial_mask is False.
output_mask = output_tokens.eq(self.mask_id)
# non_fixed_sym_masks should be all positions that can be modified (not source, not pad, not special tokens)
# This is critical for _reparam_decoding to work correctly!
non_fixed_sym_masks = (
output_tokens.ne(self.pad_id) &
output_tokens.ne(self.bos_id) &
~partial_masks # Not source tokens
)
else:
output_tokens = torch.stack([token for token in tokens for m in range(mbr)])
partial_masks = torch.stack([mask for mask in partial_masks for m in range(mbr)])
prefix_masks = torch.stack([mask for mask in prefix_masks for m in range(mbr)])
output_mask = (
output_tokens.ne(self.pad_id) &
output_tokens.ne(self.bos_id) &
output_tokens.ne(self.eos_id) &
~prefix_masks
)
output_tokens = output_tokens.masked_fill(output_mask, self.mask_id)
non_fixed_sym_masks = output_mask.clone()
output_scores = torch.zeros_like(output_tokens, dtype=torch.float)
return partial_masks, decoder_out_t(
output_tokens=output_tokens,
output_scores=output_scores,
output_masks=output_mask,
non_fixed_sym_masks=non_fixed_sym_masks,
attn=None,
step=0,
max_step=math.inf,
history=None
)
def forward_length(self, input_ids):
attention_mask = input_ids.ne(self.pad_id).int()
with torch.no_grad():
_feature = self.model.roberta(input_ids, attention_mask=attention_mask)[0]
feature = self.length_trm(_feature, src_key_padding_mask=(1-attention_mask).bool())
length = attention_mask.sum(dim=-1)
pooled_feature = feature.masked_fill((attention_mask==0)[:, :, None], 0).float().sum(1) / length[:, None]
length_logits = self.length_predictor(pooled_feature.to(feature))
return length_logits
def forward(self, prev_output_tokens, partial_mask, attention_mask=None, loss_mask=None, cache=None):
input_ids = prev_output_tokens
if attention_mask is None:
attention_mask = prev_output_tokens.ne(self.pad_id).int()
embeddings = self.model.roberta.embeddings.word_embeddings(input_ids)
if hasattr(self, "fake_layer") and self.training:
self.fake_layer.requires_grad = True
embeddings = embeddings + self.fake_layer * 0
if self.config.attention_strategy == "prefix_lm":
# ... simplified for now, assuming full attention or handling it ...
# Copying logic from original
ext_partial_mask = partial_mask.float()
ext_partial_mask = torch.bmm(ext_partial_mask[:, :, None], ext_partial_mask[:, None, :]).int()
ext_mask = attention_mask[:, None, :].repeat(1, attention_mask.size(-1), 1)
ext_mask[partial_mask] = ext_partial_mask[partial_mask]
outputs = self.model.roberta(inputs_embeds=embeddings, attention_mask=ext_mask)[0]
else:
outputs = self.model.roberta(inputs_embeds=embeddings, attention_mask=attention_mask)[0]
if not (~torch.isnan(outputs)).all():
outputs.masked_fill_(outputs.isnan(), 0)
outputs = outputs[loss_mask] if loss_mask is not None else outputs
return self.model.lm_head(outputs)
def _reparam_decoding(
self,
output_tokens,
output_scores,
cur_tokens,
cur_scores,
decoding_strategy,
xt_neq_x0,
non_special_sym_mask,
t,
max_step,
noise
):
_, condition, topk_mode, schedule = decoding_strategy.split("-")
if schedule == "linear":
rate = 1 - t / max_step
elif schedule == "cosine":
rate = np.cos(t / max_step * np.pi * 0.5)
else:
raise NotImplementedError
cutoff_len = (
non_special_sym_mask.sum(1, keepdim=True).type_as(output_scores) * rate
).long()
_scores_for_topk = cur_scores.masked_fill(~non_special_sym_mask, 1000.0)
if topk_mode.startswith("stochastic"):
noise_scale = float(topk_mode.replace("stochastic", ""))
lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=True, temp=noise_scale * rate)
elif topk_mode == "deterministic":
lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=False)
else:
raise NotImplementedError
if condition == "cond":
not_v1_t = (cur_tokens == output_tokens) & (cur_scores < output_scores) & lowest_k_mask
elif condition == "uncond":
not_v1_t = lowest_k_mask
else:
raise NotImplementedError
not_v2_t = lowest_k_mask
masked_to_noise = (~xt_neq_x0 & not_v1_t) | (xt_neq_x0 & not_v2_t)
if isinstance(noise, torch.Tensor):
output_tokens.masked_scatter_(masked_to_noise, noise[masked_to_noise])
elif isinstance(noise, (int, float)):
output_tokens.masked_fill_(masked_to_noise, noise)
else:
raise NotImplementedError("noise should be either a tensor or a scalar")
output_scores.masked_fill_(masked_to_noise, -math.inf)
masked_to_x0 = xt_neq_x0 & ~not_v2_t
output_tokens.masked_scatter_(masked_to_x0, cur_tokens[masked_to_x0])
output_scores.masked_scatter_(masked_to_x0, cur_scores[masked_to_x0])
new_xt_neq_x0 = (xt_neq_x0 | not_v1_t) & not_v2_t
return new_xt_neq_x0
def denoise_step(self, decoder_out, partial_masks, temperature=1.0, strategy="reparam-uncond-deterministic-cosine"):
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores
prev_step, cur_step = decoder_out.step, decoder_out.step + 1
max_step = decoder_out.max_step
logits = self.forward(output_tokens, partial_masks)
logits[..., self.mask_id] = -math.inf
scores = torch.log_softmax(logits, dim=-1)
if strategy == "cmlm":
# get the mask
# <bos>, <eos> are ignored in this case since
# they are not equal to unk.
output_masks = output_tokens.eq(self.mask_id)
unmask_prob = 1 / (max_step - prev_step)
# where to unmask
changes = torch.rand(output_tokens.shape, device=output_tokens.device) < unmask_prob
# don't unmask somewhere already unmasked
changes = torch.bitwise_and(changes, output_masks)
if getattr(self.config, "argmax_decoding", False):
output_scores, new_tokens = scores.max(-1)
else:
# Assuming dists is imported or available, otherwise use torch.multinomial or similar
# But let's stick to what was in generator if possible, or implement simple sampling
# The generator used: dists.Categorical(logits=scores / temperature).sample()
# We need to import dists or use torch.distributions
import torch.distributions as dists
new_tokens = dists.Categorical(logits=scores / temperature).sample()
output_scores = torch.gather(scores, -1, new_tokens.unsqueeze(-1)).squeeze(-1)
output_tokens[changes] = new_tokens[changes]
elif strategy == "ar":
output_masks = output_tokens.eq(self.mask_id)
unmask_indices = (output_tokens.ne(self.mask_id) & output_tokens.ne(self.eos_id) & output_tokens.ne(self.pad_id)).sum(dim=-1)
indices = torch.arange(output_tokens.size(-1)).expand(output_tokens.shape).to(output_masks.device)
if getattr(self.config, "argmax_decoding", False):
output_scores, new_tokens = scores.max(-1)
else:
import torch.distributions as dists
new_tokens = dists.Categorical(logits=scores / temperature).sample()
output_scores = torch.gather(scores, -1, new_tokens.unsqueeze(-1)).squeeze(-1)
output_tokens[unmask_indices[:, None]==indices] = new_tokens[unmask_indices[:, None]==indices]
else:
if getattr(self.config, "argmax_decoding", False):
cur_scores, cur_tokens = scores.max(-1)
else:
import torch.distributions as dists
cur_tokens = dists.Categorical(logits=scores / temperature).sample()
cur_scores = torch.gather(scores, -1, cur_tokens.unsqueeze(-1)).squeeze(-1)
cur_scores = cur_scores.to(output_scores)
output_masks = self._reparam_decoding(
output_tokens=output_tokens,
output_scores=output_scores,
cur_tokens=cur_tokens,
cur_scores=cur_scores,
decoding_strategy=strategy,
xt_neq_x0=decoder_out.output_masks,
non_special_sym_mask=decoder_out.non_fixed_sym_masks,
t=cur_step,
max_step=max_step,
noise=self.mask_id
)
history = (
([] if decoder_out.history is None else decoder_out.history) + [output_tokens.clone()]
if decoder_out.history is not None else None
)
return decoder_out._replace(
step=cur_step,
output_tokens=output_tokens,
output_scores=output_scores,
output_masks=output_masks,
history=history,
)
@torch.no_grad()
def generate(
self,
input_ids,
attention_mask=None,
max_iterations=10,
strategy="reparam-uncond-deterministic-cosine",
temperature=1.0,
return_history=False,
max_length=128, # Fixed generation length hyperparameter (like LLaDA)
**kwargs
):
# Prepare inputs
src_tokens = input_ids
if attention_mask is None:
partial_masks = torch.ones_like(src_tokens).bool()
else:
partial_masks = attention_mask.bool()
prefix_masks = partial_masks
# Initialize canvas with fixed length (LLaDA approach)
# Instead of predicting length, use max_length as hyperparameter
batch_size = src_tokens.size(0)
src_length = src_tokens.ne(self.pad_id).sum(dim=-1)
# Create fully masked response of fixed length
output_tokens = []
new_partial_masks = []
for i in range(batch_size):
# Format: <source_without_eos> <mask>...<mask> <eos>
# Remove EOS from source if it exists
src_len = src_length[i].item()
src_seq = src_tokens[i, :src_len]
# Remove trailing EOS from source
if src_seq[-1] == self.eos_id:
src_seq = src_seq[:-1]
src_len -= 1
seq = torch.cat([
src_seq,
torch.full((max_length,), self.mask_id, dtype=src_tokens.dtype, device=src_tokens.device),
torch.tensor([self.eos_id], dtype=src_tokens.dtype, device=src_tokens.device)
])
output_tokens.append(seq)
# Mask: True for source (fixed), False for generated part
mask = torch.cat([
torch.ones(src_len, dtype=torch.bool, device=src_tokens.device),
torch.zeros(max_length + 1, dtype=torch.bool, device=src_tokens.device) # +1 for eos
])
new_partial_masks.append(mask)
output_tokens = pad_sequence(output_tokens, batch_first=True, padding_value=self.pad_id)
partial_masks = pad_sequence(new_partial_masks, batch_first=True, padding_value=True)
# Create masks for decoding
output_mask = output_tokens.eq(self.mask_id)
non_fixed_sym_masks = (
output_tokens.ne(self.pad_id) &
output_tokens.ne(self.bos_id) &
~partial_masks # Not source tokens
)
output_scores = torch.zeros_like(output_tokens, dtype=torch.float)
prev_decoder_out = decoder_out_t(
output_tokens=output_tokens,
output_scores=output_scores,
output_masks=output_mask,
non_fixed_sym_masks=non_fixed_sym_masks,
attn=None,
step=0,
max_step=max_iterations,
history=None
)
if return_history:
prev_decoder_out = prev_decoder_out._replace(history=[])
for step in range(max_iterations):
prev_decoder_out = self.denoise_step(prev_decoder_out, partial_masks, temperature=temperature, strategy=strategy)
# Finalize: discard tokens after EOS (LLaDA approach)
def finalized_hypos(tokens, scores, partial_mask, history=None):
# First, find EOS position and cut there
eos_positions = (tokens == self.eos_id).nonzero(as_tuple=True)[0]
if len(eos_positions) > 0:
first_eos = eos_positions[0].item()
# Cut everything after EOS
tokens = tokens[:first_eos] # Exclude EOS
if scores is not None:
scores = scores[:first_eos]
partial_mask = partial_mask[:first_eos]
# Then apply cutoff logic: keep only generated tokens (not source, not special)
cutoff = (
tokens.ne(self.pad_id) &
tokens.ne(self.bos_id) &
tokens.ne(self.eos_id) &
(~partial_mask) # Not source tokens (partial_mask=False for generated)
)
tokens = tokens[cutoff]
if scores is None:
score = None
else:
scores = scores[cutoff]
score = scores.mean().item() if len(scores) > 0 else 0.0
ret_dict = {
"tokens": tokens,
"positional_scores": scores,
"score": score,
"alignment": None
}
if history is not None:
ret_dict["history"] = [
finalized_hypos(history_tokens, None, partial_mask, history=None)
for history_tokens in history
]
return ret_dict
def score_select(hyps):
index = np.argmax([hyp["score"] for hyp in hyps])
return hyps[index]
output_tokens, output_scores = prev_decoder_out.output_tokens, prev_decoder_out.output_scores
# Handle history if needed
if return_history and prev_decoder_out.history is not None:
full_history = prev_decoder_out.history
histories = [[full_history[j][i] for j in range(max_iterations)] for i in range(output_tokens.size(0))]
hyps = []
for tokens, scores, partial_mask, history in zip(output_tokens, output_scores, partial_masks, histories):
hyps.append(finalized_hypos(tokens, scores, partial_mask, history))
else:
hyps = [
finalized_hypos(tokens, scores, partial_mask, None)
for tokens, scores, partial_mask in zip(output_tokens, output_scores, partial_masks)
]
repeatition = kwargs.get("mbr", 1) * kwargs.get("length_beam", 1)
if repeatition > 1:
hyps = [score_select(hyps[i:i+repeatition]) for i in range(0, len(hyps), repeatition)]
finalized = pad_sequence([h["tokens"] for h in hyps ], batch_first=True, padding_value=self.pad_id)
# If the user expects just tokens, we return finalized tokens.
# The original model.generate returned just tokens.
return finalized