|
|
import os |
|
|
from dataclasses import dataclass, field |
|
|
from typing import List, Dict, Optional, Tuple |
|
|
from datetime import date |
|
|
import json |
|
|
import hashlib |
|
|
import sqlite3 |
|
|
|
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
BitsAndBytesConfig, |
|
|
) |
|
|
|
|
|
from retriever import InMemorySqliteRetriever |
|
|
|
|
|
def make_cache_key(model: str, |
|
|
messages, |
|
|
generation_params: dict | None = None) -> tuple[str, str]: |
|
|
""" |
|
|
Supports prompt caching to speed up benchmarking tasks |
|
|
|
|
|
Returns (cache_key_hash, cache_key_json_string). |
|
|
cache_key_json_string is stored for debugging/inspection; the hash is the actual key. |
|
|
""" |
|
|
if generation_params is None: |
|
|
generation_params = {} |
|
|
|
|
|
key_obj = { |
|
|
"model": model, |
|
|
"messages": messages, |
|
|
"generation_params": generation_params, |
|
|
} |
|
|
key_json = json.dumps(key_obj, sort_keys=True, ensure_ascii=False) |
|
|
key_hash = hashlib.sha256(key_json.encode("utf-8")).hexdigest() |
|
|
return key_hash, key_json |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LLMChatEngine: |
|
|
""" |
|
|
Chat wrapper around a Hugging Face Qwen3 14B 4-bit model with thinking support. |
|
|
|
|
|
- Stores all visible messages (no thinking) in `self.messages` as {"role","content"}. |
|
|
- Splits model output into thinking vs final answer when enable_thinking=True. |
|
|
- Keeps last thinking segment in `self.last_thinking`. |
|
|
- Supports configurable context length (in tokens). |
|
|
- Provides `clear_history()` to reset the conversation. |
|
|
""" |
|
|
|
|
|
|
|
|
think_end_token_ids: Optional[List[int]] = field(default=None, init=False) |
|
|
last_thinking: Optional[str] = field(default=None, init=False) |
|
|
|
|
|
def __init__(self, |
|
|
sqlite_path: str = "va_code.db", |
|
|
model_name: str = "Qwen/Qwen3-14B", |
|
|
enable_thinking: bool = True): |
|
|
self.model_name = model_name |
|
|
self.enable_thinking = enable_thinking |
|
|
self.retriever = InMemorySqliteRetriever(sqlite_path) |
|
|
|
|
|
self.max_context_tokens: int = 32000 |
|
|
self.device_map: str = "auto" |
|
|
self.system_prompt: str = f"""You are a helpful AI legal assistant. |
|
|
You should only answer questions about Virginia law. Don't try to answer questions about other states. |
|
|
If you're not confident about a topic, you can decline to answer. |
|
|
If the users question is unclear, please ask clarifying questions. |
|
|
Always tell the user that this is only a self research tool, and that you cannot provide legal advice. |
|
|
The current date is {date.today()} |
|
|
|
|
|
You may be given statues from a tool. Trust the tool outputs content, but decide if they are relevant before using them. |
|
|
""" |
|
|
|
|
|
self.tokenizer: AutoTokenizer = field(init=False) |
|
|
self.model: AutoModelForCausalLM = field(init=False) |
|
|
self.messages: List[Dict[str, str]] = [] |
|
|
|
|
|
|
|
|
self.cache_conn = sqlite3.connect(sqlite_path) |
|
|
self._init_cache_table() |
|
|
|
|
|
self.__post_init__() |
|
|
|
|
|
|
|
|
|
|
|
def _init_cache_table(self): |
|
|
cur = self.cache_conn.cursor() |
|
|
cur.execute( |
|
|
""" |
|
|
CREATE TABLE IF NOT EXISTS llm_cache ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
key_hash TEXT NOT NULL UNIQUE, |
|
|
key_json TEXT NOT NULL, |
|
|
model TEXT NOT NULL, |
|
|
output_json TEXT NOT NULL, |
|
|
created_at DATETIME DEFAULT (CURRENT_TIMESTAMP), |
|
|
last_used_at DATETIME DEFAULT (CURRENT_TIMESTAMP), |
|
|
hit_count INTEGER DEFAULT 0 |
|
|
); |
|
|
""" |
|
|
) |
|
|
cur.execute( |
|
|
"CREATE INDEX IF NOT EXISTS idx_llm_cache_model ON llm_cache(model);" |
|
|
) |
|
|
self.cache_conn.commit() |
|
|
|
|
|
def _cache_get(self, key_hash: str) -> Optional[Dict]: |
|
|
cur = self.cache_conn.cursor() |
|
|
cur.execute( |
|
|
"SELECT id, output_json FROM llm_cache WHERE key_hash = ?", |
|
|
(key_hash,) |
|
|
) |
|
|
row = cur.fetchone() |
|
|
if not row: |
|
|
return None |
|
|
|
|
|
cache_id, output_json = row |
|
|
|
|
|
cur.execute( |
|
|
""" |
|
|
UPDATE llm_cache |
|
|
SET last_used_at = CURRENT_TIMESTAMP, |
|
|
hit_count = hit_count + 1 |
|
|
WHERE id = ? |
|
|
""", |
|
|
(cache_id,) |
|
|
) |
|
|
self.cache_conn.commit() |
|
|
|
|
|
return json.loads(output_json) |
|
|
|
|
|
def _cache_put(self, key_hash: str, key_json: str, output_obj: Dict): |
|
|
cur = self.cache_conn.cursor() |
|
|
cur.execute( |
|
|
""" |
|
|
INSERT OR REPLACE INTO llm_cache (key_hash, key_json, model, output_json) |
|
|
VALUES (?, ?, ?, ?) |
|
|
""", |
|
|
(key_hash, key_json, self.model_name, |
|
|
json.dumps(output_obj, ensure_ascii=False)) |
|
|
) |
|
|
self.cache_conn.commit() |
|
|
|
|
|
|
|
|
|
|
|
def __post_init__(self): |
|
|
|
|
|
quant_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_use_double_quant=False, |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
self.model_name, |
|
|
use_fast=True, |
|
|
) |
|
|
|
|
|
|
|
|
max_memory = {} |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
max_memory = { |
|
|
0: "16GiB", |
|
|
"cpu": "32GiB" |
|
|
} |
|
|
device_map = "auto" |
|
|
else: |
|
|
|
|
|
max_memory = {"cpu": "32GiB"} |
|
|
device_map = {"": "cpu"} |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
self.model_name, |
|
|
quantization_config=quant_config, |
|
|
device_map=device_map, |
|
|
max_memory=max_memory, |
|
|
dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
|
|
|
self._init_think_end_tokens() |
|
|
|
|
|
|
|
|
if self.system_prompt: |
|
|
self.messages.append({"role": "system", "content": self.system_prompt}) |
|
|
|
|
|
|
|
|
|
|
|
def clear_history(self): |
|
|
"""Clear the chat history, re-inserting the system prompt if provided.""" |
|
|
self.messages = [] |
|
|
if self.system_prompt: |
|
|
self.messages.append({"role": "system", "content": self.system_prompt}) |
|
|
self.last_thinking = None |
|
|
|
|
|
def add_message(self, role: str, content: str): |
|
|
"""Manually append a message to the history.""" |
|
|
assert role in {"system", "user", "assistant", "tool"} |
|
|
self.messages.append({"role": role, "content": content}) |
|
|
|
|
|
def retrieve_context(self, query, top_k=3): |
|
|
results = self.retriever.query_relevant_docs(query, top_k=top_k) |
|
|
tool_output = "" |
|
|
for result in results: |
|
|
section_num = result['doc_id'].replace('VA:', "§ ") |
|
|
tool_output += f"Virginia Code {section_num}: {result['title']} - {result['text']}\n\n\n" |
|
|
tool_output += "\n--------\n" |
|
|
return tool_output |
|
|
|
|
|
def chat( |
|
|
self, |
|
|
user_message: str, |
|
|
max_new_tokens: int = 5000, |
|
|
temperature: float = 0.7, |
|
|
top_p: float = 0.9, |
|
|
do_sample: bool = True, |
|
|
rag: bool = True |
|
|
) -> str: |
|
|
""" |
|
|
Add a user message, run the model (with cache), append assistant reply to history, |
|
|
and return the assistant's response text (without thinking). |
|
|
|
|
|
The model's chain-of-thought "thinking" is stored in `self.last_thinking` |
|
|
when available. |
|
|
""" |
|
|
|
|
|
self.add_message("user", user_message) |
|
|
|
|
|
|
|
|
if rag: |
|
|
tool_output = self.retrieve_context(user_message) |
|
|
tool_msg = """Here are some statutes which may be useful for answering this question. |
|
|
If a statute seems relevant, you should reference the section number in your answer, |
|
|
like, 'According to Virginia Code § 63.2-1005...' |
|
|
If the statute is not relevant, ignore it. |
|
|
\n\n""" |
|
|
self.add_message("assistant", tool_msg + tool_output) |
|
|
|
|
|
|
|
|
generation_params = { |
|
|
"max_new_tokens": max_new_tokens, |
|
|
"temperature": temperature, |
|
|
"top_p": top_p, |
|
|
"do_sample": do_sample, |
|
|
"rag": rag, |
|
|
"thinking": self.enable_thinking |
|
|
} |
|
|
|
|
|
key_hash, key_json = make_cache_key( |
|
|
self.model_name, |
|
|
self.messages, |
|
|
generation_params, |
|
|
) |
|
|
|
|
|
cached = self._cache_get(key_hash) |
|
|
if cached is not None: |
|
|
thinking_text = cached.get("thinking", "") or "" |
|
|
visible_text = cached.get("visible", "") or "" |
|
|
self.last_thinking = thinking_text if thinking_text else None |
|
|
|
|
|
self.add_message("assistant", visible_text) |
|
|
return visible_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = self._build_model_inputs(max_new_tokens=max_new_tokens) |
|
|
|
|
|
with torch.no_grad(): |
|
|
generated = self.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=do_sample, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
pad_token_id=self._pad_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
full_ids = generated[0].tolist() |
|
|
prompt_len = inputs["input_ids"].shape[1] |
|
|
completion_ids = full_ids[prompt_len:] |
|
|
|
|
|
|
|
|
thinking_text, visible_text = self._split_thinking_and_content(completion_ids) |
|
|
|
|
|
self.last_thinking = thinking_text if thinking_text else None |
|
|
self.add_message("assistant", visible_text) |
|
|
|
|
|
|
|
|
output_obj = { |
|
|
"visible": visible_text, |
|
|
"thinking": thinking_text, |
|
|
} |
|
|
self._cache_put(key_hash, key_json, output_obj) |
|
|
|
|
|
return visible_text |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
def _pad_token_id(self) -> int: |
|
|
if self.tokenizer.pad_token_id is not None: |
|
|
return self.tokenizer.pad_token_id |
|
|
return self.tokenizer.eos_token_id |
|
|
|
|
|
def _apply_chat_template( |
|
|
self, |
|
|
messages: List[Dict[str, str]], |
|
|
add_generation_prompt: bool = True, |
|
|
) -> str: |
|
|
""" |
|
|
Use the HF chat template if available. |
|
|
When possible, enable Qwen thinking mode via `enable_thinking`. |
|
|
""" |
|
|
if hasattr(self.tokenizer, "apply_chat_template"): |
|
|
|
|
|
try: |
|
|
return self.tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=add_generation_prompt, |
|
|
enable_thinking=self.enable_thinking, |
|
|
) |
|
|
except TypeError: |
|
|
|
|
|
return self.tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=add_generation_prompt, |
|
|
) |
|
|
|
|
|
|
|
|
text_parts = [] |
|
|
for m in messages: |
|
|
text_parts.append(f"{m['role'].upper()}: {m['content']}") |
|
|
if add_generation_prompt: |
|
|
text_parts.append("ASSISTANT:") |
|
|
return "\n".join(text_parts) |
|
|
|
|
|
def _build_model_inputs(self, max_new_tokens: int): |
|
|
""" |
|
|
Encode the current message history and truncate from the left |
|
|
until it fits into `max_context_tokens - max_new_tokens`. |
|
|
""" |
|
|
|
|
|
messages = list(self.messages) |
|
|
|
|
|
while True: |
|
|
prompt_text = self._apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
encoded = self.tokenizer( |
|
|
prompt_text, |
|
|
return_tensors="pt", |
|
|
add_special_tokens=False, |
|
|
) |
|
|
|
|
|
input_ids = encoded["input_ids"][0] |
|
|
|
|
|
if input_ids.shape[0] <= self.max_context_tokens - max_new_tokens: |
|
|
for k in encoded: |
|
|
encoded[k] = encoded[k].to(self.model.device) |
|
|
return encoded |
|
|
|
|
|
|
|
|
if len(messages) <= 1: |
|
|
|
|
|
for k in encoded: |
|
|
encoded[k] = encoded[k][:, -self.max_context_tokens :].to( |
|
|
self.model.device |
|
|
) |
|
|
return encoded |
|
|
|
|
|
drop_idx = next( |
|
|
(i for i, m in enumerate(messages) if m["role"] != "system"), 0 |
|
|
) |
|
|
messages.pop(drop_idx) |
|
|
|
|
|
|
|
|
|
|
|
def _init_think_end_tokens(self): |
|
|
""" |
|
|
Initialize token ids representing the '</think>' marker, if present. |
|
|
This is more general than hardcoding a single id like 151668. |
|
|
""" |
|
|
try: |
|
|
ids = self.tokenizer.encode("</think>", add_special_tokens=False) |
|
|
self.think_end_token_ids = ids if ids else None |
|
|
except Exception: |
|
|
self.think_end_token_ids = None |
|
|
|
|
|
def _split_thinking_and_content( |
|
|
self, output_ids: List[int] |
|
|
) -> Tuple[str, str]: |
|
|
""" |
|
|
Split raw completion token ids into (thinking_text, visible_text). |
|
|
|
|
|
- Find the last occurrence of '</think>' token sequence. |
|
|
- Tokens before that are treated as thinking. |
|
|
- Tokens from that point onward are treated as the visible answer. |
|
|
- If not found or thinking disabled, all tokens are visible answer. |
|
|
""" |
|
|
|
|
|
if not self.enable_thinking or not self.think_end_token_ids: |
|
|
visible = self.tokenizer.decode( |
|
|
output_ids, skip_special_tokens=True |
|
|
).strip("\n") |
|
|
return "", visible |
|
|
|
|
|
pattern = self.think_end_token_ids |
|
|
n = len(pattern) |
|
|
index_tokens = 0 |
|
|
|
|
|
|
|
|
for i in range(len(output_ids) - n, -1, -1): |
|
|
if output_ids[i : i + n] == pattern: |
|
|
index_tokens = i + n |
|
|
break |
|
|
|
|
|
thinking_ids = output_ids[:index_tokens] |
|
|
visible_ids = output_ids[index_tokens:] |
|
|
|
|
|
thinking_text = self.tokenizer.decode( |
|
|
thinking_ids, skip_special_tokens=True |
|
|
).strip("\n") |
|
|
visible_text = self.tokenizer.decode( |
|
|
visible_ids, skip_special_tokens=True |
|
|
).strip("\n") |
|
|
|
|
|
return thinking_text, visible_text |
|
|
|