English
rag
virginia-legal-rag-lexva / chat_engine.py
dcrodriguez's picture
add example output
3888f16
import os
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Tuple
from datetime import date
import json
import hashlib
import sqlite3
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
from retriever import InMemorySqliteRetriever
def make_cache_key(model: str,
messages,
generation_params: dict | None = None) -> tuple[str, str]:
"""
Supports prompt caching to speed up benchmarking tasks
Returns (cache_key_hash, cache_key_json_string).
cache_key_json_string is stored for debugging/inspection; the hash is the actual key.
"""
if generation_params is None:
generation_params = {}
key_obj = {
"model": model,
"messages": messages,
"generation_params": generation_params,
}
key_json = json.dumps(key_obj, sort_keys=True, ensure_ascii=False)
key_hash = hashlib.sha256(key_json.encode("utf-8")).hexdigest()
return key_hash, key_json
@dataclass
class LLMChatEngine:
"""
Chat wrapper around a Hugging Face Qwen3 14B 4-bit model with thinking support.
- Stores all visible messages (no thinking) in `self.messages` as {"role","content"}.
- Splits model output into thinking vs final answer when enable_thinking=True.
- Keeps last thinking segment in `self.last_thinking`.
- Supports configurable context length (in tokens).
- Provides `clear_history()` to reset the conversation.
"""
# Thinking-related state
think_end_token_ids: Optional[List[int]] = field(default=None, init=False)
last_thinking: Optional[str] = field(default=None, init=False)
def __init__(self,
sqlite_path: str = "va_code.db",
model_name: str = "Qwen/Qwen3-14B",
enable_thinking: bool = True):
self.model_name = model_name
self.enable_thinking = enable_thinking
self.retriever = InMemorySqliteRetriever(sqlite_path)
self.max_context_tokens: int = 32000
self.device_map: str = "auto"
self.system_prompt: str = f"""You are a helpful AI legal assistant.
You should only answer questions about Virginia law. Don't try to answer questions about other states.
If you're not confident about a topic, you can decline to answer.
If the users question is unclear, please ask clarifying questions.
Always tell the user that this is only a self research tool, and that you cannot provide legal advice.
The current date is {date.today()}
You may be given statues from a tool. Trust the tool outputs content, but decide if they are relevant before using them.
"""
self.tokenizer: AutoTokenizer = field(init=False)
self.model: AutoModelForCausalLM = field(init=False)
self.messages: List[Dict[str, str]] = []
# --- cache setup ---
self.cache_conn = sqlite3.connect(sqlite_path)
self._init_cache_table()
self.__post_init__()
# --------- Cache helpers ---------
def _init_cache_table(self):
cur = self.cache_conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS llm_cache (
id INTEGER PRIMARY KEY AUTOINCREMENT,
key_hash TEXT NOT NULL UNIQUE,
key_json TEXT NOT NULL,
model TEXT NOT NULL,
output_json TEXT NOT NULL,
created_at DATETIME DEFAULT (CURRENT_TIMESTAMP),
last_used_at DATETIME DEFAULT (CURRENT_TIMESTAMP),
hit_count INTEGER DEFAULT 0
);
"""
)
cur.execute(
"CREATE INDEX IF NOT EXISTS idx_llm_cache_model ON llm_cache(model);"
)
self.cache_conn.commit()
def _cache_get(self, key_hash: str) -> Optional[Dict]:
cur = self.cache_conn.cursor()
cur.execute(
"SELECT id, output_json FROM llm_cache WHERE key_hash = ?",
(key_hash,)
)
row = cur.fetchone()
if not row:
return None
cache_id, output_json = row
# best-effort stats update
cur.execute(
"""
UPDATE llm_cache
SET last_used_at = CURRENT_TIMESTAMP,
hit_count = hit_count + 1
WHERE id = ?
""",
(cache_id,)
)
self.cache_conn.commit()
return json.loads(output_json)
def _cache_put(self, key_hash: str, key_json: str, output_obj: Dict):
cur = self.cache_conn.cursor()
cur.execute(
"""
INSERT OR REPLACE INTO llm_cache (key_hash, key_json, model, output_json)
VALUES (?, ?, ?, ?)
""",
(key_hash, key_json, self.model_name,
json.dumps(output_obj, ensure_ascii=False))
)
self.cache_conn.commit()
# --------- Model init ---------
def __post_init__(self):
# 4-bit quantization config (bitsandbytes)
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
use_fast=True,
)
# If VRAM is tight, offload to CPU
max_memory = {}
if torch.cuda.is_available():
max_memory = {
0: "16GiB",
"cpu": "32GiB"
}
device_map = "auto"
else:
# Fallback: CPU only
max_memory = {"cpu": "32GiB"}
device_map = {"": "cpu"}
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=quant_config,
device_map=device_map, # spans GPU + CPU
max_memory=max_memory, # triggers offloading once GPU limit is hit
dtype=torch.bfloat16
)
# Precompute the token ids corresponding to "</think>" if possible
self._init_think_end_tokens()
# Initialize with a system message
if self.system_prompt:
self.messages.append({"role": "system", "content": self.system_prompt})
# --------- Public API ---------
def clear_history(self):
"""Clear the chat history, re-inserting the system prompt if provided."""
self.messages = []
if self.system_prompt:
self.messages.append({"role": "system", "content": self.system_prompt})
self.last_thinking = None
def add_message(self, role: str, content: str):
"""Manually append a message to the history."""
assert role in {"system", "user", "assistant", "tool"}
self.messages.append({"role": role, "content": content})
def retrieve_context(self, query, top_k=3):
results = self.retriever.query_relevant_docs(query, top_k=top_k)
tool_output = ""
for result in results:
section_num = result['doc_id'].replace('VA:', "§ ")
tool_output += f"Virginia Code {section_num}: {result['title']} - {result['text']}\n\n\n"
tool_output += "\n--------\n"
return tool_output
def chat(
self,
user_message: str,
max_new_tokens: int = 5000,
temperature: float = 0.7,
top_p: float = 0.9,
do_sample: bool = True,
rag: bool = True
) -> str:
"""
Add a user message, run the model (with cache), append assistant reply to history,
and return the assistant's response text (without thinking).
The model's chain-of-thought "thinking" is stored in `self.last_thinking`
when available.
"""
# Add user message
self.add_message("user", user_message)
# Optional RAG context
if rag:
tool_output = self.retrieve_context(user_message)
tool_msg = """Here are some statutes which may be useful for answering this question.
If a statute seems relevant, you should reference the section number in your answer,
like, 'According to Virginia Code § 63.2-1005...'
If the statute is not relevant, ignore it.
\n\n"""
self.add_message("assistant", tool_msg + tool_output)
# ----- Cache check -----
generation_params = {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"do_sample": do_sample,
"rag": rag,
"thinking": self.enable_thinking
}
key_hash, key_json = make_cache_key(
self.model_name,
self.messages,
generation_params,
)
cached = self._cache_get(key_hash)
if cached is not None:
thinking_text = cached.get("thinking", "") or ""
visible_text = cached.get("visible", "") or ""
self.last_thinking = thinking_text if thinking_text else None
# Reinsert assistant message into the conversation history
self.add_message("assistant", visible_text)
return visible_text
# ----- Cache miss: run the model -----
# Build model inputs from chat history, applying truncation
inputs = self._build_model_inputs(max_new_tokens=max_new_tokens)
with torch.no_grad():
generated = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
pad_token_id=self._pad_token_id,
)
# Full sequence ids, including prompt and completion
full_ids = generated[0].tolist()
prompt_len = inputs["input_ids"].shape[1]
completion_ids = full_ids[prompt_len:]
# Split into thinking vs visible content
thinking_text, visible_text = self._split_thinking_and_content(completion_ids)
self.last_thinking = thinking_text if thinking_text else None
self.add_message("assistant", visible_text)
# Store in cache
output_obj = {
"visible": visible_text,
"thinking": thinking_text,
}
self._cache_put(key_hash, key_json, output_obj)
return visible_text
# --------- Internal helpers ---------
@property
def _pad_token_id(self) -> int:
if self.tokenizer.pad_token_id is not None:
return self.tokenizer.pad_token_id
return self.tokenizer.eos_token_id
def _apply_chat_template(
self,
messages: List[Dict[str, str]],
add_generation_prompt: bool = True,
) -> str:
"""
Use the HF chat template if available.
When possible, enable Qwen thinking mode via `enable_thinking`.
"""
if hasattr(self.tokenizer, "apply_chat_template"):
# Try to call with enable_thinking for Qwen-style tokenizers
try:
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=add_generation_prompt,
enable_thinking=self.enable_thinking,
)
except TypeError:
# Fallback if tokenizer does not support enable_thinking
return self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
# Fallback: simple role: content format
text_parts = []
for m in messages:
text_parts.append(f"{m['role'].upper()}: {m['content']}")
if add_generation_prompt:
text_parts.append("ASSISTANT:")
return "\n".join(text_parts)
def _build_model_inputs(self, max_new_tokens: int):
"""
Encode the current message history and truncate from the left
until it fits into `max_context_tokens - max_new_tokens`.
"""
# Work on a copy so we can pop from the front for truncation
messages = list(self.messages)
while True:
prompt_text = self._apply_chat_template(
messages,
add_generation_prompt=True,
)
encoded = self.tokenizer(
prompt_text,
return_tensors="pt",
add_special_tokens=False,
)
input_ids = encoded["input_ids"][0]
# Leave room for generation
if input_ids.shape[0] <= self.max_context_tokens - max_new_tokens:
for k in encoded:
encoded[k] = encoded[k].to(self.model.device)
return encoded
# If too long, drop the earliest non-system message
if len(messages) <= 1:
# Can't shrink further; just keep last max_context_tokens tokens
for k in encoded:
encoded[k] = encoded[k][:, -self.max_context_tokens :].to(
self.model.device
)
return encoded
drop_idx = next(
(i for i, m in enumerate(messages) if m["role"] != "system"), 0
)
messages.pop(drop_idx)
# --------- Thinking support ---------
def _init_think_end_tokens(self):
"""
Initialize token ids representing the '</think>' marker, if present.
This is more general than hardcoding a single id like 151668.
"""
try:
ids = self.tokenizer.encode("</think>", add_special_tokens=False)
self.think_end_token_ids = ids if ids else None
except Exception:
self.think_end_token_ids = None
def _split_thinking_and_content(
self, output_ids: List[int]
) -> Tuple[str, str]:
"""
Split raw completion token ids into (thinking_text, visible_text).
- Find the last occurrence of '</think>' token sequence.
- Tokens before that are treated as thinking.
- Tokens from that point onward are treated as the visible answer.
- If not found or thinking disabled, all tokens are visible answer.
"""
# If thinking is disabled or marker unavailable, treat all as visible content
if not self.enable_thinking or not self.think_end_token_ids:
visible = self.tokenizer.decode(
output_ids, skip_special_tokens=True
).strip("\n")
return "", visible
pattern = self.think_end_token_ids
n = len(pattern)
index_tokens = 0 # default: no thinking
# Find last occurrence of pattern in output_ids (reverse search)
for i in range(len(output_ids) - n, -1, -1):
if output_ids[i : i + n] == pattern:
index_tokens = i + n
break
thinking_ids = output_ids[:index_tokens]
visible_ids = output_ids[index_tokens:]
thinking_text = self.tokenizer.decode(
thinking_ids, skip_special_tokens=True
).strip("\n")
visible_text = self.tokenizer.decode(
visible_ids, skip_special_tokens=True
).strip("\n")
return thinking_text, visible_text