DarkNeuron-AI's picture
Update app.py
bcbbfd9 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from pathlib import Path
from typing import List, Optional
app = FastAPI(title="DNAI Humour Chatbot API", version="1.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables
model = None
tokenizer = None
MODEL_NAME = "DarkNeuronAI/dnai-humour-0.5B-instruct"
@app.on_event("startup")
async def load_model():
global model, tokenizer
try:
print(f"๐Ÿ”„ Loading {MODEL_NAME} on CPU...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Low CPU memory usage logic
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
device_map="cpu",
low_cpu_mem_usage=True
)
model.eval()
print("โœ… Model loaded on CPU successfully!")
except Exception as e:
print(f"โŒ Error loading model: {str(e)}")
raise
class Message(BaseModel):
role: str
content: str
# Updated Request Model to accept Settings
class ChatRequest(BaseModel):
messages: List[Message]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 0.9
max_tokens: Optional[int] = 256
system_prompt: Optional[str] = "You are DNAI, a helpful and humorous AI assistant."
def format_chat_prompt(messages: List[Message], system_prompt: str) -> str:
# Adding System Prompt to the beginning
formatted = f"System: {system_prompt}\n"
for msg in messages:
if msg.role == "user":
formatted += f"User: {msg.content}\n"
elif msg.role == "assistant":
formatted += f"Assistant: {msg.content}\n"
formatted += "Assistant:"
return formatted
@app.get("/", response_class=HTMLResponse)
async def root():
html_path = Path(__file__).parent / "index.html"
if html_path.exists():
with open(html_path, 'r', encoding='utf-8') as f:
return HTMLResponse(content=f.read(), status_code=200)
return "<h1>Error: index.html not found</h1>"
@app.post("/api/chat")
async def chat(request: ChatRequest):
if model is None:
raise HTTPException(status_code=503, detail="Model loading")
try:
# Pass system prompt explicitly
prompt = format_chat_prompt(request.messages, request.system_prompt)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Robust extraction
response = generated_text[len(prompt):].strip()
if "User:" in response:
response = response.split("User:")[0].strip()
return {"response": response}
except Exception as e:
print(f"Error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)