|
|
from dataclasses import dataclass, field |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import torch.distributions as dists |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
import math |
|
|
|
|
|
|
|
|
import sacrebleu |
|
|
|
|
|
from rouge import Rouge |
|
|
|
|
|
@dataclass |
|
|
class DiscreteDiffusionGeneratorArguments: |
|
|
max_iterations: int = field( |
|
|
default=10 |
|
|
) |
|
|
mbr: int = field( |
|
|
default=1 |
|
|
) |
|
|
length_beam: int = field( |
|
|
default=1 |
|
|
) |
|
|
oracle_length: bool = field( |
|
|
default=False |
|
|
) |
|
|
strategy: str = field( |
|
|
default="reparam-uncond-deterministic-cosine" |
|
|
) |
|
|
argmax_decoding: bool = field( |
|
|
default=True |
|
|
) |
|
|
bpe: str = field( |
|
|
default="sentencepiece" |
|
|
) |
|
|
bleu_tokenize: str = field( |
|
|
default="13a" |
|
|
) |
|
|
return_history: bool = field( |
|
|
default=False |
|
|
) |
|
|
temperature: float = field( |
|
|
default=0.8 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def topk_masking(scores, cutoff_len, stochastic=False, temp=1.0): |
|
|
""" |
|
|
scores: [b, n] |
|
|
cutoff_len: [b, 1] |
|
|
stochastic: bool, whether to add noise to select top_k or not |
|
|
returns: |
|
|
mask: [b, n], with 1 if the token is in top-k lowest scores, 0 otherwise |
|
|
""" |
|
|
if stochastic: |
|
|
gumbel_noise = -torch.log(-torch.log(torch.rand_like(scores) + 1e-8) + 1e-8) |
|
|
_scores = scores + temp * gumbel_noise |
|
|
else: |
|
|
_scores = scores |
|
|
sorted_index = _scores.sort(-1)[0] |
|
|
cutoff = sorted_index.gather(dim=-1, index=cutoff_len) |
|
|
|
|
|
masking = _scores < cutoff |
|
|
try: |
|
|
assert (~(cutoff_len == 0).all()) | (~masking).all() |
|
|
except: |
|
|
import ipdb;ipdb.set_trace() |
|
|
return masking |
|
|
|
|
|
|
|
|
class MergeBLEU(object): |
|
|
def __call__(self, evalpreds): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
|
sys_stats, ref_stats = evalpreds[0], evalpreds[1] |
|
|
|
|
|
sys_stats = sys_stats.reshape(-1, 5).astype('long').sum(0).tolist() |
|
|
ref_stats = ref_stats.reshape(-1, 5).astype('long').sum(0).tolist() |
|
|
try: |
|
|
from sacrebleu.metrics import BLEU |
|
|
comp_bleu = BLEU.compute_bleu |
|
|
except ImportError: |
|
|
comp_bleu = sacrebleu.compute_bleu |
|
|
fn_sig = inspect.getfullargspec(comp_bleu)[0] |
|
|
if "smooth_method" in fn_sig: |
|
|
smooth = {"smooth_method": "exp"} |
|
|
else: |
|
|
smooth = {"smooth": "exp"} |
|
|
return { |
|
|
"bleu": comp_bleu( |
|
|
correct=sys_stats[:4], |
|
|
total=ref_stats[:4], |
|
|
sys_len=sys_stats[-1], |
|
|
ref_len=ref_stats[-1], |
|
|
**smooth |
|
|
).score |
|
|
} |
|
|
|
|
|
class MergeRouge(object): |
|
|
def __call__(self, evalpreds): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
|
|
|
|
avg_rouge, batch_size = evalpreds[0], evalpreds[1] |
|
|
|
|
|
rouge = (avg_rouge * batch_size).sum() / batch_size.sum() |
|
|
|
|
|
return { |
|
|
"rouge": rouge |
|
|
} |
|
|
|
|
|
|
|
|
class DiscreteDiffusionGenerator: |
|
|
def __init__(self, args, dictionary=None, tokenizer=None) -> None: |
|
|
self.args = args |
|
|
self.dictionary = dictionary |
|
|
self.tokenizer = tokenizer |
|
|
self.write_prediction = None |
|
|
|
|
|
assert (dictionary is not None) or (tokenizer is not None) |
|
|
assert (dictionary is None) ^ (tokenizer is None) |
|
|
|
|
|
self.retain_history = args.return_history |
|
|
|
|
|
if dictionary is not None: |
|
|
self.pad_id = dictionary.pad() |
|
|
self.bos_id = dictionary.bos() |
|
|
self.eos_id = dictionary.eos() |
|
|
self.mask_id = dictionary.mask_index |
|
|
else: |
|
|
self.pad_id = tokenizer.pad_token_id |
|
|
self.bos_id = tokenizer.bos_token_id |
|
|
self.eos_id = tokenizer.eos_token_id |
|
|
self.mask_id = tokenizer.mask_token_id |
|
|
|
|
|
self.rouge = Rouge(["rouge-l"]) |
|
|
|
|
|
def set_write_to(self, path): |
|
|
self.write_prediction = path |
|
|
|
|
|
def _reparam_decoding( |
|
|
self, |
|
|
output_tokens, |
|
|
output_scores, |
|
|
cur_tokens, |
|
|
cur_scores, |
|
|
decoding_strategy, |
|
|
xt_neq_x0, |
|
|
non_special_sym_mask, |
|
|
t, |
|
|
max_step, |
|
|
noise |
|
|
): |
|
|
""" |
|
|
This function is used to perform reparameterized decoding. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_, condition, topk_mode, schedule = decoding_strategy.split("-") |
|
|
|
|
|
|
|
|
if schedule == "linear": |
|
|
rate = 1 - t / max_step |
|
|
elif schedule == "cosine": |
|
|
rate = np.cos(t / max_step * np.pi * 0.5) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
cutoff_len = ( |
|
|
non_special_sym_mask.sum(1, keepdim=True).type_as(output_scores) * rate |
|
|
).long() |
|
|
|
|
|
_scores_for_topk = cur_scores.masked_fill(~non_special_sym_mask, 1000.0) |
|
|
|
|
|
|
|
|
if topk_mode.startswith("stochastic"): |
|
|
noise_scale = float(topk_mode.replace("stochastic", "")) |
|
|
lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=True, temp=noise_scale * rate) |
|
|
elif topk_mode == "deterministic": |
|
|
lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=False) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if condition == "cond": |
|
|
not_v1_t = (cur_tokens == output_tokens) & (cur_scores < output_scores) & lowest_k_mask |
|
|
elif condition == "uncond": |
|
|
not_v1_t = lowest_k_mask |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
not_v2_t = lowest_k_mask |
|
|
|
|
|
masked_to_noise = (~xt_neq_x0 & not_v1_t) | (xt_neq_x0 & not_v2_t) |
|
|
if isinstance(noise, torch.Tensor): |
|
|
output_tokens.masked_scatter_(masked_to_noise, noise[masked_to_noise]) |
|
|
elif isinstance(noise, (int, float)): |
|
|
output_tokens.masked_fill_(masked_to_noise, noise) |
|
|
else: |
|
|
raise NotImplementedError("noise should be either a tensor or a scalar") |
|
|
output_scores.masked_fill_(masked_to_noise, -math.inf) |
|
|
|
|
|
masked_to_x0 = xt_neq_x0 & ~not_v2_t |
|
|
output_tokens.masked_scatter_(masked_to_x0, cur_tokens[masked_to_x0]) |
|
|
output_scores.masked_scatter_(masked_to_x0, cur_scores[masked_to_x0]) |
|
|
|
|
|
|
|
|
|
|
|
new_xt_neq_x0 = (xt_neq_x0 | not_v1_t) & not_v2_t |
|
|
return new_xt_neq_x0 |
|
|
|
|
|
def denoise_step(self, model, decoder_out, partial_masks): |
|
|
output_tokens = decoder_out.output_tokens |
|
|
output_scores = decoder_out.output_scores |
|
|
prev_step, cur_step = decoder_out.step, decoder_out.step + 1 |
|
|
max_step = decoder_out.max_step |
|
|
temperature = self.args.temperature |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logits = model(output_tokens, partial_masks) |
|
|
|
|
|
logits[..., self.mask_id] = -math.inf |
|
|
scores = torch.log_softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
if self.args.strategy == "cmlm": |
|
|
|
|
|
|
|
|
|
|
|
output_masks = output_tokens.eq(self.mask_id) |
|
|
unmask_prob = 1 / (max_step - prev_step) |
|
|
|
|
|
changes = torch.rand(output_tokens.shape, device=output_tokens.device) < unmask_prob |
|
|
|
|
|
changes = torch.bitwise_and(changes, output_masks) |
|
|
|
|
|
if self.args.argmax_decoding: |
|
|
output_scores, new_tokens = scores.max(-1) |
|
|
else: |
|
|
new_tokens = dists.Categorical(logits=scores / temperature).sample() |
|
|
output_scores = torch.gather(scores, -1, new_tokens.unsqueeze(-1)).squeeze(-1) |
|
|
output_tokens[changes] = new_tokens[changes] |
|
|
elif self.args.strategy == "ar": |
|
|
output_masks = output_tokens.eq(self.mask_id) |
|
|
unmask_indices = (output_tokens.ne(self.mask_id) & output_tokens.ne(self.eos_id) & output_tokens.ne(self.pad_id)).sum(dim=-1) |
|
|
indices = torch.arange(output_tokens.size(-1)).expand(output_tokens.shape).to(output_masks.device) |
|
|
if self.args.argmax_decoding: |
|
|
output_scores, new_tokens = scores.max(-1) |
|
|
else: |
|
|
new_tokens = dists.Categorical(logits=scores / temperature).sample() |
|
|
output_scores = torch.gather(scores, -1, new_tokens.unsqueeze(-1)).squeeze(-1) |
|
|
output_tokens[unmask_indices[:, None]==indices] = new_tokens[unmask_indices[:, None]==indices] |
|
|
|
|
|
else: |
|
|
if self.args.argmax_decoding: |
|
|
cur_scores, cur_tokens = scores.max(-1) |
|
|
else: |
|
|
cur_tokens = dists.Categorical(logits=scores / temperature).sample() |
|
|
cur_scores = torch.gather(scores, -1, cur_tokens.unsqueeze(-1)).squeeze(-1) |
|
|
cur_scores = cur_scores.to(output_scores) |
|
|
|
|
|
output_masks = self._reparam_decoding( |
|
|
output_tokens=output_tokens, |
|
|
output_scores=output_scores, |
|
|
cur_tokens=cur_tokens, |
|
|
cur_scores=cur_scores, |
|
|
decoding_strategy=self.args.strategy, |
|
|
xt_neq_x0=decoder_out.output_masks, |
|
|
non_special_sym_mask=decoder_out.non_fixed_sym_masks, |
|
|
t=cur_step, |
|
|
max_step=max_step, |
|
|
noise=self.mask_id |
|
|
) |
|
|
if self.retain_history: |
|
|
history = ([] if decoder_out.history is None else decoder_out.history) + [output_tokens.clone()] |
|
|
else: |
|
|
history = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return decoder_out._replace( |
|
|
step=cur_step, |
|
|
output_tokens=output_tokens, |
|
|
output_scores=output_scores, |
|
|
output_masks=output_masks, |
|
|
history=history, |
|
|
) |
|
|
|
|
|
|
|
|
def decode(self, seqs_tensors, preserve_special=False): |
|
|
seqs_tensors[seqs_tensors < 0] = self.pad_id |
|
|
if self.dictionary is not None: |
|
|
seqs = [ |
|
|
self.dictionary.string(seq, self.args.bpe).strip() |
|
|
for seq in seqs_tensors |
|
|
] |
|
|
if not preserve_special: |
|
|
seqs = [seq.replace(self.dictionary.pad_word, '') for seq in seqs] |
|
|
else: |
|
|
seqs = self.tokenizer.batch_decode(seqs_tensors, skip_special_tokens=(not preserve_special)) |
|
|
return [seq.lower() for seq in seqs] |
|
|
|
|
|
def compute_bleu(self, hyps, refs): |
|
|
if isinstance(hyps, torch.Tensor): |
|
|
hyps = self.decode(hyps) |
|
|
if isinstance(refs, torch.Tensor): |
|
|
refs = self.decode(refs) |
|
|
return sacrebleu.corpus_bleu(hyps, [refs], tokenize=self.args.bleu_tokenize) |
|
|
|
|
|
def compute_rouge(self, hyps, refs): |
|
|
if isinstance(hyps, torch.Tensor): |
|
|
hyps = self.decode(hyps) |
|
|
if isinstance(refs, torch.Tensor): |
|
|
refs = self.decode(refs) |
|
|
return self.rouge.get_scores(hyps, [[ref] for ref in refs])['rouge-l']['f'] |
|
|
|
|
|
def stepwise_generate(self, model, inputs): |
|
|
src_tokens = inputs["net_input"]["src_tokens"] |
|
|
partial_masks = inputs["net_input"]["partial_masks"] |
|
|
|
|
|
|
|
|
|
|
|
raw_model = model.module if hasattr(model, "module") else model |
|
|
if "prefix_masks" in inputs["net_input"]: |
|
|
prefix_masks = inputs["net_input"]["prefix_masks"] |
|
|
else: |
|
|
prefix_masks = partial_masks |
|
|
|
|
|
partial_masks, prev_decoder_out = raw_model.initialize_decode_samples( |
|
|
src_tokens, partial_masks, prefix_masks, oracle_length=self.args.oracle_length, length_beam=self.args.length_beam, mbr=self.args.mbr |
|
|
) |
|
|
prev_decoder_out = prev_decoder_out._replace( |
|
|
step=0, max_step=self.args.max_iterations |
|
|
) |
|
|
for step in range(self.args.max_iterations): |
|
|
prev_decoder_out = self.denoise_step(model, prev_decoder_out, partial_masks) |
|
|
yield prev_decoder_out |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, model, inputs): |
|
|
src_tokens = inputs["net_input"]["src_tokens"] |
|
|
partial_masks = inputs["net_input"]["partial_masks"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "prefix_masks" in inputs["net_input"]: |
|
|
prefix_masks = inputs["net_input"]["prefix_masks"] |
|
|
else: |
|
|
prefix_masks = partial_masks |
|
|
partial_masks, prev_decoder_out = model.initialize_decode_samples( |
|
|
src_tokens, partial_masks, prefix_masks, oracle_length=self.args.oracle_length, length_beam=self.args.length_beam, mbr=self.args.mbr |
|
|
) |
|
|
prev_decoder_out = prev_decoder_out._replace( |
|
|
step=0, max_step=self.args.max_iterations |
|
|
) |
|
|
|
|
|
for step in range(self.args.max_iterations): |
|
|
prev_decoder_out = self.denoise_step(model, prev_decoder_out, partial_masks) |
|
|
|
|
|
def finalized_hypos(tokens, scores, partial_mask, history=None): |
|
|
cutoff = ( |
|
|
tokens.ne(self.pad_id) & |
|
|
tokens.ne(self.bos_id) & |
|
|
tokens.ne(self.eos_id) & |
|
|
(~partial_mask) |
|
|
) |
|
|
tokens = tokens[cutoff] |
|
|
if scores is None: |
|
|
score = None |
|
|
else: |
|
|
scores = scores[cutoff] |
|
|
score = scores.mean().item() |
|
|
ret_dict = { |
|
|
"tokens": tokens, |
|
|
"positional_scores": scores, |
|
|
"score": score, |
|
|
"alignment": None |
|
|
} |
|
|
if history is not None: |
|
|
ret_dict["history"] = [ |
|
|
finalized_hypos(history_tokens, None, partial_mask, history=None) |
|
|
for history_tokens in history |
|
|
] |
|
|
return ret_dict |
|
|
|
|
|
def mbr_select(hyps): |
|
|
index = np.argmax(np.array( |
|
|
[self.rouge.get_scores([hyps[i]], [[hyps[j]]])['rouge-l']['f'] |
|
|
for j in range(len(hyps)) if i != j] |
|
|
).mean() for i in range(len(hyps))) |
|
|
return hyps[index] |
|
|
|
|
|
def score_select(hyps): |
|
|
index = np.argmax([hyp["score"] for hyp in hyps]) |
|
|
return hyps[index] |
|
|
|
|
|
output_tokens, output_scores = prev_decoder_out.output_tokens, prev_decoder_out.output_scores |
|
|
if self.retain_history: |
|
|
full_history = prev_decoder_out.history |
|
|
histories = [[full_history[j][i] for j in range(self.args.max_iterations)] for i in range(output_tokens.size(0))] |
|
|
hyps = [] |
|
|
for tokens, scores, partial_mask, history in zip(output_tokens, output_scores, partial_masks, histories): |
|
|
hyps.append(finalized_hypos(tokens, scores, partial_mask, history)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
hyps = [ |
|
|
finalized_hypos(tokens, scores, partial_mask, None) |
|
|
for tokens, scores, partial_mask in zip(output_tokens, output_scores, partial_masks) |
|
|
] |
|
|
repeatition = self.args.mbr * self.args.length_beam |
|
|
if repeatition > 1: |
|
|
hyps = [score_select(hyps[i:i+repeatition])for i in range(0, len(hyps), repeatition)] |
|
|
|
|
|
|
|
|
finalized = pad_sequence([h["tokens"] for h in hyps ], batch_first=True, padding_value=self.pad_id) |
|
|
history = [[item["tokens"] for item in h["history"]] for h in hyps] if self.retain_history else None |
|
|
return finalized, history |