Clemylia commited on
Commit
9b4c36a
·
verified ·
1 Parent(s): 662a0e8

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +88 -0
handler.py CHANGED
@@ -1,2 +1,90 @@
1
 
2
  # (copier-coller le contenu du handler.py ici)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  # (copier-coller le contenu du handler.py ici)
3
+ # handler.py
4
+
5
+ import json
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers.utils import is_torch_available
9
+
10
+ # On va utiliser le tokenizer et le modèle que nous avons créés
11
+ def simple_tokenizer(text):
12
+ return text.lower().split()
13
+
14
+ class SimpleClassifier(nn.Module):
15
+ def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
16
+ super().__init__()
17
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
18
+ self.lstm = nn.LSTM(embedding_dim, hidden_dim)
19
+ self.fc = nn.Linear(hidden_dim, output_dim)
20
+
21
+ def forward(self, text):
22
+ embedded = self.embedding(text)
23
+ _, (hidden, _) = self.lstm(embedded.view(len(text), 1, -1))
24
+ output = self.fc(hidden.squeeze(0))
25
+ return output
26
+
27
+ class InferenceHandler:
28
+ def __init__(self):
29
+ self.initialized = False
30
+ self.word_to_idx = None
31
+ self.model = None
32
+
33
+ def initialize(self, context):
34
+ # Cette fonction est appelée une seule fois pour charger le modèle
35
+ # On charge le vocabulaire
36
+ vocab_path = "vocab.json"
37
+ with open(vocab_path, "r") as f:
38
+ self.word_to_idx = json.load(f)
39
+
40
+ # On charge la configuration du modèle
41
+ config_path = "config.json"
42
+ with open(config_path, "r") as f:
43
+ config = json.load(f)
44
+
45
+ # On crée le modèle
46
+ self.model = SimpleClassifier(
47
+ vocab_size=config['vocab_size'],
48
+ embedding_dim=config['embedding_dim'],
49
+ hidden_dim=config['hidden_dim'],
50
+ output_dim=config['output_dim']
51
+ )
52
+
53
+ # On charge les poids entraînés
54
+ model_path = "pytorch_model.bin"
55
+ self.model.load_state_dict(torch.load(model_path))
56
+
57
+ # On met le modèle en mode évaluation
58
+ self.model.eval()
59
+ self.initialized = True
60
+
61
+ def preprocess(self, inputs):
62
+ # Cette fonction traite les données d'entrée avant l'inférence
63
+ # 'inputs' est le dictionnaire envoyé par l'API
64
+ text = inputs.get("inputs", "")
65
+ if not text:
66
+ raise ValueError("Aucun texte fourni pour l'inférence.")
67
+
68
+ # Tokenisation
69
+ tokens = simple_tokenizer(text)
70
+ token_indices = [self.word_to_idx.get(token, 0) for token in tokens]
71
+
72
+ # Création du tenseur
73
+ input_tensor = torch.tensor(token_indices, dtype=torch.long)
74
+
75
+ return input_tensor.view(-1, 1)
76
+
77
+ def inference(self, input_tensor):
78
+ # Cette fonction fait la prédiction
79
+ with torch.no_grad():
80
+ outputs = self.model(input_tensor)
81
+ return outputs
82
+
83
+ def postprocess(self, outputs):
84
+ # Cette fonction convertit la sortie du modèle en un format lisible
85
+ prediction = torch.argmax(outputs, dim=1).item()
86
+
87
+ labels = {0: "Animaux", 1: "Capitales"}
88
+ predicted_label = labels.get(prediction, "Inconnu")
89
+
90
+ return [{"label": predicted_label, "score": outputs.softmax(dim=1)[0][prediction].item()}]