import os import json from typing import Optional from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from huggingface_hub import InferenceClient import uvicorn import urllib.request import urllib.error app = FastAPI() # Token da Settings → Secrets (obbligatorio per Inference API / Inference Endpoints) HF_TOKEN = os.environ.get("HF_TOKEN") # Modello: default Ministral-3-14B-Instruct (multilingue, system prompt, Apache 2.0) # Se usi Inference Endpoints, imposta HF_INFERENCE_ENDPOINT_URL con l'URL dell'endpoint MODEL_ID = os.environ.get("MODEL_ID", "mistralai/Ministral-3-14B-Instruct-2512") INFERENCE_ENDPOINT_URL = os.environ.get("HF_INFERENCE_ENDPOINT_URL") # Se hai un Inference Endpoint dedicato (es. per Ministral-3-14B), usalo; altrimenti provider auto/hf-inference if INFERENCE_ENDPOINT_URL: client = InferenceClient(model=INFERENCE_ENDPOINT_URL, token=HF_TOKEN) client_hf = client else: client = InferenceClient(model=MODEL_ID, token=HF_TOKEN) client_hf = InferenceClient(model=MODEL_ID, token=HF_TOKEN, provider="hf-inference") def build_messages(system_prompt: str, history: list, prompt: str) -> list: """Costruisce la lista messages per chat_completion (task conversational).""" messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) for pair in history or []: user_msg = pair[0] if isinstance(pair, (list, tuple)) else pair.get("user", "") bot_msg = pair[1] if isinstance(pair, (list, tuple)) else pair.get("bot", "") messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": bot_msg}) messages.append({"role": "user", "content": prompt}) return messages def format_prompt_for_text_generation(system_prompt: str, history: list, prompt: str) -> str: """Formato Mistral [INST] per text_generation quando il provider non supporta chat.""" full_message = f"{system_prompt}, {prompt}" if system_prompt else prompt out = "" for pair in history or []: user_msg = pair[0] if isinstance(pair, (list, tuple)) else pair.get("user", "") bot_msg = pair[1] if isinstance(pair, (list, tuple)) else pair.get("bot", "") out += f"[INST] {user_msg} [/INST] {bot_msg} " out += f"[INST] {full_message} [/INST]" return out def chat_completion_via_http(messages: list, max_tokens: int, temperature: float, top_p: float) -> Optional[str]: """ Chiamata diretta all'endpoint HF chat completions (v1). Usata quando il SDK fallisce perché il modello non dichiara il task (es. Ministral-3). """ if not HF_TOKEN: return None base = INFERENCE_ENDPOINT_URL.rstrip("/") if INFERENCE_ENDPOINT_URL else f"https://api-inference.huggingface.co/models/{MODEL_ID}" url = f"{base}/v1/chat/completions" body = { "messages": messages, "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p, } data = json.dumps(body).encode("utf-8") req = urllib.request.Request( url, data=data, headers={ "Authorization": f"Bearer {HF_TOKEN}", "Content-Type": "application/json", }, method="POST", ) try: with urllib.request.urlopen(req, timeout=120) as resp: out = json.loads(resp.read().decode()) except urllib.error.HTTPError as e: if e.code == 503: err_body = e.read().decode() if e.fp else "" try: err_json = json.loads(err_body) if "estimated_time" in err_json: raise HTTPException( status_code=503, detail=f"Modello in caricamento. Riprova tra {err_json.get('estimated_time', 0):.0f}s.", ) except (ValueError, TypeError): pass return None except Exception: return None choices = out.get("choices") or [] if not choices: return None msg = choices[0].get("message") or {} return (msg.get("content") or "").strip() class Item(BaseModel): """Payload per /generate/. Per Ministral-3: temperature < 0.1 in produzione (raccomandato).""" prompt: str history: list = Field(default_factory=list, description="Lista di coppie (user_msg, bot_msg)") system_prompt: str = "" temperature: float = Field(default=0.0, description="Ministral-3: < 0.1 per produzione") max_new_tokens: int = 1048 top_p: float = 0.15 repetition_penalty: float = 1.0 def generate(item: Item) -> str: if not HF_TOKEN: raise HTTPException( status_code=500, detail="HF_TOKEN non configurato. Imposta il secret HF_TOKEN in Settings → Secrets dello Space.", ) temperature = float(item.temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(item.top_p) messages = build_messages(item.system_prompt, item.history or [], item.prompt) last_error = None # 1) Prova chat_completion con provider di default (es. Together) try: response = client.chat_completion( messages=messages, max_tokens=item.max_new_tokens, temperature=temperature, top_p=top_p, ) if response.choices and len(response.choices) > 0: return response.choices[0].message.content or "" return "" except Exception as e1: last_error = e1 # 2) Riprova chat con hf-inference (supporta modelli come Mixtral-8x7B che su altri provider non sono "chat") try: response = client_hf.chat_completion( messages=messages, max_tokens=item.max_new_tokens, temperature=temperature, top_p=top_p, ) if response.choices and len(response.choices) > 0: return response.choices[0].message.content or "" return "" except Exception as e2: last_error = e2 # 3) Chat completions via HTTP (endpoint v1) – funziona per modelli che non dichiarano il task (es. Ministral-3) try: content = chat_completion_via_http( messages, item.max_new_tokens, temperature, top_p ) if content is not None and content != "": return content except HTTPException: raise except Exception as e3: last_error = e3 # 4) Ultima risorsa: text_generation (solo per modelli che lo supportano su hf-inference) try: formatted = format_prompt_for_text_generation( item.system_prompt, item.history or [], item.prompt ) stream = client_hf.text_generation( formatted, max_new_tokens=item.max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, return_full_text=False, ) if isinstance(stream, str): return stream if hasattr(stream, "__iter__") and not isinstance(stream, str): return "".join( r.token.text if hasattr(r, "token") and hasattr(r.token, "text") else (r if isinstance(r, str) else "") for r in stream ) return str(stream) except Exception as e4: last_error = e4 raise HTTPException(status_code=502, detail=f"Inference fallita: {str(last_error)}") @app.get("/") def root(): return {"status": "ok", "model": MODEL_ID} @app.get("/health") def health(): return {"status": "ok"} @app.post("/generate/") async def generate_text(item: Item): return {"response": generate(item)}