Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import HTMLResponse | |
| from pydantic import BaseModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| from pathlib import Path | |
| from typing import List, Optional | |
| app = FastAPI(title="DNAI Humour Chatbot API", version="1.1") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables | |
| model = None | |
| tokenizer = None | |
| MODEL_NAME = "DarkNeuronAI/dnai-humour-0.5B-instruct" | |
| async def load_model(): | |
| global model, tokenizer | |
| try: | |
| print(f"๐ Loading {MODEL_NAME} on CPU...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| # Low CPU memory usage logic | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float32, | |
| device_map="cpu", | |
| low_cpu_mem_usage=True | |
| ) | |
| model.eval() | |
| print("โ Model loaded on CPU successfully!") | |
| except Exception as e: | |
| print(f"โ Error loading model: {str(e)}") | |
| raise | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| # Updated Request Model to accept Settings | |
| class ChatRequest(BaseModel): | |
| messages: List[Message] | |
| temperature: Optional[float] = 0.7 | |
| top_p: Optional[float] = 0.9 | |
| max_tokens: Optional[int] = 256 | |
| system_prompt: Optional[str] = "You are DNAI, a helpful and humorous AI assistant." | |
| def format_chat_prompt(messages: List[Message], system_prompt: str) -> str: | |
| # Adding System Prompt to the beginning | |
| formatted = f"System: {system_prompt}\n" | |
| for msg in messages: | |
| if msg.role == "user": | |
| formatted += f"User: {msg.content}\n" | |
| elif msg.role == "assistant": | |
| formatted += f"Assistant: {msg.content}\n" | |
| formatted += "Assistant:" | |
| return formatted | |
| async def root(): | |
| html_path = Path(__file__).parent / "index.html" | |
| if html_path.exists(): | |
| with open(html_path, 'r', encoding='utf-8') as f: | |
| return HTMLResponse(content=f.read(), status_code=200) | |
| return "<h1>Error: index.html not found</h1>" | |
| async def chat(request: ChatRequest): | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model loading") | |
| try: | |
| # Pass system prompt explicitly | |
| prompt = format_chat_prompt(request.messages, request.system_prompt) | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Robust extraction | |
| response = generated_text[len(prompt):].strip() | |
| if "User:" in response: | |
| response = response.split("User:")[0].strip() | |
| return {"response": response} | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |