File size: 903 Bytes
6ae42f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from transformers import DistilBertForSequenceClassification

label_dict = {0: "Clarification", 1: "Factual", 2: "Operational", 3: "Summarization"}


class CustomDistilBertClassifier(DistilBertForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)
        self.label_map = label_dict  # Use your predefined label mapping

    def forward(self, input_ids, attention_mask):
        outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        predicted_class = torch.argmax(logits, dim=-1).item()
        return self.label_map[predicted_class]  # Return the string label directly

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
        return model