Spaces:
Running
Running
| # Fichier : core_logic.py | |
| import sqlite3 | |
| import datetime | |
| import re | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import uuid # <-- AJOUT NÉCESSAIRE pour générer des IDs uniques | |
| # --- CONFIGURATION GLOBALE --- | |
| DB_NAME = "charlotte_apy.db" | |
| MODEL_NAME = "Clemylia/Tiny-charlotte" | |
| MAX_QUOTA = 1200 | |
| MAX_TOKENS_PER_RESPONSE = 100 | |
| # ---------------------------------------------------- | |
| # A. LOGIQUE DE BASE DE DONNÉES SQLite (Fonctions réutilisables) | |
| # ---------------------------------------------------- | |
| def init_db(): | |
| conn = sqlite3.connect(DB_NAME) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS api_keys ( | |
| key_id TEXT PRIMARY KEY, | |
| quota_remaining INTEGER, | |
| max_quota INTEGER, | |
| date_last_use TEXT | |
| ) | |
| """) | |
| conn.commit() | |
| conn.close() | |
| def get_all_keys(): | |
| conn = sqlite3.connect(DB_NAME) | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT key_id, quota_remaining, max_quota, date_last_use FROM api_keys") | |
| rows = cursor.fetchall() | |
| conn.close() | |
| keys = {} | |
| for row in rows: | |
| keys[row[0]] = {'quota_remaining': row[1], 'max_quota': row[2], 'date_last_use': row[3]} | |
| return keys | |
| def get_key_data(key_id): | |
| conn = sqlite3.connect(DB_NAME) | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT quota_remaining, max_quota, date_last_use FROM api_keys WHERE key_id = ?", (key_id,)) | |
| row = cursor.fetchone() | |
| conn.close() | |
| if row: | |
| return {'quota_remaining': row[0], 'max_quota': row[1], 'date_last_use': row[2]} | |
| return None | |
| def add_key_to_db(key_id, max_quota=MAX_QUOTA): | |
| conn = sqlite3.connect(DB_NAME) | |
| cursor = conn.cursor() | |
| today = datetime.date.today().isoformat() | |
| try: | |
| cursor.execute("INSERT INTO api_keys (key_id, quota_remaining, max_quota, date_last_use) VALUES (?, ?, ?, ?)", (key_id, max_quota, max_quota, today)) | |
| conn.commit() | |
| conn.close() | |
| return True | |
| except sqlite3.IntegrityError: | |
| conn.close() | |
| return False | |
| # 🔑 NOUVELLE FONCTION AJOUTÉE POUR CRÉER ET ENREGISTRER UNE CLÉ UNIQUE | |
| def create_api_key(max_quota=MAX_QUOTA): | |
| """ | |
| Génère une clé d'API unique (Tn-charlotte-...) et l'enregistre immédiatement | |
| dans la base de données via add_key_to_db. | |
| """ | |
| # Générer une clé unique au format souhaité | |
| unique_id = uuid.uuid4().hex[:20].upper() | |
| api_key = f"Tn-charlotte-{unique_id}" | |
| # Validation minimale (vérification du format) | |
| is_valid, _ = validate_key(api_key) | |
| # Enregistrement dans la DB | |
| if is_valid and add_key_to_db(api_key, max_quota): | |
| return api_key | |
| return None # Retourne None en cas d'échec d'insertion ou de validation | |
| def delete_key_from_db(key_id): | |
| conn = sqlite3.connect(DB_NAME) | |
| cursor = conn.cursor() | |
| cursor.execute("DELETE FROM api_keys WHERE key_id = ?", (key_id,)) | |
| conn.commit() | |
| conn.close() | |
| def update_key_quota_in_db(key_id, new_remaining_quota, new_date_last_use): | |
| conn = sqlite3.connect(DB_NAME) | |
| cursor = conn.cursor() | |
| cursor.execute("UPDATE api_keys SET quota_remaining = ?, date_last_use = ? WHERE key_id = ?", (new_remaining_quota, new_date_last_use, key_id)) | |
| conn.commit() | |
| conn.close() | |
| def reset_key_quota_in_db(key_id): | |
| conn = sqlite3.connect(DB_NAME) | |
| cursor = conn.cursor() | |
| today = datetime.date.today().isoformat() | |
| cursor.execute("UPDATE api_keys SET quota_remaining = max_quota, date_last_use = ? WHERE key_id = ?", (today, key_id)) | |
| conn.commit() | |
| conn.close() | |
| def validate_key(key_str): | |
| # (Logique de validation inchangée) | |
| if not key_str.startswith("Tn-charlotte"): | |
| return False, "La clé doit commencer par Tn-charlotte." | |
| num_digits = len(re.findall(r'\d', key_str)) | |
| if num_digits < 5: | |
| return False, f"La clé doit contenir au moins 5 chiffres (actuel : {num_digits})." | |
| num_letters = len(re.findall(r'[a-zA-Z]', key_str)) | |
| if num_letters < 7: | |
| return False, f"La clé doit contenir au moins 7 lettres (actuel : {num_letters})." | |
| return True, "Clé valide !" | |
| # Appel initial | |
| init_db() | |
| # ---------------------------------------------------- | |
| # B. CHARGEMENT DU MODÈLE (Chargement unique pour le serveur) | |
| # ---------------------------------------------------- | |
| # Ces variables seront utilisées par le serveur FastAPI | |
| MODEL, TOKENIZER, DEVICE = None, None, None | |
| def load_tiny_charlotte_server(): | |
| """Charge le modèle une seule fois pour le serveur API.""" | |
| global MODEL, TOKENIZER, DEVICE | |
| if MODEL is None: | |
| try: | |
| print("INFO: Chargement du modèle Tiny-Charlotte...") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| MODEL = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE) | |
| print(f"INFO: Modèle chargé avec succès sur {DEVICE}.") | |
| except Exception as e: | |
| print(f"ERREUR: Échec du chargement du modèle: {e}") | |
| # Le serveur FastAPI appellera cette fonction au démarrage. | |
| # ---------------------------------------------------- | |
| # C. LOGIQUE D'INFÉRENCE POUR L'API (Retourne données et code statut) | |
| # ---------------------------------------------------- | |
| def run_inference_api(api_key, prompt): | |
| """ | |
| Exécute l'inférence et gère le quota. | |
| Retourne (données JSON, code statut HTTP). | |
| """ | |
| if MODEL is None or TOKENIZER is None: | |
| return {"error": "Internal Server Error: Model not loaded."}, 500 | |
| today = datetime.date.today().isoformat() | |
| key_data = get_key_data(api_key) | |
| if key_data is None: | |
| # 401 Unauthorized: Clé invalide | |
| return {"error": "Unauthorized: Invalid API Key."}, 401 | |
| # Réinitialisation automatique du Quota | |
| if key_data['date_last_use'] != today: | |
| key_data['quota_remaining'] = key_data['max_quota'] | |
| key_data['date_last_use'] = today | |
| update_key_quota_in_db(api_key, key_data['quota_remaining'], today) | |
| # Vérification du Quota | |
| if key_data['quota_remaining'] < MAX_TOKENS_PER_RESPONSE: | |
| # 429 Too Many Requests: Quota atteint | |
| return { | |
| "error": "Quota Exceeded: Daily token limit reached.", | |
| "usage": {"tokens_remaining": key_data['quota_remaining'], "limit": MAX_QUOTA} | |
| }, 429 | |
| try: | |
| # Encodage et Génération | |
| input_ids = TOKENIZER.encode(prompt, return_tensors='pt').to(DEVICE) | |
| output = MODEL.generate( | |
| input_ids, | |
| max_length=input_ids.shape[1] + MAX_TOKENS_PER_RESPONSE, | |
| do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1, | |
| pad_token_id=TOKENIZER.eos_token_id | |
| ) | |
| response_text = TOKENIZER.decode(output[0], skip_special_tokens=True) | |
| tokens_generated = output.shape[1] - input_ids.shape[1] | |
| except Exception as e: | |
| print(f"ERREUR d'inférence: {e}") | |
| return {"error": "Internal Server Error: Inference failed."}, 500 | |
| # Mise à Jour du Quota DANS LA BASE DE DONNÉES | |
| new_remaining = key_data['quota_remaining'] - tokens_generated | |
| update_key_quota_in_db(api_key, new_remaining, today) | |
| # Retour de la réponse (200 OK) | |
| return { | |
| "generated_text": response_text, | |
| "model": MODEL_NAME, | |
| "usage": { | |
| "tokens_used": tokens_generated, | |
| "tokens_remaining": new_remaining, | |
| "limit": MAX_QUOTA | |
| } | |
| }, 200 | |
| # ---------------------------------------------------- | |
| # D. LOGIQUE D'INFÉRENCE POUR STREAMLIT (Nécessite les outils Streamlit) | |
| # ---------------------------------------------------- | |
| # (Ceci est un placeholder, la logique Streamlit sera dans charlotte_apy.py) |