from transformers import AutoModelForSequenceClassification import torch class FineTunedDistilBertWithStringLabels(AutoModelForSequenceClassification): def __init__(self, config): super().__init__(config) self.label_dict = { 0: "Clarification", 1: "Factual", 2: "Operational", 3: "Summarization" } def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None): # Perform the usual forward pass outputs = super().forward(input_ids, attention_mask, token_type_ids, labels) # Get logits (raw model output) logits = outputs.logits predicted_class_index = torch.argmax(logits, dim=-1).item() # .item() -> extract regular python number # Map index -> string label using dictionary outputs.label = self.label_dict[predicted_class_index] return outputs