Spaces:
Sleeping
Sleeping
File size: 3,660 Bytes
699e0ab c67c6e5 699e0ab a36268e 699e0ab a36268e 3c5bfb7 0352cc7 a36268e 0352cc7 699e0ab fc1751b 699e0ab a36268e 3c5bfb7 a36268e 3c5bfb7 699e0ab fc1751b 699e0ab d1b906b 3c5bfb7 a36268e 699e0ab a36268e 699e0ab fc1751b 699e0ab a36268e 0352cc7 fc1751b a36268e fc1751b a36268e fc1751b a36268e fc1751b 53aad97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
# src/chatbot.py
from typing import Dict, Any, Optional
from src.templates import TEMPLATES
from src.intent import detect_intent
MSG_SEPARATOR = "\n"
DEFAULT_GEN_ARGS = {
"max_tokens": 350,
"temperature": 0.6,
"top_p": 0.9
}
class LocalChatbot:
def __init__(self, llm, memory, default_template: Optional[str] = "general"):
self.llm = llm
self.memory = memory
self.default_template = default_template
# ------------------------------------------------
# Build system prompt based on intent
# ------------------------------------------------
def _build_system_prompt(self, intent: str) -> str:
return TEMPLATES.get(intent, TEMPLATES.get(self.default_template, TEMPLATES["general"]))
# ------------------------------------------------
# Build full prompt with memory and user message
# ------------------------------------------------
def _build_prompt(self, user_message: str, intent: str, max_pairs: int = 12) -> str:
try:
self.memory.trim_to_recent_pairs(max_pairs)
except Exception:
pass
system_prompt = self._build_system_prompt(intent)
history_text = self.memory.get_formatted(separator=MSG_SEPARATOR)
parts = [
f"System: {system_prompt}",
history_text,
f"User: {user_message}",
"Assistant:"
]
return MSG_SEPARATOR.join([p for p in parts if p])
# ------------------------------------------------
# Main ask function
# ------------------------------------------------
def ask(self, user_message: Any, gen_args: Optional[Dict[str, Any]] = None, intent: Optional[str] = None) -> str:
# Extract text if passed from Gradio
if isinstance(user_message, list):
user_message = "\n".join([item.get("text", "") if isinstance(item, dict) else str(item) for item in user_message])
elif isinstance(user_message, dict) and "text" in user_message:
user_message = user_message["text"]
user_message = str(user_message).strip()
if not user_message:
return "Please enter a message."
# Use passed intent or detect
if intent is None:
intent = detect_intent(user_message)
# Build prompt
prompt = self._build_prompt(user_message, intent, max_pairs=12)
# Merge generation args
gen = DEFAULT_GEN_ARGS.copy()
if gen_args:
gen.update(gen_args)
# Call LLM
try:
output = self.llm(prompt, **gen)
except TypeError:
alt_gen = gen.copy()
if "max_tokens" in alt_gen:
alt_gen["max_new_tokens"] = alt_gen.pop("max_tokens")
output = self.llm(prompt, **alt_gen)
# Parse output
bot_reply = ""
try:
if isinstance(output, dict) and "choices" in output:
bot_reply = output["choices"][0].get("text", "").strip()
elif isinstance(output, str):
bot_reply = output.strip()
else:
bot_reply = str(output).strip()
except Exception:
bot_reply = ""
if not bot_reply:
bot_reply = "Sorry — I couldn't generate a response. Please try again."
# Store memory
try:
self.memory.add(user_message, bot_reply)
except Exception:
try:
self.memory.add_message("user", user_message)
self.memory.add_message("assistant", bot_reply)
except Exception:
pass
return bot_reply
|