File size: 3,087 Bytes
712d350 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | import numpy as np
import torch
from extra_utils import res_to_list, res_to_seq
class AbScores:
def __init__(self, device = 'cpu', ncpu = 1):
self.device = device
self.ncpu = ncpu
def _initiate_abencoding(self, model, tokenizer):
self.AbLang = model
self.tokenizer = tokenizer
def _encode_sequences(self, seqs):
tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
with torch.no_grad():
return self.AbLang.AbRep(tokens).last_hidden_states.numpy()
def _predict_logits(self, seqs):
tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
with torch.no_grad():
return self.AbLang(tokens), tokens
def pseudo_log_likelihood(self, seqs, **kwargs):
"""
Pseudo log likelihood of sequences.
"""
plls = []
for seq in seqs:
labels = self.tokenizer(
seq, pad=True, w_extra_tkns=False, device=self.used_device
)
idxs = (
~torch.isin(labels, torch.Tensor(self.tokenizer.all_special_tokens).to(self.used_device))
).nonzero()
masked_tokens = labels.repeat(len(idxs), 1)
for num, idx in enumerate(idxs):
masked_tokens[num, idx[1]] = self.tokenizer.mask_token
with torch.no_grad():
logits = self.AbLang(masked_tokens)
logits[:, :, self.tokenizer.all_special_tokens] = -float("inf")
logits = torch.stack([logits[num, idx[1]] for num, idx in enumerate(idxs)])
labels = labels[:,idxs[:,1:]].squeeze(2)[0]
nll = torch.nn.functional.cross_entropy(
logits,
labels,
reduction="mean",
)
pll = -nll
plls.append(pll)
plls = torch.stack(plls, dim=0).cpu().numpy()
return plls
def confidence(self, seqs, **kwargs):
"""
Log likelihood of sequences without masking.
"""
labels = self.tokenizer(
seqs, pad=True, w_extra_tkns=False, device=self.used_device
)
with torch.no_grad():
logits = self.AbLang(labels)
logits[:, :, self.tokenizer.all_special_tokens] = -float("inf")
plls = []
for label, logit in zip(labels, logits):
idxs = (
~torch.isin(label, torch.Tensor(self.tokenizer.all_special_tokens).to(self.used_device))
).nonzero().squeeze(1)
nll = torch.nn.functional.cross_entropy(
logit[idxs],
label[idxs],
reduction="mean",
)
pll = -nll
plls.append(pll)
return torch.stack(plls, dim=0).cpu().numpy() |