| from transformers import AutoModelForSequenceClassification | |
| import torch | |
| class FineTunedDistilBertWithStringLabels(AutoModelForSequenceClassification): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.label_dict = { | |
| 0: "Clarification", | |
| 1: "Factual", | |
| 2: "Operational", | |
| 3: "Summarization" | |
| } | |
| def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None): | |
| # Perform the usual forward pass | |
| outputs = super().forward(input_ids, attention_mask, token_type_ids, labels) | |
| # Get logits (raw model output) | |
| logits = outputs.logits | |
| predicted_class_index = torch.argmax(logits, dim=-1).item() # .item() -> extract regular python number | |
| # Map index -> string label using dictionary | |
| outputs.label = self.label_dict[predicted_class_index] | |
| return outputs | |