distilbert-another-classifier / modeling_custom.py
Hristo-Karagyozov's picture
Upload folder using huggingface_hub
6ae42f6 verified
import torch
from transformers import DistilBertForSequenceClassification
label_dict = {0: "Clarification", 1: "Factual", 2: "Operational", 3: "Summarization"}
class CustomDistilBertClassifier(DistilBertForSequenceClassification):
def __init__(self, config):
super().__init__(config)
self.label_map = label_dict # Use your predefined label mapping
def forward(self, input_ids, attention_mask):
outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=-1).item()
return self.label_map[predicted_class] # Return the string label directly
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
return model