distilbert-another-classifier / subclassing_for_string_conversion.py
Hristo-Karagyozov's picture
Upload subclassing_for_string_conversion.py
9434e7d verified
from transformers import AutoModelForSequenceClassification
import torch
class FineTunedDistilBertWithStringLabels(AutoModelForSequenceClassification):
def __init__(self, config):
super().__init__(config)
self.label_dict = {
0: "Clarification",
1: "Factual",
2: "Operational",
3: "Summarization"
}
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None):
# Perform the usual forward pass
outputs = super().forward(input_ids, attention_mask, token_type_ids, labels)
# Get logits (raw model output)
logits = outputs.logits
predicted_class_index = torch.argmax(logits, dim=-1).item() # .item() -> extract regular python number
# Map index -> string label using dictionary
outputs.label = self.label_dict[predicted_class_index]
return outputs