desc_compare / main.py
chripto's picture
Update main.py
07fb4c5 verified
import os
import json
from typing import Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from huggingface_hub import InferenceClient
import uvicorn
import urllib.request
import urllib.error
app = FastAPI()
# Token da Settings → Secrets (obbligatorio per Inference API / Inference Endpoints)
HF_TOKEN = os.environ.get("HF_TOKEN")
# Modello: default Ministral-3-14B-Instruct (multilingue, system prompt, Apache 2.0)
# Se usi Inference Endpoints, imposta HF_INFERENCE_ENDPOINT_URL con l'URL dell'endpoint
MODEL_ID = os.environ.get("MODEL_ID", "mistralai/Ministral-3-14B-Instruct-2512")
INFERENCE_ENDPOINT_URL = os.environ.get("HF_INFERENCE_ENDPOINT_URL")
# Se hai un Inference Endpoint dedicato (es. per Ministral-3-14B), usalo; altrimenti provider auto/hf-inference
if INFERENCE_ENDPOINT_URL:
client = InferenceClient(model=INFERENCE_ENDPOINT_URL, token=HF_TOKEN)
client_hf = client
else:
client = InferenceClient(model=MODEL_ID, token=HF_TOKEN)
client_hf = InferenceClient(model=MODEL_ID, token=HF_TOKEN, provider="hf-inference")
def build_messages(system_prompt: str, history: list, prompt: str) -> list:
"""Costruisce la lista messages per chat_completion (task conversational)."""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
for pair in history or []:
user_msg = pair[0] if isinstance(pair, (list, tuple)) else pair.get("user", "")
bot_msg = pair[1] if isinstance(pair, (list, tuple)) else pair.get("bot", "")
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": prompt})
return messages
def format_prompt_for_text_generation(system_prompt: str, history: list, prompt: str) -> str:
"""Formato Mistral [INST] per text_generation quando il provider non supporta chat."""
full_message = f"{system_prompt}, {prompt}" if system_prompt else prompt
out = "<s>"
for pair in history or []:
user_msg = pair[0] if isinstance(pair, (list, tuple)) else pair.get("user", "")
bot_msg = pair[1] if isinstance(pair, (list, tuple)) else pair.get("bot", "")
out += f"[INST] {user_msg} [/INST] {bot_msg}</s> "
out += f"[INST] {full_message} [/INST]"
return out
def chat_completion_via_http(messages: list, max_tokens: int, temperature: float, top_p: float) -> Optional[str]:
"""
Chiamata diretta all'endpoint HF chat completions (v1).
Usata quando il SDK fallisce perché il modello non dichiara il task (es. Ministral-3).
"""
if not HF_TOKEN:
return None
base = INFERENCE_ENDPOINT_URL.rstrip("/") if INFERENCE_ENDPOINT_URL else f"https://api-inference.huggingface.co/models/{MODEL_ID}"
url = f"{base}/v1/chat/completions"
body = {
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
data = json.dumps(body).encode("utf-8")
req = urllib.request.Request(
url,
data=data,
headers={
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json",
},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=120) as resp:
out = json.loads(resp.read().decode())
except urllib.error.HTTPError as e:
if e.code == 503:
err_body = e.read().decode() if e.fp else ""
try:
err_json = json.loads(err_body)
if "estimated_time" in err_json:
raise HTTPException(
status_code=503,
detail=f"Modello in caricamento. Riprova tra {err_json.get('estimated_time', 0):.0f}s.",
)
except (ValueError, TypeError):
pass
return None
except Exception:
return None
choices = out.get("choices") or []
if not choices:
return None
msg = choices[0].get("message") or {}
return (msg.get("content") or "").strip()
class Item(BaseModel):
"""Payload per /generate/. Per Ministral-3: temperature < 0.1 in produzione (raccomandato)."""
prompt: str
history: list = Field(default_factory=list, description="Lista di coppie (user_msg, bot_msg)")
system_prompt: str = ""
temperature: float = Field(default=0.0, description="Ministral-3: < 0.1 per produzione")
max_new_tokens: int = 1048
top_p: float = 0.15
repetition_penalty: float = 1.0
def generate(item: Item) -> str:
if not HF_TOKEN:
raise HTTPException(
status_code=500,
detail="HF_TOKEN non configurato. Imposta il secret HF_TOKEN in Settings → Secrets dello Space.",
)
temperature = float(item.temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(item.top_p)
messages = build_messages(item.system_prompt, item.history or [], item.prompt)
last_error = None
# 1) Prova chat_completion con provider di default (es. Together)
try:
response = client.chat_completion(
messages=messages,
max_tokens=item.max_new_tokens,
temperature=temperature,
top_p=top_p,
)
if response.choices and len(response.choices) > 0:
return response.choices[0].message.content or ""
return ""
except Exception as e1:
last_error = e1
# 2) Riprova chat con hf-inference (supporta modelli come Mixtral-8x7B che su altri provider non sono "chat")
try:
response = client_hf.chat_completion(
messages=messages,
max_tokens=item.max_new_tokens,
temperature=temperature,
top_p=top_p,
)
if response.choices and len(response.choices) > 0:
return response.choices[0].message.content or ""
return ""
except Exception as e2:
last_error = e2
# 3) Chat completions via HTTP (endpoint v1) – funziona per modelli che non dichiarano il task (es. Ministral-3)
try:
content = chat_completion_via_http(
messages, item.max_new_tokens, temperature, top_p
)
if content is not None and content != "":
return content
except HTTPException:
raise
except Exception as e3:
last_error = e3
# 4) Ultima risorsa: text_generation (solo per modelli che lo supportano su hf-inference)
try:
formatted = format_prompt_for_text_generation(
item.system_prompt, item.history or [], item.prompt
)
stream = client_hf.text_generation(
formatted,
max_new_tokens=item.max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
return_full_text=False,
)
if isinstance(stream, str):
return stream
if hasattr(stream, "__iter__") and not isinstance(stream, str):
return "".join(
r.token.text if hasattr(r, "token") and hasattr(r.token, "text") else (r if isinstance(r, str) else "")
for r in stream
)
return str(stream)
except Exception as e4:
last_error = e4
raise HTTPException(status_code=502, detail=f"Inference fallita: {str(last_error)}")
@app.get("/")
def root():
return {"status": "ok", "model": MODEL_ID}
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/generate/")
async def generate_text(item: Item):
return {"response": generate(item)}