File size: 3,087 Bytes
712d350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import numpy as np
import torch

from extra_utils import res_to_list, res_to_seq


class AbScores:

    def __init__(self, device = 'cpu', ncpu = 1):
        
        self.device = device
        self.ncpu = ncpu
        
    def _initiate_abencoding(self, model, tokenizer):
        self.AbLang = model
        self.tokenizer = tokenizer
        
    def _encode_sequences(self, seqs):
        tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
        with torch.no_grad():
            return self.AbLang.AbRep(tokens).last_hidden_states.numpy()
        
    def _predict_logits(self, seqs):
        tokens = self.tokenizer(seqs, pad=True, w_extra_tkns=False, device=self.used_device)
        with torch.no_grad():
            return self.AbLang(tokens), tokens
        
    def pseudo_log_likelihood(self, seqs, **kwargs):
        """
        Pseudo log likelihood of sequences.
        """
        
        plls = []
        for seq in seqs:
            
            labels = self.tokenizer(
                seq, pad=True, w_extra_tkns=False, device=self.used_device
            )
            
            idxs = (
                ~torch.isin(labels, torch.Tensor(self.tokenizer.all_special_tokens).to(self.used_device))
            ).nonzero()
            
            masked_tokens = labels.repeat(len(idxs), 1)
            for num, idx in enumerate(idxs):
                masked_tokens[num, idx[1]] = self.tokenizer.mask_token
            
            with torch.no_grad():
                logits = self.AbLang(masked_tokens)
            
            logits[:, :, self.tokenizer.all_special_tokens] = -float("inf")            
            logits = torch.stack([logits[num, idx[1]] for num, idx in enumerate(idxs)])
  
            labels = labels[:,idxs[:,1:]].squeeze(2)[0]

            nll = torch.nn.functional.cross_entropy(
                    logits,
                    labels,
                    reduction="mean",
                )
            
            pll = -nll

            plls.append(pll)
    
        plls = torch.stack(plls, dim=0).cpu().numpy()

        return plls  
    
    def confidence(self, seqs, **kwargs):
        """
        Log likelihood of sequences without masking.
        """
        
        labels = self.tokenizer(
                seqs, pad=True, w_extra_tkns=False, device=self.used_device
            )
        with torch.no_grad():
            logits = self.AbLang(labels)
            logits[:, :, self.tokenizer.all_special_tokens] = -float("inf")  
            
        plls = []
        for label, logit in zip(labels, logits):
            
            idxs = (
                ~torch.isin(label, torch.Tensor(self.tokenizer.all_special_tokens).to(self.used_device))
            ).nonzero().squeeze(1)

            nll = torch.nn.functional.cross_entropy(
                        logit[idxs],
                        label[idxs],
                        reduction="mean",
                    )

            pll = -nll
            plls.append(pll)

        return torch.stack(plls, dim=0).cpu().numpy()