Spaces:
Sleeping
Sleeping
| # src/chatbot.py | |
| from typing import Dict, Any, Optional | |
| from src.templates import TEMPLATES | |
| from src.intent import detect_intent | |
| MSG_SEPARATOR = "\n" | |
| DEFAULT_GEN_ARGS = { | |
| "max_tokens": 350, | |
| "temperature": 0.6, | |
| "top_p": 0.9 | |
| } | |
| class LocalChatbot: | |
| def __init__(self, llm, memory, default_template: Optional[str] = "general"): | |
| self.llm = llm | |
| self.memory = memory | |
| self.default_template = default_template | |
| # ------------------------------------------------ | |
| # Build system prompt based on intent | |
| # ------------------------------------------------ | |
| def _build_system_prompt(self, intent: str) -> str: | |
| return TEMPLATES.get(intent, TEMPLATES.get(self.default_template, TEMPLATES["general"])) | |
| # ------------------------------------------------ | |
| # Build full prompt with memory and user message | |
| # ------------------------------------------------ | |
| def _build_prompt(self, user_message: str, intent: str, max_pairs: int = 12) -> str: | |
| try: | |
| self.memory.trim_to_recent_pairs(max_pairs) | |
| except Exception: | |
| pass | |
| system_prompt = self._build_system_prompt(intent) | |
| history_text = self.memory.get_formatted(separator=MSG_SEPARATOR) | |
| parts = [ | |
| f"System: {system_prompt}", | |
| history_text, | |
| f"User: {user_message}", | |
| "Assistant:" | |
| ] | |
| return MSG_SEPARATOR.join([p for p in parts if p]) | |
| # ------------------------------------------------ | |
| # Main ask function | |
| # ------------------------------------------------ | |
| def ask(self, user_message: Any, gen_args: Optional[Dict[str, Any]] = None, intent: Optional[str] = None) -> str: | |
| # Extract text if passed from Gradio | |
| if isinstance(user_message, list): | |
| user_message = "\n".join([item.get("text", "") if isinstance(item, dict) else str(item) for item in user_message]) | |
| elif isinstance(user_message, dict) and "text" in user_message: | |
| user_message = user_message["text"] | |
| user_message = str(user_message).strip() | |
| if not user_message: | |
| return "Please enter a message." | |
| # Use passed intent or detect | |
| if intent is None: | |
| intent = detect_intent(user_message) | |
| # Build prompt | |
| prompt = self._build_prompt(user_message, intent, max_pairs=12) | |
| # Merge generation args | |
| gen = DEFAULT_GEN_ARGS.copy() | |
| if gen_args: | |
| gen.update(gen_args) | |
| # Call LLM | |
| try: | |
| output = self.llm(prompt, **gen) | |
| except TypeError: | |
| alt_gen = gen.copy() | |
| if "max_tokens" in alt_gen: | |
| alt_gen["max_new_tokens"] = alt_gen.pop("max_tokens") | |
| output = self.llm(prompt, **alt_gen) | |
| # Parse output | |
| bot_reply = "" | |
| try: | |
| if isinstance(output, dict) and "choices" in output: | |
| bot_reply = output["choices"][0].get("text", "").strip() | |
| elif isinstance(output, str): | |
| bot_reply = output.strip() | |
| else: | |
| bot_reply = str(output).strip() | |
| except Exception: | |
| bot_reply = "" | |
| if not bot_reply: | |
| bot_reply = "Sorry β I couldn't generate a response. Please try again." | |
| # Store memory | |
| try: | |
| self.memory.add(user_message, bot_reply) | |
| except Exception: | |
| try: | |
| self.memory.add_message("user", user_message) | |
| self.memory.add_message("assistant", bot_reply) | |
| except Exception: | |
| pass | |
| return bot_reply | |