Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| from typing import Optional | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| from huggingface_hub import InferenceClient | |
| import uvicorn | |
| import urllib.request | |
| import urllib.error | |
| app = FastAPI() | |
| # Token da Settings → Secrets (obbligatorio per Inference API / Inference Endpoints) | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # Modello: default Ministral-3-14B-Instruct (multilingue, system prompt, Apache 2.0) | |
| # Se usi Inference Endpoints, imposta HF_INFERENCE_ENDPOINT_URL con l'URL dell'endpoint | |
| MODEL_ID = os.environ.get("MODEL_ID", "mistralai/Ministral-3-14B-Instruct-2512") | |
| INFERENCE_ENDPOINT_URL = os.environ.get("HF_INFERENCE_ENDPOINT_URL") | |
| # Se hai un Inference Endpoint dedicato (es. per Ministral-3-14B), usalo; altrimenti provider auto/hf-inference | |
| if INFERENCE_ENDPOINT_URL: | |
| client = InferenceClient(model=INFERENCE_ENDPOINT_URL, token=HF_TOKEN) | |
| client_hf = client | |
| else: | |
| client = InferenceClient(model=MODEL_ID, token=HF_TOKEN) | |
| client_hf = InferenceClient(model=MODEL_ID, token=HF_TOKEN, provider="hf-inference") | |
| def build_messages(system_prompt: str, history: list, prompt: str) -> list: | |
| """Costruisce la lista messages per chat_completion (task conversational).""" | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| for pair in history or []: | |
| user_msg = pair[0] if isinstance(pair, (list, tuple)) else pair.get("user", "") | |
| bot_msg = pair[1] if isinstance(pair, (list, tuple)) else pair.get("bot", "") | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": bot_msg}) | |
| messages.append({"role": "user", "content": prompt}) | |
| return messages | |
| def format_prompt_for_text_generation(system_prompt: str, history: list, prompt: str) -> str: | |
| """Formato Mistral [INST] per text_generation quando il provider non supporta chat.""" | |
| full_message = f"{system_prompt}, {prompt}" if system_prompt else prompt | |
| out = "<s>" | |
| for pair in history or []: | |
| user_msg = pair[0] if isinstance(pair, (list, tuple)) else pair.get("user", "") | |
| bot_msg = pair[1] if isinstance(pair, (list, tuple)) else pair.get("bot", "") | |
| out += f"[INST] {user_msg} [/INST] {bot_msg}</s> " | |
| out += f"[INST] {full_message} [/INST]" | |
| return out | |
| def chat_completion_via_http(messages: list, max_tokens: int, temperature: float, top_p: float) -> Optional[str]: | |
| """ | |
| Chiamata diretta all'endpoint HF chat completions (v1). | |
| Usata quando il SDK fallisce perché il modello non dichiara il task (es. Ministral-3). | |
| """ | |
| if not HF_TOKEN: | |
| return None | |
| base = INFERENCE_ENDPOINT_URL.rstrip("/") if INFERENCE_ENDPOINT_URL else f"https://api-inference.huggingface.co/models/{MODEL_ID}" | |
| url = f"{base}/v1/chat/completions" | |
| body = { | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| } | |
| data = json.dumps(body).encode("utf-8") | |
| req = urllib.request.Request( | |
| url, | |
| data=data, | |
| headers={ | |
| "Authorization": f"Bearer {HF_TOKEN}", | |
| "Content-Type": "application/json", | |
| }, | |
| method="POST", | |
| ) | |
| try: | |
| with urllib.request.urlopen(req, timeout=120) as resp: | |
| out = json.loads(resp.read().decode()) | |
| except urllib.error.HTTPError as e: | |
| if e.code == 503: | |
| err_body = e.read().decode() if e.fp else "" | |
| try: | |
| err_json = json.loads(err_body) | |
| if "estimated_time" in err_json: | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"Modello in caricamento. Riprova tra {err_json.get('estimated_time', 0):.0f}s.", | |
| ) | |
| except (ValueError, TypeError): | |
| pass | |
| return None | |
| except Exception: | |
| return None | |
| choices = out.get("choices") or [] | |
| if not choices: | |
| return None | |
| msg = choices[0].get("message") or {} | |
| return (msg.get("content") or "").strip() | |
| class Item(BaseModel): | |
| """Payload per /generate/. Per Ministral-3: temperature < 0.1 in produzione (raccomandato).""" | |
| prompt: str | |
| history: list = Field(default_factory=list, description="Lista di coppie (user_msg, bot_msg)") | |
| system_prompt: str = "" | |
| temperature: float = Field(default=0.0, description="Ministral-3: < 0.1 per produzione") | |
| max_new_tokens: int = 1048 | |
| top_p: float = 0.15 | |
| repetition_penalty: float = 1.0 | |
| def generate(item: Item) -> str: | |
| if not HF_TOKEN: | |
| raise HTTPException( | |
| status_code=500, | |
| detail="HF_TOKEN non configurato. Imposta il secret HF_TOKEN in Settings → Secrets dello Space.", | |
| ) | |
| temperature = float(item.temperature) | |
| if temperature < 1e-2: | |
| temperature = 1e-2 | |
| top_p = float(item.top_p) | |
| messages = build_messages(item.system_prompt, item.history or [], item.prompt) | |
| last_error = None | |
| # 1) Prova chat_completion con provider di default (es. Together) | |
| try: | |
| response = client.chat_completion( | |
| messages=messages, | |
| max_tokens=item.max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| if response.choices and len(response.choices) > 0: | |
| return response.choices[0].message.content or "" | |
| return "" | |
| except Exception as e1: | |
| last_error = e1 | |
| # 2) Riprova chat con hf-inference (supporta modelli come Mixtral-8x7B che su altri provider non sono "chat") | |
| try: | |
| response = client_hf.chat_completion( | |
| messages=messages, | |
| max_tokens=item.max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| if response.choices and len(response.choices) > 0: | |
| return response.choices[0].message.content or "" | |
| return "" | |
| except Exception as e2: | |
| last_error = e2 | |
| # 3) Chat completions via HTTP (endpoint v1) – funziona per modelli che non dichiarano il task (es. Ministral-3) | |
| try: | |
| content = chat_completion_via_http( | |
| messages, item.max_new_tokens, temperature, top_p | |
| ) | |
| if content is not None and content != "": | |
| return content | |
| except HTTPException: | |
| raise | |
| except Exception as e3: | |
| last_error = e3 | |
| # 4) Ultima risorsa: text_generation (solo per modelli che lo supportano su hf-inference) | |
| try: | |
| formatted = format_prompt_for_text_generation( | |
| item.system_prompt, item.history or [], item.prompt | |
| ) | |
| stream = client_hf.text_generation( | |
| formatted, | |
| max_new_tokens=item.max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| return_full_text=False, | |
| ) | |
| if isinstance(stream, str): | |
| return stream | |
| if hasattr(stream, "__iter__") and not isinstance(stream, str): | |
| return "".join( | |
| r.token.text if hasattr(r, "token") and hasattr(r.token, "text") else (r if isinstance(r, str) else "") | |
| for r in stream | |
| ) | |
| return str(stream) | |
| except Exception as e4: | |
| last_error = e4 | |
| raise HTTPException(status_code=502, detail=f"Inference fallita: {str(last_error)}") | |
| def root(): | |
| return {"status": "ok", "model": MODEL_ID} | |
| def health(): | |
| return {"status": "ok"} | |
| async def generate_text(item: Item): | |
| return {"response": generate(item)} | |