| import torch | |
| from transformers import DistilBertForSequenceClassification | |
| label_dict = {0: "Clarification", 1: "Factual", 2: "Operational", 3: "Summarization"} | |
| class CustomDistilBertClassifier(DistilBertForSequenceClassification): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.label_map = label_dict # Use your predefined label mapping | |
| def forward(self, input_ids, attention_mask): | |
| outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask) | |
| logits = outputs.logits | |
| predicted_class = torch.argmax(logits, dim=-1).item() | |
| return self.label_map[predicted_class] # Return the string label directly | |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
| model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) | |
| return model | |