|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
from transformers import PreTrainedModel, AutoModelForMaskedLM, AutoConfig |
|
|
try: |
|
|
from .configuration_dlm import DiscreteDiffusionConfig |
|
|
except ImportError: |
|
|
from configuration_dlm import DiscreteDiffusionConfig |
|
|
|
|
|
from collections import namedtuple |
|
|
import math |
|
|
import numpy as np |
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
decoder_out_t = namedtuple( |
|
|
"decoder_out_t", |
|
|
["output_tokens", "output_scores", "output_masks", "non_fixed_sym_masks", "attn", "step", "max_step", "history"], |
|
|
) |
|
|
|
|
|
def topk_masking(scores, cutoff_len, stochastic=False, temp=1.0): |
|
|
""" |
|
|
scores: [b, n] |
|
|
cutoff_len: [b, 1] |
|
|
stochastic: bool, whether to add noise to select top_k or not |
|
|
returns: |
|
|
mask: [b, n], with 1 if the token is in top-k lowest scores, 0 otherwise |
|
|
""" |
|
|
if stochastic: |
|
|
gumbel_noise = -torch.log(-torch.log(torch.rand_like(scores) + 1e-8) + 1e-8) |
|
|
_scores = scores + temp * gumbel_noise |
|
|
else: |
|
|
_scores = scores |
|
|
sorted_index = _scores.sort(-1)[0] |
|
|
cutoff = sorted_index.gather(dim=-1, index=cutoff_len) |
|
|
|
|
|
masking = _scores < cutoff |
|
|
return masking |
|
|
|
|
|
class DiscreteDiffusionModel(PreTrainedModel): |
|
|
config_class = DiscreteDiffusionConfig |
|
|
_keys_to_ignore_on_load_missing = ["fake_layer", "length_trm", "length_predictor", "model.lm_head.decoder.weight"] |
|
|
|
|
|
def __init__(self, config: DiscreteDiffusionConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.args = config |
|
|
|
|
|
|
|
|
if config.backbone_config: |
|
|
|
|
|
backbone_config_obj = AutoConfig.for_model(**config.backbone_config) |
|
|
self.model = AutoModelForMaskedLM.from_config(backbone_config_obj) |
|
|
else: |
|
|
|
|
|
raise ValueError("backbone_config must be provided in config") |
|
|
|
|
|
if config.tie_word_embeddings: |
|
|
self.model.lm_head.decoder.weight = self.model.roberta.embeddings.word_embeddings.weight |
|
|
|
|
|
self.mask_id = config.mask_token_id |
|
|
self.bos_id = config.bos_token_id |
|
|
self.eos_id = config.eos_token_id |
|
|
self.pad_id = config.pad_token_id |
|
|
|
|
|
|
|
|
if config.lora: |
|
|
self.add_fake_layer() |
|
|
|
|
|
|
|
|
self.length_trm = nn.TransformerEncoder( |
|
|
nn.TransformerEncoderLayer( |
|
|
d_model=self.config.hidden_size, |
|
|
nhead=self.config.num_attention_heads, |
|
|
dim_feedforward=self.config.intermediate_size, |
|
|
batch_first=True |
|
|
), |
|
|
num_layers=1, |
|
|
) |
|
|
self.length_predictor = nn.Sequential( |
|
|
nn.Linear(self.config.hidden_size , self.config.intermediate_size), |
|
|
nn.Tanh(), |
|
|
nn.Linear(self.config.intermediate_size, self.config.max_position_embeddings) |
|
|
) |
|
|
|
|
|
def add_fake_layer(self): |
|
|
self.fake_layer = nn.Parameter(torch.zeros((self.config.hidden_size, ))) |
|
|
|
|
|
def gradient_checkpointing_enable(self): |
|
|
self.model.gradient_checkpointing_enable() |
|
|
|
|
|
def _tie_weights(self): |
|
|
"""Tie the weights between the input embeddings and the output embeddings.""" |
|
|
if self.config.tie_word_embeddings: |
|
|
self._tie_or_clone_weights( |
|
|
self.model.lm_head.decoder, |
|
|
self.model.roberta.embeddings.word_embeddings |
|
|
) |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""Initialize the weights - called after loading checkpoint.""" |
|
|
|
|
|
super()._init_weights(module) |
|
|
|
|
|
self._tie_weights() |
|
|
|
|
|
@property |
|
|
def _tied_weights_keys(self): |
|
|
"""Return the keys of tied weights.""" |
|
|
if self.config.tie_word_embeddings: |
|
|
return ["model.lm_head.decoder.weight"] |
|
|
return [] |
|
|
|
|
|
def q_sample_coupled(self, x_0, t1, t2, maskable_mask): |
|
|
|
|
|
assert self.config.diffusion_type == "absorbing", "we only support absorbing diffusion temporarily" |
|
|
t1_eq_t2_mask = (t1 == t2) |
|
|
t1, t2 = torch.maximum(t1, t2).float(), torch.minimum(t1, t2).float() |
|
|
|
|
|
u = torch.rand_like(x_0, dtype=torch.float) |
|
|
t1_mask = (u < (t1 / self.config.num_diffusion_timesteps)[:, None]) & maskable_mask |
|
|
x_t1 = x_0.masked_fill(t1_mask, self.mask_id) |
|
|
|
|
|
u = torch.rand_like(x_0, dtype=torch.float) |
|
|
t2_mask = t1_mask & (u > ((t1 - t2) / t1)[:, None]) |
|
|
u = torch.rand_like(x_0[t1_eq_t2_mask], dtype=torch.float) |
|
|
t2_mask[t1_eq_t2_mask] = (u < (t1[t1_eq_t2_mask] / self.config.num_diffusion_timesteps)[:, None]) & (maskable_mask[t1_eq_t2_mask]) |
|
|
x_t2 = x_0.masked_fill(t2_mask, self.mask_id) |
|
|
|
|
|
return { |
|
|
"x_t": torch.cat([x_t1, x_t2], dim=0), |
|
|
"t": torch.cat([t1, t2]), |
|
|
"mask_mask": torch.cat([t1_mask, t2_mask], dim=0) |
|
|
} |
|
|
|
|
|
def initialize_decode_samples(self, tokens, partial_masks, prefix_masks, oracle_length=False, length_beam=1, mbr=1): |
|
|
|
|
|
if tokens is None: |
|
|
raise NotImplementedError |
|
|
else: |
|
|
if not oracle_length: |
|
|
inputs_tokens = tokens.masked_fill(~prefix_masks, self.pad_id) |
|
|
src_length = inputs_tokens.ne(self.pad_id).sum(dim=-1) |
|
|
inputs_tokens = inputs_tokens[:, :src_length.max()] |
|
|
length_logits = self.forward_length(inputs_tokens) |
|
|
|
|
|
max_allowed_length = torch.min( |
|
|
torch.tensor([100]).to(src_length.device), |
|
|
(src_length * 3)[:, None] |
|
|
) |
|
|
length = ( |
|
|
torch.min( |
|
|
torch.min( |
|
|
length_logits.topk(length_beam, dim=-1).indices + 1, |
|
|
max_allowed_length |
|
|
), |
|
|
self.config.max_position_embeddings - 2 - src_length[:, None] - 1 |
|
|
) |
|
|
) |
|
|
output_tokens = [] |
|
|
new_partial_masks = [] |
|
|
for i, token in enumerate(inputs_tokens): |
|
|
for b in range(length_beam): |
|
|
for m in range(mbr): |
|
|
|
|
|
seq = torch.cat([ |
|
|
token[:src_length[i]], |
|
|
torch.tensor([self.mask_id] * length[i][b] + [self.eos_id]).to(token) |
|
|
]) |
|
|
output_tokens.append(seq) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
p_mask = torch.cat([ |
|
|
partial_masks[i][:src_length[i]], |
|
|
torch.tensor([False] * (length[i][b] + 1)).to(partial_masks) |
|
|
]) |
|
|
new_partial_masks.append(p_mask) |
|
|
|
|
|
output_tokens = pad_sequence(output_tokens, batch_first=True, padding_value=self.pad_id) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
partial_masks = pad_sequence(new_partial_masks, batch_first=True, padding_value=True) |
|
|
|
|
|
|
|
|
output_mask = output_tokens.eq(self.mask_id) |
|
|
|
|
|
|
|
|
non_fixed_sym_masks = ( |
|
|
output_tokens.ne(self.pad_id) & |
|
|
output_tokens.ne(self.bos_id) & |
|
|
~partial_masks |
|
|
) |
|
|
else: |
|
|
output_tokens = torch.stack([token for token in tokens for m in range(mbr)]) |
|
|
partial_masks = torch.stack([mask for mask in partial_masks for m in range(mbr)]) |
|
|
prefix_masks = torch.stack([mask for mask in prefix_masks for m in range(mbr)]) |
|
|
output_mask = ( |
|
|
output_tokens.ne(self.pad_id) & |
|
|
output_tokens.ne(self.bos_id) & |
|
|
output_tokens.ne(self.eos_id) & |
|
|
~prefix_masks |
|
|
) |
|
|
output_tokens = output_tokens.masked_fill(output_mask, self.mask_id) |
|
|
non_fixed_sym_masks = output_mask.clone() |
|
|
output_scores = torch.zeros_like(output_tokens, dtype=torch.float) |
|
|
|
|
|
return partial_masks, decoder_out_t( |
|
|
output_tokens=output_tokens, |
|
|
output_scores=output_scores, |
|
|
output_masks=output_mask, |
|
|
non_fixed_sym_masks=non_fixed_sym_masks, |
|
|
attn=None, |
|
|
step=0, |
|
|
max_step=math.inf, |
|
|
history=None |
|
|
) |
|
|
|
|
|
def forward_length(self, input_ids): |
|
|
attention_mask = input_ids.ne(self.pad_id).int() |
|
|
with torch.no_grad(): |
|
|
_feature = self.model.roberta(input_ids, attention_mask=attention_mask)[0] |
|
|
feature = self.length_trm(_feature, src_key_padding_mask=(1-attention_mask).bool()) |
|
|
length = attention_mask.sum(dim=-1) |
|
|
pooled_feature = feature.masked_fill((attention_mask==0)[:, :, None], 0).float().sum(1) / length[:, None] |
|
|
length_logits = self.length_predictor(pooled_feature.to(feature)) |
|
|
return length_logits |
|
|
|
|
|
def forward(self, prev_output_tokens, partial_mask, attention_mask=None, loss_mask=None, cache=None): |
|
|
input_ids = prev_output_tokens |
|
|
if attention_mask is None: |
|
|
attention_mask = prev_output_tokens.ne(self.pad_id).int() |
|
|
|
|
|
embeddings = self.model.roberta.embeddings.word_embeddings(input_ids) |
|
|
|
|
|
if hasattr(self, "fake_layer") and self.training: |
|
|
self.fake_layer.requires_grad = True |
|
|
embeddings = embeddings + self.fake_layer * 0 |
|
|
|
|
|
if self.config.attention_strategy == "prefix_lm": |
|
|
|
|
|
|
|
|
ext_partial_mask = partial_mask.float() |
|
|
ext_partial_mask = torch.bmm(ext_partial_mask[:, :, None], ext_partial_mask[:, None, :]).int() |
|
|
ext_mask = attention_mask[:, None, :].repeat(1, attention_mask.size(-1), 1) |
|
|
ext_mask[partial_mask] = ext_partial_mask[partial_mask] |
|
|
outputs = self.model.roberta(inputs_embeds=embeddings, attention_mask=ext_mask)[0] |
|
|
else: |
|
|
outputs = self.model.roberta(inputs_embeds=embeddings, attention_mask=attention_mask)[0] |
|
|
|
|
|
if not (~torch.isnan(outputs)).all(): |
|
|
outputs.masked_fill_(outputs.isnan(), 0) |
|
|
|
|
|
outputs = outputs[loss_mask] if loss_mask is not None else outputs |
|
|
return self.model.lm_head(outputs) |
|
|
|
|
|
def _reparam_decoding( |
|
|
self, |
|
|
output_tokens, |
|
|
output_scores, |
|
|
cur_tokens, |
|
|
cur_scores, |
|
|
decoding_strategy, |
|
|
xt_neq_x0, |
|
|
non_special_sym_mask, |
|
|
t, |
|
|
max_step, |
|
|
noise |
|
|
): |
|
|
_, condition, topk_mode, schedule = decoding_strategy.split("-") |
|
|
|
|
|
if schedule == "linear": |
|
|
rate = 1 - t / max_step |
|
|
elif schedule == "cosine": |
|
|
rate = np.cos(t / max_step * np.pi * 0.5) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
cutoff_len = ( |
|
|
non_special_sym_mask.sum(1, keepdim=True).type_as(output_scores) * rate |
|
|
).long() |
|
|
_scores_for_topk = cur_scores.masked_fill(~non_special_sym_mask, 1000.0) |
|
|
|
|
|
if topk_mode.startswith("stochastic"): |
|
|
noise_scale = float(topk_mode.replace("stochastic", "")) |
|
|
lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=True, temp=noise_scale * rate) |
|
|
elif topk_mode == "deterministic": |
|
|
lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=False) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
if condition == "cond": |
|
|
not_v1_t = (cur_tokens == output_tokens) & (cur_scores < output_scores) & lowest_k_mask |
|
|
elif condition == "uncond": |
|
|
not_v1_t = lowest_k_mask |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
not_v2_t = lowest_k_mask |
|
|
|
|
|
masked_to_noise = (~xt_neq_x0 & not_v1_t) | (xt_neq_x0 & not_v2_t) |
|
|
if isinstance(noise, torch.Tensor): |
|
|
output_tokens.masked_scatter_(masked_to_noise, noise[masked_to_noise]) |
|
|
elif isinstance(noise, (int, float)): |
|
|
output_tokens.masked_fill_(masked_to_noise, noise) |
|
|
else: |
|
|
raise NotImplementedError("noise should be either a tensor or a scalar") |
|
|
output_scores.masked_fill_(masked_to_noise, -math.inf) |
|
|
|
|
|
masked_to_x0 = xt_neq_x0 & ~not_v2_t |
|
|
output_tokens.masked_scatter_(masked_to_x0, cur_tokens[masked_to_x0]) |
|
|
output_scores.masked_scatter_(masked_to_x0, cur_scores[masked_to_x0]) |
|
|
|
|
|
new_xt_neq_x0 = (xt_neq_x0 | not_v1_t) & not_v2_t |
|
|
return new_xt_neq_x0 |
|
|
|
|
|
def denoise_step(self, decoder_out, partial_masks, temperature=1.0, strategy="reparam-uncond-deterministic-cosine"): |
|
|
output_tokens = decoder_out.output_tokens |
|
|
output_scores = decoder_out.output_scores |
|
|
prev_step, cur_step = decoder_out.step, decoder_out.step + 1 |
|
|
max_step = decoder_out.max_step |
|
|
|
|
|
logits = self.forward(output_tokens, partial_masks) |
|
|
|
|
|
logits[..., self.mask_id] = -math.inf |
|
|
scores = torch.log_softmax(logits, dim=-1) |
|
|
|
|
|
if strategy == "cmlm": |
|
|
|
|
|
|
|
|
|
|
|
output_masks = output_tokens.eq(self.mask_id) |
|
|
unmask_prob = 1 / (max_step - prev_step) |
|
|
|
|
|
changes = torch.rand(output_tokens.shape, device=output_tokens.device) < unmask_prob |
|
|
|
|
|
changes = torch.bitwise_and(changes, output_masks) |
|
|
|
|
|
if getattr(self.config, "argmax_decoding", False): |
|
|
output_scores, new_tokens = scores.max(-1) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch.distributions as dists |
|
|
new_tokens = dists.Categorical(logits=scores / temperature).sample() |
|
|
output_scores = torch.gather(scores, -1, new_tokens.unsqueeze(-1)).squeeze(-1) |
|
|
output_tokens[changes] = new_tokens[changes] |
|
|
elif strategy == "ar": |
|
|
output_masks = output_tokens.eq(self.mask_id) |
|
|
unmask_indices = (output_tokens.ne(self.mask_id) & output_tokens.ne(self.eos_id) & output_tokens.ne(self.pad_id)).sum(dim=-1) |
|
|
indices = torch.arange(output_tokens.size(-1)).expand(output_tokens.shape).to(output_masks.device) |
|
|
if getattr(self.config, "argmax_decoding", False): |
|
|
output_scores, new_tokens = scores.max(-1) |
|
|
else: |
|
|
import torch.distributions as dists |
|
|
new_tokens = dists.Categorical(logits=scores / temperature).sample() |
|
|
output_scores = torch.gather(scores, -1, new_tokens.unsqueeze(-1)).squeeze(-1) |
|
|
output_tokens[unmask_indices[:, None]==indices] = new_tokens[unmask_indices[:, None]==indices] |
|
|
else: |
|
|
if getattr(self.config, "argmax_decoding", False): |
|
|
cur_scores, cur_tokens = scores.max(-1) |
|
|
else: |
|
|
import torch.distributions as dists |
|
|
cur_tokens = dists.Categorical(logits=scores / temperature).sample() |
|
|
cur_scores = torch.gather(scores, -1, cur_tokens.unsqueeze(-1)).squeeze(-1) |
|
|
cur_scores = cur_scores.to(output_scores) |
|
|
|
|
|
output_masks = self._reparam_decoding( |
|
|
output_tokens=output_tokens, |
|
|
output_scores=output_scores, |
|
|
cur_tokens=cur_tokens, |
|
|
cur_scores=cur_scores, |
|
|
decoding_strategy=strategy, |
|
|
xt_neq_x0=decoder_out.output_masks, |
|
|
non_special_sym_mask=decoder_out.non_fixed_sym_masks, |
|
|
t=cur_step, |
|
|
max_step=max_step, |
|
|
noise=self.mask_id |
|
|
) |
|
|
|
|
|
history = ( |
|
|
([] if decoder_out.history is None else decoder_out.history) + [output_tokens.clone()] |
|
|
if decoder_out.history is not None else None |
|
|
) |
|
|
|
|
|
return decoder_out._replace( |
|
|
step=cur_step, |
|
|
output_tokens=output_tokens, |
|
|
output_scores=output_scores, |
|
|
output_masks=output_masks, |
|
|
history=history, |
|
|
) |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
input_ids, |
|
|
attention_mask=None, |
|
|
max_iterations=10, |
|
|
strategy="reparam-uncond-deterministic-cosine", |
|
|
temperature=1.0, |
|
|
return_history=False, |
|
|
max_length=128, |
|
|
**kwargs |
|
|
): |
|
|
|
|
|
src_tokens = input_ids |
|
|
|
|
|
if attention_mask is None: |
|
|
partial_masks = torch.ones_like(src_tokens).bool() |
|
|
else: |
|
|
partial_masks = attention_mask.bool() |
|
|
|
|
|
prefix_masks = partial_masks |
|
|
|
|
|
|
|
|
|
|
|
batch_size = src_tokens.size(0) |
|
|
src_length = src_tokens.ne(self.pad_id).sum(dim=-1) |
|
|
|
|
|
|
|
|
output_tokens = [] |
|
|
new_partial_masks = [] |
|
|
|
|
|
for i in range(batch_size): |
|
|
|
|
|
|
|
|
src_len = src_length[i].item() |
|
|
src_seq = src_tokens[i, :src_len] |
|
|
|
|
|
|
|
|
if src_seq[-1] == self.eos_id: |
|
|
src_seq = src_seq[:-1] |
|
|
src_len -= 1 |
|
|
|
|
|
seq = torch.cat([ |
|
|
src_seq, |
|
|
torch.full((max_length,), self.mask_id, dtype=src_tokens.dtype, device=src_tokens.device), |
|
|
torch.tensor([self.eos_id], dtype=src_tokens.dtype, device=src_tokens.device) |
|
|
]) |
|
|
output_tokens.append(seq) |
|
|
|
|
|
|
|
|
mask = torch.cat([ |
|
|
torch.ones(src_len, dtype=torch.bool, device=src_tokens.device), |
|
|
torch.zeros(max_length + 1, dtype=torch.bool, device=src_tokens.device) |
|
|
]) |
|
|
new_partial_masks.append(mask) |
|
|
|
|
|
output_tokens = pad_sequence(output_tokens, batch_first=True, padding_value=self.pad_id) |
|
|
partial_masks = pad_sequence(new_partial_masks, batch_first=True, padding_value=True) |
|
|
|
|
|
|
|
|
output_mask = output_tokens.eq(self.mask_id) |
|
|
non_fixed_sym_masks = ( |
|
|
output_tokens.ne(self.pad_id) & |
|
|
output_tokens.ne(self.bos_id) & |
|
|
~partial_masks |
|
|
) |
|
|
|
|
|
output_scores = torch.zeros_like(output_tokens, dtype=torch.float) |
|
|
|
|
|
prev_decoder_out = decoder_out_t( |
|
|
output_tokens=output_tokens, |
|
|
output_scores=output_scores, |
|
|
output_masks=output_mask, |
|
|
non_fixed_sym_masks=non_fixed_sym_masks, |
|
|
attn=None, |
|
|
step=0, |
|
|
max_step=max_iterations, |
|
|
history=None |
|
|
) |
|
|
|
|
|
if return_history: |
|
|
prev_decoder_out = prev_decoder_out._replace(history=[]) |
|
|
|
|
|
for step in range(max_iterations): |
|
|
prev_decoder_out = self.denoise_step(prev_decoder_out, partial_masks, temperature=temperature, strategy=strategy) |
|
|
|
|
|
|
|
|
def finalized_hypos(tokens, scores, partial_mask, history=None): |
|
|
|
|
|
eos_positions = (tokens == self.eos_id).nonzero(as_tuple=True)[0] |
|
|
if len(eos_positions) > 0: |
|
|
first_eos = eos_positions[0].item() |
|
|
|
|
|
tokens = tokens[:first_eos] |
|
|
if scores is not None: |
|
|
scores = scores[:first_eos] |
|
|
partial_mask = partial_mask[:first_eos] |
|
|
|
|
|
|
|
|
cutoff = ( |
|
|
tokens.ne(self.pad_id) & |
|
|
tokens.ne(self.bos_id) & |
|
|
tokens.ne(self.eos_id) & |
|
|
(~partial_mask) |
|
|
) |
|
|
tokens = tokens[cutoff] |
|
|
if scores is None: |
|
|
score = None |
|
|
else: |
|
|
scores = scores[cutoff] |
|
|
score = scores.mean().item() if len(scores) > 0 else 0.0 |
|
|
ret_dict = { |
|
|
"tokens": tokens, |
|
|
"positional_scores": scores, |
|
|
"score": score, |
|
|
"alignment": None |
|
|
} |
|
|
if history is not None: |
|
|
ret_dict["history"] = [ |
|
|
finalized_hypos(history_tokens, None, partial_mask, history=None) |
|
|
for history_tokens in history |
|
|
] |
|
|
return ret_dict |
|
|
|
|
|
def score_select(hyps): |
|
|
index = np.argmax([hyp["score"] for hyp in hyps]) |
|
|
return hyps[index] |
|
|
|
|
|
output_tokens, output_scores = prev_decoder_out.output_tokens, prev_decoder_out.output_scores |
|
|
|
|
|
|
|
|
if return_history and prev_decoder_out.history is not None: |
|
|
full_history = prev_decoder_out.history |
|
|
histories = [[full_history[j][i] for j in range(max_iterations)] for i in range(output_tokens.size(0))] |
|
|
hyps = [] |
|
|
for tokens, scores, partial_mask, history in zip(output_tokens, output_scores, partial_masks, histories): |
|
|
hyps.append(finalized_hypos(tokens, scores, partial_mask, history)) |
|
|
else: |
|
|
hyps = [ |
|
|
finalized_hypos(tokens, scores, partial_mask, None) |
|
|
for tokens, scores, partial_mask in zip(output_tokens, output_scores, partial_masks) |
|
|
] |
|
|
|
|
|
repeatition = kwargs.get("mbr", 1) * kwargs.get("length_beam", 1) |
|
|
if repeatition > 1: |
|
|
hyps = [score_select(hyps[i:i+repeatition]) for i in range(0, len(hyps), repeatition)] |
|
|
|
|
|
finalized = pad_sequence([h["tokens"] for h in hyps ], batch_first=True, padding_value=self.pad_id) |
|
|
|
|
|
|
|
|
|
|
|
return finalized |
|
|
|