Hristo-Karagyozov commited on
Commit
9434e7d
·
verified ·
1 Parent(s): b49d532

Upload subclassing_for_string_conversion.py

Browse files
subclassing_for_string_conversion.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSequenceClassification
2
+ import torch
3
+
4
+
5
+ class FineTunedDistilBertWithStringLabels(AutoModelForSequenceClassification):
6
+ def __init__(self, config):
7
+ super().__init__(config)
8
+ self.label_dict = {
9
+ 0: "Clarification",
10
+ 1: "Factual",
11
+ 2: "Operational",
12
+ 3: "Summarization"
13
+ }
14
+
15
+ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None):
16
+ # Perform the usual forward pass
17
+ outputs = super().forward(input_ids, attention_mask, token_type_ids, labels)
18
+
19
+ # Get logits (raw model output)
20
+ logits = outputs.logits
21
+ predicted_class_index = torch.argmax(logits, dim=-1).item() # .item() -> extract regular python number
22
+
23
+ # Map index -> string label using dictionary
24
+ outputs.label = self.label_dict[predicted_class_index]
25
+ return outputs