import torch from transformers import DistilBertForSequenceClassification label_dict = {0: "Clarification", 1: "Factual", 2: "Operational", 3: "Summarization"} class CustomDistilBertClassifier(DistilBertForSequenceClassification): def __init__(self, config): super().__init__(config) self.label_map = label_dict # Use your predefined label mapping def forward(self, input_ids, attention_mask): outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask) logits = outputs.logits predicted_class = torch.argmax(logits, dim=-1).item() return self.label_map[predicted_class] # Return the string label directly @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return model