Limy-basique / handler.py
Clemylia's picture
Ajout du handler et des requirements pour l'API d'inférence
48885c1 verified
# (copier-coller le contenu du handler.py ici)
# handler.py
import json
import torch
import torch.nn as nn
from transformers.utils import is_torch_available
# On va utiliser le tokenizer et le modèle que nous avons créés
def simple_tokenizer(text):
return text.lower().split()
class SimpleClassifier(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, text):
embedded = self.embedding(text)
_, (hidden, _) = self.lstm(embedded.view(len(text), 1, -1))
output = self.fc(hidden.squeeze(0))
return output
class InferenceHandler:
def __init__(self):
self.initialized = False
self.word_to_idx = None
self.model = None
def initialize(self, context):
# Cette fonction est appelée une seule fois pour charger le modèle
# On charge le vocabulaire
vocab_path = "vocab.json"
with open(vocab_path, "r") as f:
self.word_to_idx = json.load(f)
# On charge la configuration du modèle
config_path = "config.json"
with open(config_path, "r") as f:
config = json.load(f)
# On crée le modèle
self.model = SimpleClassifier(
vocab_size=config['vocab_size'],
embedding_dim=config['embedding_dim'],
hidden_dim=config['hidden_dim'],
output_dim=config['output_dim']
)
# On charge les poids entraînés
model_path = "pytorch_model.bin"
self.model.load_state_dict(torch.load(model_path))
# On met le modèle en mode évaluation
self.model.eval()
self.initialized = True
def preprocess(self, inputs):
# Cette fonction traite les données d'entrée avant l'inférence
# 'inputs' est le dictionnaire envoyé par l'API
text = inputs.get("inputs", "")
if not text:
raise ValueError("Aucun texte fourni pour l'inférence.")
# Tokenisation
tokens = simple_tokenizer(text)
token_indices = [self.word_to_idx.get(token, 0) for token in tokens]
# Création du tenseur
input_tensor = torch.tensor(token_indices, dtype=torch.long)
return input_tensor.view(-1, 1)
def inference(self, input_tensor):
# Cette fonction fait la prédiction
with torch.no_grad():
outputs = self.model(input_tensor)
return outputs
def postprocess(self, outputs):
# Cette fonction convertit la sortie du modèle en un format lisible
prediction = torch.argmax(outputs, dim=1).item()
labels = {0: "Animaux", 1: "Capitales"}
predicted_label = labels.get(prediction, "Inconnu")
return [{"label": predicted_label, "score": outputs.softmax(dim=1)[0][prediction].item()}]