File size: 3,660 Bytes
699e0ab
c67c6e5
699e0ab
 
a36268e
699e0ab
a36268e
3c5bfb7
0352cc7
a36268e
 
 
0352cc7
 
699e0ab
fc1751b
 
699e0ab
 
 
a36268e
 
 
3c5bfb7
 
 
a36268e
 
 
3c5bfb7
699e0ab
 
fc1751b
699e0ab
 
 
d1b906b
3c5bfb7
a36268e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
699e0ab
 
a36268e
 
 
699e0ab
fc1751b
 
699e0ab
a36268e
0352cc7
 
 
fc1751b
a36268e
fc1751b
 
 
 
 
 
 
 
a36268e
fc1751b
 
 
 
 
 
 
 
 
 
 
 
 
 
a36268e
fc1751b
 
 
 
 
 
 
 
 
53aad97
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# src/chatbot.py

from typing import Dict, Any, Optional
from src.templates import TEMPLATES
from src.intent import detect_intent

MSG_SEPARATOR = "\n"

DEFAULT_GEN_ARGS = {
    "max_tokens": 350,
    "temperature": 0.6,
    "top_p": 0.9
}

class LocalChatbot:
    def __init__(self, llm, memory, default_template: Optional[str] = "general"):
        self.llm = llm
        self.memory = memory
        self.default_template = default_template

    # ------------------------------------------------
    # Build system prompt based on intent
    # ------------------------------------------------
    def _build_system_prompt(self, intent: str) -> str:
        return TEMPLATES.get(intent, TEMPLATES.get(self.default_template, TEMPLATES["general"]))

    # ------------------------------------------------
    # Build full prompt with memory and user message
    # ------------------------------------------------
    def _build_prompt(self, user_message: str, intent: str, max_pairs: int = 12) -> str:
        try:
            self.memory.trim_to_recent_pairs(max_pairs)
        except Exception:
            pass

        system_prompt = self._build_system_prompt(intent)
        history_text = self.memory.get_formatted(separator=MSG_SEPARATOR)

        parts = [
            f"System: {system_prompt}",
            history_text,
            f"User: {user_message}",
            "Assistant:"
        ]

        return MSG_SEPARATOR.join([p for p in parts if p])

    # ------------------------------------------------
    # Main ask function
    # ------------------------------------------------
    def ask(self, user_message: Any, gen_args: Optional[Dict[str, Any]] = None, intent: Optional[str] = None) -> str:
        # Extract text if passed from Gradio
        if isinstance(user_message, list):
            user_message = "\n".join([item.get("text", "") if isinstance(item, dict) else str(item) for item in user_message])
        elif isinstance(user_message, dict) and "text" in user_message:
            user_message = user_message["text"]

        user_message = str(user_message).strip()
        if not user_message:
            return "Please enter a message."

        # Use passed intent or detect
        if intent is None:
            intent = detect_intent(user_message)

        # Build prompt
        prompt = self._build_prompt(user_message, intent, max_pairs=12)

        # Merge generation args
        gen = DEFAULT_GEN_ARGS.copy()
        if gen_args:
            gen.update(gen_args)

        # Call LLM
        try:
            output = self.llm(prompt, **gen)
        except TypeError:
            alt_gen = gen.copy()
            if "max_tokens" in alt_gen:
                alt_gen["max_new_tokens"] = alt_gen.pop("max_tokens")
            output = self.llm(prompt, **alt_gen)

        # Parse output
        bot_reply = ""
        try:
            if isinstance(output, dict) and "choices" in output:
                bot_reply = output["choices"][0].get("text", "").strip()
            elif isinstance(output, str):
                bot_reply = output.strip()
            else:
                bot_reply = str(output).strip()
        except Exception:
            bot_reply = ""

        if not bot_reply:
            bot_reply = "Sorry — I couldn't generate a response. Please try again."

        # Store memory
        try:
            self.memory.add(user_message, bot_reply)
        except Exception:
            try:
                self.memory.add_message("user", user_message)
                self.memory.add_message("assistant", bot_reply)
            except Exception:
                pass

        return bot_reply