Zhiqi(Eli) Wang commited on
Commit
795902f
·
0 Parent(s):
Files changed (7) hide show
  1. README.md +76 -0
  2. app.py +1130 -0
  3. docs/app.js +61 -0
  4. docs/config.js +4 -0
  5. docs/index.html +53 -0
  6. docs/styles.css +195 -0
  7. requirements.txt +5 -0
README.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SPADE
3
+ emoji: 💻
4
+ colorFrom: red
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 6.5.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ ## What this repo provides
14
+
15
+ 1. `app.py`: Hugging Face Space backend pipeline (`x -> detector -> explainer`)
16
+ 2. `docs/`: GitHub Pages demo frontend that calls the Space API
17
+
18
+ ## Backend (Hugging Face Space)
19
+
20
+ This repo is configured as a Gradio Space (`sdk: gradio`).
21
+
22
+ ### API endpoint
23
+
24
+ - Endpoint name: `/pipeline`
25
+ - Input: `text` (string)
26
+ - Output: `[detector_output_json, explainer_markdown]`
27
+
28
+ ### Replace with your paper models
29
+
30
+ In `app.py`, replace:
31
+
32
+ - `run_detector(text)` with your detector inference
33
+ - `run_explainer(text, detector_output)` with your explainer inference
34
+
35
+ The pipeline order is already implemented in `pipeline(text)`.
36
+
37
+ ## Frontend (GitHub Project Page)
38
+
39
+ The static site lives in `docs/`:
40
+
41
+ - `docs/index.html`
42
+ - `docs/app.js`
43
+ - `docs/config.js`
44
+ - `docs/styles.css`
45
+
46
+ ### Configure Space ID
47
+
48
+ Edit `docs/config.js`:
49
+
50
+ ```js
51
+ window.SPADE_CONFIG = {
52
+ spaceId: "your-username/your-space-name",
53
+ };
54
+ ```
55
+
56
+ ### Enable GitHub Pages
57
+
58
+ In GitHub repo settings:
59
+
60
+ 1. Open `Settings -> Pages`
61
+ 2. Set source to `Deploy from a branch`
62
+ 3. Select branch `main` and folder `/docs`
63
+ 4. Save
64
+
65
+ Your demo will be published at:
66
+
67
+ `https://<your-github-username>.github.io/<repo-name>/`
68
+
69
+ ## Local run (optional)
70
+
71
+ ```bash
72
+ pip install -r requirements.txt
73
+ python app.py
74
+ ```
75
+
76
+ Then open `docs/index.html` in a local static server if needed.
app.py ADDED
@@ -0,0 +1,1130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import json
5
+ import logging
6
+ import os
7
+ import re
8
+ import time
9
+ from functools import lru_cache
10
+ from pathlib import Path
11
+ from queue import Empty
12
+ from threading import Thread
13
+ from typing import Any, Dict, Iterator, Tuple
14
+
15
+ import gradio as gr
16
+ import torch
17
+ from transformers import pipeline as hf_pipeline
18
+ from transformers import TextIteratorStreamer
19
+ from transformers import __version__ as transformers_version
20
+
21
+
22
+ DETECTOR_MODEL_ID = os.getenv(
23
+ "DETECTOR_MODEL_ID",
24
+ "ZhiqiEliWang/qwen3_0.6b_psyscam_romance_ephishllm",
25
+ )
26
+ EXPLAINER_MODEL_ID = os.getenv(
27
+ "EXPLAINER_MODEL_ID",
28
+ "ZhiqiEliWang/qwen3_0.6b_explainer",
29
+ )
30
+ STOP_TOKEN = "<|im_end|>"
31
+ NUM_CTX = 4096
32
+ TEMPERATURE = 0.6
33
+ TOP_K = 20
34
+ TOP_P = 0.95
35
+ MAX_NEW_TOKENS_DETECTOR = int(os.getenv("MAX_NEW_TOKENS_DETECTOR", "2048"))
36
+ MAX_NEW_TOKENS_EXPLAINER = int(os.getenv("MAX_NEW_TOKENS_EXPLAINER", "512"))
37
+ USER_PLACEHOLDER = "<<__SPADE_USER_PROMPT__>>"
38
+
39
+
40
+ def _default_kv_cache_dir() -> Path:
41
+ data_dir = Path("/data")
42
+ if data_dir.exists() and os.access(data_dir, os.W_OK):
43
+ return data_dir / "spade_kv_cache"
44
+ return Path("/tmp/spade_kv_cache")
45
+
46
+
47
+ KV_CACHE_DIR = Path(os.getenv("KV_CACHE_DIR", str(_default_kv_cache_dir())))
48
+ ENABLE_DISK_KV_CACHE = os.getenv("ENABLE_DISK_KV_CACHE", "1") == "1"
49
+ WARMUP_ON_STARTUP = os.getenv("WARMUP_ON_STARTUP", "1") == "1"
50
+ KV_CACHE_SCHEMA_VERSION = os.getenv("KV_CACHE_SCHEMA_VERSION", "2")
51
+ # FORCE_CLEAN_KV_CACHE_ON_STARTUP = os.getenv("FORCE_CLEAN_KV_CACHE_ON_STARTUP", "0") == "1"
52
+ FORCE_CLEAN_KV_CACHE_ON_STARTUP = "1"
53
+
54
+ logging.basicConfig(level=logging.INFO)
55
+ logger = logging.getLogger("spade")
56
+ _PROMPT_TEMPLATE_CACHE: Dict[Tuple[int, str, bool], str] = {}
57
+ _PREFIX_KV_CACHE: Dict[Tuple[int, str, bool], Tuple[Any, int, str]] = {}
58
+
59
+ CUSTOM_CSS = """
60
+ :root {
61
+ --bg: #f4f5f7;
62
+ --surface: #ffffff;
63
+ --text: #111827;
64
+ --muted: #6b7280;
65
+ --accent: #245fa8;
66
+ --accent-hover: #1f4f8a;
67
+ --border: #d9dde3;
68
+ --focus: #245fa8;
69
+ --focus-ring: rgba(36, 95, 168, 0.18);
70
+ }
71
+
72
+ html, body, .gradio-container {
73
+ background: var(--bg) !important;
74
+ color: var(--text) !important;
75
+ }
76
+
77
+ .gradio-container {
78
+ font-family: "Helvetica Neue", Helvetica, Arial, sans-serif !important;
79
+ font-size: 15px !important;
80
+ line-height: 1.45 !important;
81
+ color-scheme: light !important;
82
+ --body-background-fill: var(--bg) !important;
83
+ --body-background-fill-dark: var(--bg) !important;
84
+ --body-text-color: var(--text) !important;
85
+ --body-text-color-dark: var(--text) !important;
86
+ --body-text-color-subdued: var(--muted) !important;
87
+ --body-text-color-subdued-dark: var(--muted) !important;
88
+ --background-fill-primary: var(--bg) !important;
89
+ --background-fill-primary-dark: var(--bg) !important;
90
+ --background-fill-secondary: var(--surface) !important;
91
+ --background-fill-secondary-dark: var(--surface) !important;
92
+ --block-background-fill: var(--surface) !important;
93
+ --block-background-fill-dark: var(--surface) !important;
94
+ --block-border-color: var(--border) !important;
95
+ --block-border-color-dark: var(--border) !important;
96
+ --block-title-text-color: var(--text) !important;
97
+ --block-title-text-color-dark: var(--text) !important;
98
+ --block-label-text-color: var(--text) !important;
99
+ --block-label-text-color-dark: var(--text) !important;
100
+ --input-background-fill: #ffffff !important;
101
+ --input-background-fill-dark: #ffffff !important;
102
+ --input-border-color: var(--border) !important;
103
+ --input-border-color-dark: var(--border) !important;
104
+ --input-placeholder-color: var(--muted) !important;
105
+ --input-placeholder-color-dark: var(--muted) !important;
106
+ --button-primary-background-fill: var(--accent) !important;
107
+ --button-primary-background-fill-dark: var(--accent) !important;
108
+ --button-primary-background-fill-hover: var(--accent-hover) !important;
109
+ --button-primary-background-fill-hover-dark: var(--accent-hover) !important;
110
+ --button-primary-border-color: var(--accent) !important;
111
+ --button-primary-border-color-dark: var(--accent) !important;
112
+ --button-primary-text-color: #ffffff !important;
113
+ --button-primary-text-color-dark: #ffffff !important;
114
+ --code-background-fill: #fbfbfc !important;
115
+ --code-background-fill-dark: #fbfbfc !important;
116
+ }
117
+
118
+ body.dark .gradio-container,
119
+ .dark .gradio-container,
120
+ [data-theme="dark"] .gradio-container {
121
+ color-scheme: light !important;
122
+ }
123
+
124
+ .gradio-container h1,
125
+ .gradio-container h2,
126
+ .gradio-container h3,
127
+ .gradio-container label {
128
+ letter-spacing: -0.01em;
129
+ }
130
+
131
+ #app-shell {
132
+ max-width: 1080px;
133
+ margin: 0 auto;
134
+ padding: 2rem 1rem 2.5rem;
135
+ }
136
+
137
+ .hero {
138
+ margin-bottom: 1rem;
139
+ padding: 0;
140
+ }
141
+
142
+ .hero h1 {
143
+ margin: 0 0 0.5rem;
144
+ font-size: 30px;
145
+ font-weight: 600;
146
+ color: var(--text) !important;
147
+ }
148
+
149
+ .hero-subtitle {
150
+ margin: 0;
151
+ color: var(--muted) !important;
152
+ font-size: 15px;
153
+ }
154
+
155
+ .hero-meta {
156
+ display: grid;
157
+ gap: 0.35rem;
158
+ margin-top: 0.85rem;
159
+ }
160
+
161
+ .hero-meta p {
162
+ margin: 0;
163
+ color: var(--muted) !important;
164
+ font-size: 14px;
165
+ }
166
+
167
+ .hero-meta span {
168
+ color: var(--text) !important;
169
+ font-weight: 600;
170
+ margin-right: 0.4rem;
171
+ }
172
+
173
+ .hero-meta code {
174
+ border: 1px solid var(--border);
175
+ border-radius: 8px;
176
+ padding: 0.08rem 0.32rem;
177
+ background: #f7f8fa;
178
+ color: #1f2937;
179
+ }
180
+
181
+ .section-card {
182
+ border: 1px solid var(--border) !important;
183
+ border-radius: 12px !important;
184
+ background: var(--surface) !important;
185
+ box-shadow: 0 1px 2px rgba(17, 24, 39, 0.04) !important;
186
+ padding: 0.9rem !important;
187
+ margin-top: 0.95rem;
188
+ }
189
+
190
+ .input-card,
191
+ .examples-card {
192
+ margin-top: 1rem;
193
+ }
194
+
195
+ #run-btn {
196
+ margin-top: 0.6rem;
197
+ border: 1px solid var(--accent) !important;
198
+ background: var(--accent) !important;
199
+ color: #fff !important;
200
+ border-radius: 8px !important;
201
+ font-weight: 600 !important;
202
+ min-height: 40px !important;
203
+ }
204
+
205
+ #run-btn:hover {
206
+ background: var(--accent-hover) !important;
207
+ }
208
+
209
+ #run-btn:focus-visible {
210
+ outline: none !important;
211
+ box-shadow: 0 0 0 3px var(--focus-ring) !important;
212
+ }
213
+
214
+ #outputs-row {
215
+ gap: 1rem;
216
+ }
217
+
218
+ .output-left,
219
+ .output-right {
220
+ min-height: 300px;
221
+ }
222
+
223
+ .output-left .cm-editor,
224
+ .output-left .cm-scroller,
225
+ .output-right .prose,
226
+ .output-right .markdown {
227
+ background: #fbfbfc !important;
228
+ border: 1px solid var(--border) !important;
229
+ border-radius: 8px !important;
230
+ color: var(--text) !important;
231
+ }
232
+
233
+ .gradio-container .prose,
234
+ .gradio-container .prose *,
235
+ .gradio-container .markdown,
236
+ .gradio-container .markdown * {
237
+ color: var(--text) !important;
238
+ }
239
+
240
+ .output-left .wrap,
241
+ .output-right .wrap {
242
+ min-height: 240px;
243
+ }
244
+
245
+ .output-left label,
246
+ .output-right label {
247
+ color: var(--text) !important;
248
+ font-weight: 600 !important;
249
+ font-size: 15px !important;
250
+ }
251
+
252
+ .examples-card .label-wrap span {
253
+ color: var(--text) !important;
254
+ font-weight: 600 !important;
255
+ font-size: 15px !important;
256
+ }
257
+
258
+ .examples-card .dataset-item {
259
+ color: var(--muted) !important;
260
+ border: 1px solid var(--border) !important;
261
+ border-radius: 8px !important;
262
+ background: #fafbfc !important;
263
+ }
264
+
265
+ .gradio-container textarea,
266
+ .gradio-container input[type="text"] {
267
+ border: 1px solid var(--border) !important;
268
+ border-radius: 8px !important;
269
+ background: #ffffff !important;
270
+ color: var(--text) !important;
271
+ }
272
+
273
+ .gradio-container textarea:focus,
274
+ .gradio-container input[type="text"]:focus {
275
+ border-color: var(--focus) !important;
276
+ box-shadow: 0 0 0 3px var(--focus-ring) !important;
277
+ }
278
+
279
+ .gradio-container .message code,
280
+ .gradio-container code {
281
+ font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace !important;
282
+ }
283
+
284
+ @media (max-width: 900px) {
285
+ #app-shell {
286
+ padding: 1.15rem 0.75rem 1.5rem;
287
+ }
288
+
289
+ .hero h1 {
290
+ font-size: 24px;
291
+ }
292
+
293
+ #outputs-row {
294
+ display: flex;
295
+ flex-direction: column;
296
+ }
297
+ }
298
+ """
299
+
300
+ DETECTOR_SYSTEM_PROMPT = """You are an expert in psychological manipulation and fraud detection.\n\nTASK: Analyze the message through the lens of persuasion techniques to determine if it's a scam.\n\nANALYTICAL FRAMEWORK - Psychological Techniques (PTs):\nThese are common persuasion methods used in both legitimate communication and scams. \nThe key is HOW they're deployed - legitimately or deceptively.\n\nAuthority and Impersonation: Authority and Impersonation | Tend to obey authorities and credible individuals | Person claimed to be calling for Finance America, claiming our home warranty was expired\nPhantom Riches: Phantom Riches | Visceral triggers of desire that override rationality | Your phone Number was randomly selected from the US database and you have won 18,087.71\nFear and Intimidation: Fear and Intimidation | Fear of loss and penalties | You will be arrested!\nLiking: Liking | Preference for saying \u201cyes\u201d to people they like | I am always available to help, and it\u2019s my pleasure to answer any questions you may have\nUrgency and Scarcity: Urgency and Scarcity | Sense of urgency and scarcity assign more value to items | We are currently in urgent need of 100 employees\nPretext and Trust: Pretext and Trust | Tendency to trust credible individuals | This is an urgent message for [MY NAME]. I\u2019m calling regarding a complaint scheduled to be filed out of [Our County Name]\nReciprocity: Reciprocity | Tendency to feel obliged to repay favors from others | We will send you a check to purchase equipment such as new apple laptop and iphone 14 and software\nConsistency: Consistency | Tendency to behave consistently with past behaviors | Starts with small asks (fill a form) and escalate to big asks (invest money)\nSocial Proof: Social Proof | Tendency to refer majority\u2019s behavior to guide own actions | Your resume has been recommended by many online recruitment companies\n\n\nANALYSIS METHOD:\nFor each PT you identify, ask:\n1. Is this technique present? (What specific evidence?)\n2. What is the apparent intent? (Inform, persuade, or deceive?)\n3. Is there verification possible? (Can claims be checked?)\n4. What action is requested? (Reasonable vs suspicious?)\n\nCLASSIFICATION PRINCIPLE:\nA scam typically combines multiple PTs to create a deceptive narrative that:\n- Cannot be verified through official channels\n- Requests irreversible actions (money, credentials)\n- Benefits from victim's emotional response over logical thinking\n\nLegitimate messages may use PTs but:\n- Can be verified independently\n- Follow normal business practices\n- Allow time for consideration\n\nAnalyze the message below. Output JSON with:\n- 'features': {PT_name: evidence_snippet} for all PTs (empty string if absent)\n- 'scam': 1 if deceptive pattern detected, 0 if legitimate use of persuasion\n"""
301
+ EXPLAINER_SYSTEM_PROMPT = """You are an expert at explaining scam detection decisions. Given a message with extracted psychological cues (PTs) and a scam classification, generate a concise explanation.
302
+
303
+ Output format:
304
+ - Write 2–3 cue lines: <Cue>: "<≤3-word quote>" → <plain meaning>.
305
+ - End with one Summary sentence describing the manipulation mechanism (no advice).
306
+
307
+ Allowed cues: Authority, Fear, Urgency, Pretext, Consistency, Reciprocity, Liking, Phantom Riches, Social Proof.
308
+
309
+ Output only the explanation, no extra text."""
310
+
311
+ @lru_cache(maxsize=1)
312
+ def get_models() -> Tuple[Any, Any]:
313
+ has_cuda = torch.cuda.is_available()
314
+ logger.info("Loading models. CUDA available: %s", has_cuda)
315
+
316
+ pipeline_kwargs: Dict[str, Any] = {"trust_remote_code": True}
317
+ if has_cuda:
318
+ pipeline_kwargs["device_map"] = "auto"
319
+ pipeline_kwargs["torch_dtype"] = torch.float16
320
+ logger.info("Using CUDA device: %s", torch.cuda.get_device_name(0))
321
+ else:
322
+ # Explicit CPU mode to keep behavior stable on non-GPU Spaces.
323
+ pipeline_kwargs["device"] = -1
324
+ logger.info("Using CPU mode.")
325
+
326
+ detector = hf_pipeline(
327
+ "text-generation",
328
+ model=DETECTOR_MODEL_ID,
329
+ **pipeline_kwargs,
330
+ )
331
+ explainer = hf_pipeline(
332
+ "text-generation",
333
+ model=EXPLAINER_MODEL_ID,
334
+ **pipeline_kwargs,
335
+ )
336
+ logger.info("Models loaded.")
337
+ return detector, explainer
338
+
339
+
340
+ def _extract_text(generation_output: Any) -> str:
341
+ if isinstance(generation_output, list) and generation_output:
342
+ first = generation_output[0]
343
+ if isinstance(first, dict):
344
+ generated = first.get("generated_text", "")
345
+ if isinstance(generated, list) and generated:
346
+ last = generated[-1]
347
+ if isinstance(last, dict):
348
+ return str(last.get("content", "")).strip()
349
+ return str(generated).strip()
350
+ if isinstance(generation_output, str):
351
+ return generation_output.strip()
352
+ return str(generation_output).strip()
353
+
354
+
355
+ def _extract_json_object(text: str) -> Dict[str, Any]:
356
+ match = re.search(r"\{.*\}", text, re.DOTALL)
357
+ if not match:
358
+ return {}
359
+ candidate = match.group(0)
360
+ try:
361
+ parsed = json.loads(candidate)
362
+ if isinstance(parsed, dict):
363
+ return parsed
364
+ except json.JSONDecodeError:
365
+ pass
366
+ return {}
367
+
368
+
369
+ def _build_prompt(generator: Any, system_prompt: str, user_prompt: str, thinking: bool) -> str:
370
+ tokenizer = generator.tokenizer
371
+ cache_key = (id(tokenizer), system_prompt, thinking)
372
+ template = _PROMPT_TEMPLATE_CACHE.get(cache_key)
373
+
374
+ if template is None:
375
+ messages = [
376
+ {"role": "system", "content": system_prompt},
377
+ {"role": "user", "content": USER_PLACEHOLDER},
378
+ ]
379
+ template = tokenizer.apply_chat_template(
380
+ messages,
381
+ tokenize=False,
382
+ add_generation_prompt=True,
383
+ enable_thinking=thinking
384
+ )
385
+ _PROMPT_TEMPLATE_CACHE[cache_key] = template
386
+ logger.info("Cached prompt template for tokenizer id=%s", id(tokenizer))
387
+
388
+ return template.replace(USER_PLACEHOLDER, user_prompt, 1)
389
+
390
+
391
+ def _move_inputs_to_generator_device(generator: Any, encoded: Any) -> Any:
392
+ device = getattr(generator, "device", None)
393
+ if device is None:
394
+ return encoded
395
+ # device_map="auto" commonly uses cuda:0 as entry device for inputs.
396
+ if str(device) == "cpu":
397
+ return encoded
398
+ return encoded.to(device)
399
+
400
+
401
+ def _get_generator_device(generator: Any) -> torch.device:
402
+ device = getattr(generator, "device", None)
403
+ if device is None:
404
+ return torch.device("cpu")
405
+ return torch.device(str(device))
406
+
407
+
408
+ def _tensor_tree_map(obj: Any, fn: Any) -> Any:
409
+ if torch.is_tensor(obj):
410
+ return fn(obj)
411
+ if isinstance(obj, tuple):
412
+ return tuple(_tensor_tree_map(item, fn) for item in obj)
413
+ if isinstance(obj, list):
414
+ return [_tensor_tree_map(item, fn) for item in obj]
415
+ if isinstance(obj, dict):
416
+ return {k: _tensor_tree_map(v, fn) for k, v in obj.items()}
417
+ return obj
418
+
419
+
420
+ def _move_past_key_values_to_device(past_key_values: Any, device: torch.device) -> Any:
421
+ return _tensor_tree_map(past_key_values, lambda t: t.to(device))
422
+
423
+
424
+ def _cpu_clone_past_key_values(past_key_values: Any) -> Any:
425
+ return _tensor_tree_map(past_key_values, lambda t: t.detach().to("cpu"))
426
+
427
+
428
+ def _sha256_text(text: str) -> str:
429
+ return hashlib.sha256(text.encode("utf-8")).hexdigest()
430
+
431
+
432
+ def _sanitize_key(key: str) -> str:
433
+ return re.sub(r"[^a-zA-Z0-9._-]+", "_", key)
434
+
435
+
436
+ def _get_prompt_parts(generator: Any, system_prompt: str, thinking: bool) -> Tuple[str, str]:
437
+ tokenizer = generator.tokenizer
438
+ cache_key = (id(tokenizer), system_prompt, thinking)
439
+ template = _PROMPT_TEMPLATE_CACHE.get(cache_key)
440
+
441
+ if template is None:
442
+ messages = [
443
+ {"role": "system", "content": system_prompt},
444
+ {"role": "user", "content": USER_PLACEHOLDER},
445
+ ]
446
+ template = tokenizer.apply_chat_template(
447
+ messages,
448
+ tokenize=False,
449
+ add_generation_prompt=True,
450
+ enable_thinking=thinking
451
+ )
452
+ _PROMPT_TEMPLATE_CACHE[cache_key] = template
453
+ logger.info("Cached prompt template for tokenizer id=%s", id(tokenizer))
454
+
455
+ if USER_PLACEHOLDER not in template:
456
+ return template, ""
457
+ return template.split(USER_PLACEHOLDER, 1)
458
+
459
+
460
+ def _disk_cache_paths(generator: Any, system_prompt: str, thinking: bool) -> Tuple[Path, Path]:
461
+ model_name = getattr(generator.model.config, "_name_or_path", "unknown_model")
462
+ tokenizer_name = getattr(generator.tokenizer, "name_or_path", "unknown_tokenizer")
463
+ prompt_hash = _sha256_text(system_prompt)
464
+ thinking_tag = "thinking1" if thinking else "thinking0"
465
+ schema_tag = f"schema{KV_CACHE_SCHEMA_VERSION}"
466
+ base_name = _sanitize_key(
467
+ f"{model_name}__{tokenizer_name}__{schema_tag}__{thinking_tag}__{prompt_hash[:16]}"
468
+ )
469
+ return KV_CACHE_DIR / f"{base_name}.pt", KV_CACHE_DIR / f"{base_name}.meta.json"
470
+
471
+
472
+ def _load_prefix_kv_from_disk(
473
+ generator: Any,
474
+ system_prompt: str,
475
+ prefix_hash: str,
476
+ suffix_hash: str,
477
+ thinking: bool,
478
+ ) -> Tuple[Any, int] | None:
479
+ if not ENABLE_DISK_KV_CACHE:
480
+ return None
481
+
482
+ pt_path, meta_path = _disk_cache_paths(generator, system_prompt, thinking)
483
+ if not pt_path.exists() or not meta_path.exists():
484
+ return None
485
+
486
+ try:
487
+ with meta_path.open("r", encoding="utf-8") as f:
488
+ meta = json.load(f)
489
+
490
+ expected = {
491
+ "kv_cache_schema_version": KV_CACHE_SCHEMA_VERSION,
492
+ "transformers_version": transformers_version,
493
+ "system_prompt_hash": _sha256_text(system_prompt),
494
+ "prefix_hash": prefix_hash,
495
+ "suffix_hash": suffix_hash,
496
+ "thinking": thinking,
497
+ }
498
+ for key, expected_value in expected.items():
499
+ if meta.get(key) != expected_value:
500
+ logger.info("[DEBUG] Disk KV cache metadata mismatch on %s; rebuilding cache.", key)
501
+ return None
502
+
503
+ payload = torch.load(pt_path, map_location="cpu")
504
+ past_key_values = payload.get("past_key_values")
505
+ prefix_len = int(payload.get("prefix_len", 0))
506
+ if past_key_values is None or prefix_len <= 0:
507
+ return None
508
+
509
+ runtime_device = _get_generator_device(generator)
510
+ past_key_values = _move_past_key_values_to_device(past_key_values, runtime_device)
511
+ logger.info("[DEBUG] Loaded prefix KV cache from disk: %s", pt_path)
512
+ return past_key_values, prefix_len
513
+ except Exception as exc:
514
+ logger.warning("[DEBUG] Failed to load disk KV cache (%s): %s", pt_path, exc)
515
+ return None
516
+
517
+
518
+ def _save_prefix_kv_to_disk(
519
+ generator: Any,
520
+ system_prompt: str,
521
+ prefix_hash: str,
522
+ suffix_hash: str,
523
+ past_key_values: Any,
524
+ prefix_len: int,
525
+ thinking: bool,
526
+ ) -> None:
527
+ if not ENABLE_DISK_KV_CACHE:
528
+ return
529
+
530
+ pt_path, meta_path = _disk_cache_paths(generator, system_prompt, thinking)
531
+ model_name = getattr(generator.model.config, "_name_or_path", "unknown_model")
532
+ tokenizer_name = getattr(generator.tokenizer, "name_or_path", "unknown_tokenizer")
533
+
534
+ try:
535
+ KV_CACHE_DIR.mkdir(parents=True, exist_ok=True)
536
+ cpu_past = _cpu_clone_past_key_values(past_key_values)
537
+ payload = {
538
+ "past_key_values": cpu_past,
539
+ "prefix_len": prefix_len,
540
+ }
541
+ meta = {
542
+ "created_at_unix": int(time.time()),
543
+ "kv_cache_schema_version": KV_CACHE_SCHEMA_VERSION,
544
+ "transformers_version": transformers_version,
545
+ "model_name_or_path": model_name,
546
+ "tokenizer_name_or_path": tokenizer_name,
547
+ "system_prompt_hash": _sha256_text(system_prompt),
548
+ "prefix_hash": prefix_hash,
549
+ "suffix_hash": suffix_hash,
550
+ "thinking": thinking,
551
+ }
552
+ torch.save(payload, pt_path)
553
+ with meta_path.open("w", encoding="utf-8") as f:
554
+ json.dump(meta, f)
555
+ logger.info("[DEBUG] Saved prefix KV cache to disk: %s", pt_path)
556
+ except Exception as exc:
557
+ logger.warning("[DEBUG] Failed to save disk KV cache (%s): %s", pt_path, exc)
558
+
559
+
560
+ def _get_prefix_kv(generator: Any, system_prompt: str, thinking: bool) -> Tuple[Any, int, str]:
561
+ model = generator.model
562
+ tokenizer = generator.tokenizer
563
+ cache_key = (id(model), system_prompt, thinking)
564
+ cached = _PREFIX_KV_CACHE.get(cache_key)
565
+ if cached is not None:
566
+ logger.info(
567
+ "[DEBUG] Prefix KV cache source=memory_hit model_id=%s thinking=%s prefix_tokens=%s",
568
+ id(model),
569
+ thinking,
570
+ cached[1],
571
+ )
572
+ return cached
573
+
574
+ prefix, suffix = _get_prompt_parts(generator, system_prompt, thinking)
575
+ prefix_hash = _sha256_text(prefix)
576
+ suffix_hash = _sha256_text(suffix)
577
+ logger.info(
578
+ "[DEBUG] Prefix KV cache source=memory_miss model_id=%s thinking=%s; checking disk",
579
+ id(model),
580
+ thinking,
581
+ )
582
+
583
+ disk_cache = _load_prefix_kv_from_disk(
584
+ generator,
585
+ system_prompt,
586
+ prefix_hash,
587
+ suffix_hash,
588
+ thinking,
589
+ )
590
+ if disk_cache is not None:
591
+ past_key_values, prefix_len = disk_cache
592
+ _PREFIX_KV_CACHE[cache_key] = (past_key_values, prefix_len, suffix)
593
+ logger.info(
594
+ "[DEBUG] Prefix KV cache source=disk_hit model_id=%s thinking=%s prefix_tokens=%s",
595
+ id(model),
596
+ thinking,
597
+ prefix_len,
598
+ )
599
+ return _PREFIX_KV_CACHE[cache_key]
600
+
601
+ logger.info(
602
+ "[DEBUG] Prefix KV cache source=disk_miss model_id=%s thinking=%s; recomputing",
603
+ id(model),
604
+ thinking,
605
+ )
606
+
607
+ encoded_prefix = tokenizer(prefix, return_tensors="pt")
608
+ encoded_prefix = _move_inputs_to_generator_device(generator, encoded_prefix)
609
+
610
+ with torch.inference_mode():
611
+ outputs = model(**encoded_prefix, use_cache=True)
612
+
613
+ prefix_len = int(encoded_prefix["input_ids"].shape[1])
614
+ past_key_values = outputs.past_key_values
615
+ _PREFIX_KV_CACHE[cache_key] = (past_key_values, prefix_len, suffix)
616
+ _save_prefix_kv_to_disk(
617
+ generator=generator,
618
+ system_prompt=system_prompt,
619
+ prefix_hash=prefix_hash,
620
+ suffix_hash=suffix_hash,
621
+ past_key_values=past_key_values,
622
+ prefix_len=prefix_len,
623
+ thinking=thinking,
624
+ )
625
+ logger.info(
626
+ "[DEBUG] Prefix KV cache source=recompute model_id=%s thinking=%s prefix_tokens=%s",
627
+ id(model),
628
+ thinking,
629
+ prefix_len,
630
+ )
631
+ return _PREFIX_KV_CACHE[cache_key]
632
+
633
+
634
+ def _resolve_eos_ids(generator: Any) -> Any:
635
+ tokenizer = generator.tokenizer
636
+ eos_ids = []
637
+
638
+ default_eos = getattr(tokenizer, "eos_token_id", None)
639
+ if default_eos is not None:
640
+ eos_ids.append(default_eos)
641
+
642
+ stop_id = tokenizer.convert_tokens_to_ids(STOP_TOKEN)
643
+ unk_id = getattr(tokenizer, "unk_token_id", None)
644
+ if stop_id is not None and stop_id >= 0 and stop_id != unk_id and stop_id not in eos_ids:
645
+ eos_ids.append(stop_id)
646
+
647
+ if not eos_ids:
648
+ return None
649
+ if len(eos_ids) == 1:
650
+ return eos_ids[0]
651
+ return eos_ids
652
+
653
+
654
+ def _generate_text_stream(
655
+ generator: Any,
656
+ system_prompt: str,
657
+ user_prompt: str,
658
+ max_new_tokens: int,
659
+ thinking: bool,
660
+ task_name: str = "generation",
661
+ force_full_prompt: bool = False,
662
+ ) -> Iterator[str]:
663
+ eos_ids = _resolve_eos_ids(generator)
664
+ pad_token_id = getattr(generator.tokenizer, "pad_token_id", None)
665
+ if pad_token_id is None:
666
+ pad_token_id = getattr(generator.tokenizer, "eos_token_id", None)
667
+
668
+ t0 = time.perf_counter()
669
+ generate_kwargs: Dict[str, Any]
670
+ path_label = "prefix_kv"
671
+ kv_cache_applied = True
672
+
673
+ logger.info(
674
+ "[DEBUG] [%s] Generation started (max_new_tokens=%s, thinking=%s)",
675
+ task_name,
676
+ max_new_tokens,
677
+ thinking,
678
+ )
679
+
680
+ try:
681
+ if force_full_prompt:
682
+ raise RuntimeError("force_full_prompt enabled")
683
+
684
+ past_key_values, prefix_len, suffix = _get_prefix_kv(generator, system_prompt, thinking)
685
+ dynamic_prompt = f"{user_prompt}{suffix}"
686
+ encoded_dynamic = generator.tokenizer(dynamic_prompt, return_tensors="pt")
687
+ encoded_dynamic = _move_inputs_to_generator_device(generator, encoded_dynamic)
688
+ dynamic_len = int(encoded_dynamic["input_ids"].shape[1])
689
+ if dynamic_len <= 0:
690
+ raise RuntimeError("Dynamic prompt tokenized to 0 tokens on prefix-KV path.")
691
+
692
+ attention_mask = torch.ones(
693
+ (1, prefix_len + dynamic_len),
694
+ dtype=encoded_dynamic["attention_mask"].dtype,
695
+ device=encoded_dynamic["attention_mask"].device,
696
+ )
697
+ cache_position = torch.arange(
698
+ prefix_len,
699
+ prefix_len + dynamic_len,
700
+ dtype=torch.long,
701
+ device=encoded_dynamic["input_ids"].device,
702
+ )
703
+ logger.info(
704
+ "[DEBUG] [%s] Prefix-KV input lengths: prefix_tokens=%s dynamic_tokens=%s cache_position_len=%s",
705
+ task_name,
706
+ prefix_len,
707
+ dynamic_len,
708
+ int(cache_position.numel()),
709
+ )
710
+
711
+ generate_kwargs = {
712
+ "input_ids": encoded_dynamic["input_ids"],
713
+ "attention_mask": attention_mask,
714
+ "past_key_values": past_key_values,
715
+ "cache_position": cache_position,
716
+ "max_new_tokens": max_new_tokens,
717
+ "do_sample": True,
718
+ "temperature": TEMPERATURE,
719
+ "top_k": TOP_K,
720
+ "top_p": TOP_P,
721
+ "use_cache": True,
722
+ "eos_token_id": eos_ids,
723
+ "pad_token_id": pad_token_id,
724
+ }
725
+ except Exception as exc:
726
+ path_label = "full_prompt"
727
+ kv_cache_applied = False
728
+ if not force_full_prompt:
729
+ logger.warning("[DEBUG] KV-cache path failed, falling back to full prompt path: %s", exc)
730
+ prompt = _build_prompt(generator, system_prompt, user_prompt, thinking)
731
+ encoded_prompt = generator.tokenizer(
732
+ prompt,
733
+ return_tensors="pt",
734
+ truncation=True,
735
+ max_length=NUM_CTX,
736
+ )
737
+ encoded_prompt = _move_inputs_to_generator_device(generator, encoded_prompt)
738
+ generate_kwargs = {
739
+ "input_ids": encoded_prompt["input_ids"],
740
+ "attention_mask": encoded_prompt.get("attention_mask"),
741
+ "max_new_tokens": max_new_tokens,
742
+ "do_sample": True,
743
+ "temperature": TEMPERATURE,
744
+ "top_k": TOP_K,
745
+ "top_p": TOP_P,
746
+ "use_cache": True,
747
+ "eos_token_id": eos_ids,
748
+ "pad_token_id": pad_token_id,
749
+ }
750
+
751
+ streamer = TextIteratorStreamer(
752
+ generator.tokenizer,
753
+ skip_prompt=True,
754
+ skip_special_tokens=False,
755
+ timeout=1.0,
756
+ )
757
+ generate_kwargs["streamer"] = streamer
758
+
759
+ generation_error: Dict[str, Exception] = {}
760
+
761
+ def _worker() -> None:
762
+ try:
763
+ with torch.inference_mode():
764
+ generator.model.generate(**generate_kwargs)
765
+ except Exception as exc:
766
+ generation_error["error"] = exc
767
+
768
+ worker = Thread(target=_worker, daemon=True)
769
+ worker.start()
770
+
771
+ text = ""
772
+ first_token_latency_ms: float | None = None
773
+ stop_seen = False
774
+ while True:
775
+ try:
776
+ chunk = next(streamer)
777
+ except StopIteration:
778
+ break
779
+ except Empty:
780
+ if worker.is_alive():
781
+ continue
782
+ break
783
+
784
+ if stop_seen:
785
+ continue
786
+
787
+ if first_token_latency_ms is None:
788
+ first_token_latency_ms = (time.perf_counter() - t0) * 1000.0
789
+
790
+ text += chunk
791
+ if STOP_TOKEN in text:
792
+ text = text.split(STOP_TOKEN, 1)[0]
793
+ stop_seen = True
794
+ yield text.strip()
795
+
796
+ worker.join()
797
+ if "error" in generation_error:
798
+ elapsed = time.perf_counter() - t0
799
+ logger.error(
800
+ "[DEBUG] [%s] Generation failed after %.2fs (kv_cache_applied=%s, path=%s, first_token_latency_ms=%s, output_chars=%s): %s",
801
+ task_name,
802
+ elapsed,
803
+ kv_cache_applied,
804
+ path_label,
805
+ f"{first_token_latency_ms:.1f}" if first_token_latency_ms is not None else "none",
806
+ len(text),
807
+ generation_error["error"],
808
+ )
809
+ if kv_cache_applied and not force_full_prompt:
810
+ logger.warning("[DEBUG] [%s] Retrying generation with full_prompt path.", task_name)
811
+ yield from _generate_text_stream(
812
+ generator=generator,
813
+ system_prompt=system_prompt,
814
+ user_prompt=user_prompt,
815
+ max_new_tokens=max_new_tokens,
816
+ thinking=thinking,
817
+ task_name=f"{task_name}:full_prompt_retry",
818
+ force_full_prompt=True,
819
+ )
820
+ return
821
+ raise generation_error["error"]
822
+
823
+ elapsed = time.perf_counter() - t0
824
+ logger.info(
825
+ "[DEBUG] [%s] Generation complete in %.2fs (max_new_tokens=%s, kv_cache_applied=%s, path=%s, first_token_latency_ms=%s, output_chars=%s)",
826
+ task_name,
827
+ elapsed,
828
+ max_new_tokens,
829
+ kv_cache_applied,
830
+ path_label,
831
+ f"{first_token_latency_ms:.1f}" if first_token_latency_ms is not None else "none",
832
+ len(text),
833
+ )
834
+
835
+
836
+ def _generate_text(
837
+ generator: Any,
838
+ system_prompt: str,
839
+ user_prompt: str,
840
+ max_new_tokens: int,
841
+ thinking: bool,
842
+ task_name: str = "generation",
843
+ ) -> str:
844
+ final = ""
845
+ for partial in _generate_text_stream(
846
+ generator,
847
+ system_prompt,
848
+ user_prompt,
849
+ max_new_tokens,
850
+ thinking,
851
+ task_name=task_name,
852
+ ):
853
+ final = partial
854
+ return final
855
+
856
+
857
+ def _build_detector_output(raw: str, empty_input: bool = False) -> Dict[str, Any]:
858
+ if empty_input:
859
+ return {
860
+ "label": "invalid_input",
861
+ "score": 0.0,
862
+ "reasoning": "Input text is empty.",
863
+ "raw_output": "",
864
+ }
865
+
866
+ parsed = _extract_json_object(raw)
867
+ if parsed:
868
+ parsed["raw_output"] = raw
869
+ logger.info("Detector step completed with valid JSON.")
870
+ return parsed
871
+
872
+ logger.info("Detector step completed without valid JSON.")
873
+ return {
874
+ "label": "unknown",
875
+ "score": None,
876
+ "reasoning": "Detector did not return valid JSON.",
877
+ "raw_output": raw,
878
+ }
879
+
880
+
881
+ def run_detector_stream(text: str, task_name: str = "detector") -> Iterator[str]:
882
+ cleaned = text.strip()
883
+ if not cleaned:
884
+ return
885
+
886
+ detector, _ = get_models()
887
+ user_prompt = f"Message: {cleaned}"
888
+ yield from _generate_text_stream(
889
+ detector,
890
+ system_prompt=DETECTOR_SYSTEM_PROMPT,
891
+ user_prompt=user_prompt,
892
+ max_new_tokens=MAX_NEW_TOKENS_DETECTOR,
893
+ thinking=True,
894
+ task_name=task_name,
895
+ )
896
+
897
+
898
+ def run_detector(text: str) -> Dict[str, Any]:
899
+ logger.info("Detector step started.")
900
+ cleaned = text.strip()
901
+ if not cleaned:
902
+ logger.info("Detector step skipped: empty input.")
903
+ return _build_detector_output(raw="", empty_input=True)
904
+
905
+ raw = ""
906
+ for partial in run_detector_stream(cleaned, task_name="detector"):
907
+ raw = partial
908
+ return _build_detector_output(raw=raw, empty_input=False)
909
+
910
+
911
+ def _fallback_explanation(detector_output: Dict[str, Any]) -> str:
912
+ scam = detector_output.get("scam", detector_output.get("label", "unknown"))
913
+ features = detector_output.get("features", {})
914
+ non_empty_cues = []
915
+ if isinstance(features, dict):
916
+ for cue, evidence in features.items():
917
+ if str(evidence).strip():
918
+ non_empty_cues.append((cue, str(evidence).strip()))
919
+
920
+ lines = [f"Summary: detector predicts {scam}."]
921
+ if non_empty_cues:
922
+ for cue, evidence in non_empty_cues[:3]:
923
+ lines.append(f"{cue}: {evidence[:120]}")
924
+ else:
925
+ lines.append("No strong cues were provided by the detector output.")
926
+ return "\n".join(lines)
927
+
928
+
929
+ def run_explainer_stream(
930
+ detector_output: Dict[str, Any],
931
+ simplified_prompt: bool = False,
932
+ task_name: str = "explainer",
933
+ ) -> Iterator[str]:
934
+ _, explainer = get_models()
935
+ user_prompt = (
936
+ json.dumps(detector_output, ensure_ascii=True)
937
+ if simplified_prompt
938
+ else json.dumps(detector_output, ensure_ascii=True, indent=2)
939
+ )
940
+ max_tokens = (
941
+ max(96, min(256, MAX_NEW_TOKENS_EXPLAINER))
942
+ if simplified_prompt
943
+ else MAX_NEW_TOKENS_EXPLAINER
944
+ )
945
+ yield from _generate_text_stream(
946
+ explainer,
947
+ system_prompt=EXPLAINER_SYSTEM_PROMPT,
948
+ user_prompt=user_prompt,
949
+ max_new_tokens=max_tokens,
950
+ thinking=False,
951
+ task_name=task_name,
952
+ )
953
+
954
+
955
+ def run_explainer(text: str, detector_output: Dict[str, Any]) -> str:
956
+ logger.info("Explainer step started.")
957
+ del text # explainer user prompt should be detector output only
958
+
959
+ explanation = ""
960
+ for partial in run_explainer_stream(detector_output, simplified_prompt=False, task_name="explainer"):
961
+ explanation = partial
962
+ if explanation.strip():
963
+ return explanation
964
+
965
+ logger.warning("Explainer returned empty text; retrying with simplified prompt.")
966
+ retry = ""
967
+ for partial in run_explainer_stream(detector_output, simplified_prompt=True, task_name="explainer_retry"):
968
+ retry = partial
969
+ if retry.strip():
970
+ return retry
971
+
972
+ logger.warning("Explainer retry also empty; using deterministic fallback explanation.")
973
+ return _fallback_explanation(detector_output)
974
+
975
+
976
+ def pipeline(text: str) -> Iterator[Tuple[str, str]]:
977
+ req_id = f"req-{int(time.time() * 1000)}"
978
+ started = time.perf_counter()
979
+ logger.info("[%s] Pipeline started.", req_id)
980
+
981
+ detector_render = ""
982
+ explainer_render = ""
983
+ yield detector_render, explainer_render
984
+
985
+ cleaned = text.strip()
986
+ if not cleaned:
987
+ detector_output = _build_detector_output(raw="", empty_input=True)
988
+ detector_render = json.dumps(detector_output, ensure_ascii=True, indent=2)
989
+ explainer_render = _fallback_explanation(detector_output)
990
+ yield detector_render, explainer_render
991
+ elapsed = time.perf_counter() - started
992
+ logger.info("[%s] Pipeline finished in %.2fs", req_id, elapsed)
993
+ return
994
+
995
+ logger.info("[%s] Detector stream started.", req_id)
996
+ detector_raw = ""
997
+ for partial in run_detector_stream(cleaned, task_name=f"{req_id}:detector"):
998
+ detector_raw = partial
999
+ detector_render = detector_raw
1000
+ yield detector_render, explainer_render
1001
+
1002
+ detector_output = _build_detector_output(raw=detector_raw, empty_input=False)
1003
+ detector_render = json.dumps(detector_output, ensure_ascii=True, indent=2)
1004
+ yield detector_render, explainer_render
1005
+
1006
+ logger.info("[%s] Explainer stream started.", req_id)
1007
+ explanation = ""
1008
+ for partial in run_explainer_stream(
1009
+ detector_output,
1010
+ simplified_prompt=False,
1011
+ task_name=f"{req_id}:explainer",
1012
+ ):
1013
+ explanation = partial
1014
+ explainer_render = explanation
1015
+ yield detector_render, explainer_render
1016
+
1017
+ if not explanation.strip():
1018
+ logger.warning("[%s] Explainer empty; retrying with simplified prompt.", req_id)
1019
+ retry = ""
1020
+ for partial in run_explainer_stream(
1021
+ detector_output,
1022
+ simplified_prompt=True,
1023
+ task_name=f"{req_id}:explainer_retry",
1024
+ ):
1025
+ retry = partial
1026
+ explainer_render = retry
1027
+ yield detector_render, explainer_render
1028
+ explanation = retry
1029
+
1030
+ if not explanation.strip():
1031
+ logger.warning("[%s] Explainer still empty; using fallback.", req_id)
1032
+ explanation = _fallback_explanation(detector_output)
1033
+ explainer_render = explanation
1034
+ yield detector_render, explainer_render
1035
+
1036
+ yield detector_render, explainer_render
1037
+ elapsed = time.perf_counter() - started
1038
+ logger.info("[%s] Pipeline finished in %.2fs", req_id, elapsed)
1039
+
1040
+
1041
+ def _force_clean_kv_cache_dir() -> None:
1042
+ if not FORCE_CLEAN_KV_CACHE_ON_STARTUP:
1043
+ return
1044
+ if not KV_CACHE_DIR.exists():
1045
+ logger.info("[DEBUG] KV cache clean skipped: directory not found (%s).", KV_CACHE_DIR)
1046
+ return
1047
+
1048
+ removed = 0
1049
+ for path in KV_CACHE_DIR.glob("*"):
1050
+ if path.suffix not in {".pt", ".json"}:
1051
+ continue
1052
+ try:
1053
+ path.unlink(missing_ok=True)
1054
+ removed += 1
1055
+ except Exception as exc:
1056
+ logger.warning("[DEBUG] Failed to remove KV cache file %s: %s", path, exc)
1057
+ logger.info("[DEBUG] Force-cleaned KV cache files on startup: removed=%s dir=%s", removed, KV_CACHE_DIR)
1058
+
1059
+
1060
+ def warmup_prefix_kv_cache() -> None:
1061
+ if not WARMUP_ON_STARTUP:
1062
+ logger.info("Startup warmup disabled (WARMUP_ON_STARTUP=0).")
1063
+ return
1064
+
1065
+ _force_clean_kv_cache_dir()
1066
+ logger.info("Startup warmup started. KV cache dir: %s", KV_CACHE_DIR)
1067
+ try:
1068
+ detector, explainer = get_models()
1069
+ _get_prefix_kv(detector, DETECTOR_SYSTEM_PROMPT, thinking=True)
1070
+ _get_prefix_kv(explainer, EXPLAINER_SYSTEM_PROMPT, thinking=False)
1071
+ logger.info("Startup warmup completed.")
1072
+ except Exception as exc:
1073
+ # Keep service available even if warmup fails.
1074
+ logger.warning("Startup warmup failed: %s", exc)
1075
+
1076
+
1077
+ with gr.Blocks(title="SPADE Demo API", css=CUSTOM_CSS) as demo:
1078
+ with gr.Column(elem_id="app-shell"):
1079
+ gr.Markdown(
1080
+ f"""
1081
+ <div class="hero">
1082
+ <h1>SPADE Detector + Explainer</h1>
1083
+ <p class="hero-subtitle">
1084
+ A paper demo for psychological scam detection and explanation.
1085
+ The detector output appears at bottom-left and the explainer output at bottom-right.
1086
+ </p>
1087
+ <div class="hero-meta">
1088
+ <p><span>Detector model:</span> <code>{DETECTOR_MODEL_ID}</code></p>
1089
+ <p><span>Explainer model:</span> <code>{EXPLAINER_MODEL_ID}</code></p>
1090
+ <p><span>Runtime note:</span> this Space is currently running on CPU, so inference is slower than GPU.</p>
1091
+ </div>
1092
+ </div>
1093
+ """
1094
+ )
1095
+
1096
+ with gr.Group(elem_classes=["section-card", "input-card"]):
1097
+ input_text = gr.Textbox(
1098
+ label="Input message x",
1099
+ lines=6,
1100
+ placeholder="Paste or type a message to analyze...",
1101
+ )
1102
+ run_btn = gr.Button("Run Pipeline", elem_id="run-btn")
1103
+
1104
+ with gr.Row(elem_id="outputs-row", equal_height=True):
1105
+ with gr.Column(scale=1, min_width=360):
1106
+ with gr.Group(elem_classes=["section-card", "output-left"]):
1107
+ detector_json = gr.Code(label="Detector Output", language="json")
1108
+ with gr.Column(scale=1, min_width=360):
1109
+ with gr.Group(elem_classes=["section-card", "output-right"]):
1110
+ explainer_md = gr.Markdown(label="Explainer Output")
1111
+
1112
+ with gr.Group(elem_classes=["section-card", "examples-card"]):
1113
+ gr.Examples(
1114
+ examples=[
1115
+ "this is Oscar Walden with location services contacting you in reference to a pending claim being issued against your name requesting a signature I do need to make your work phone number QJR19680 is finalized there are no longer being an opportunity to contact the office processing your claim this sort of location requires a signature service to take place at your home worker just due to the Sonia and we're getting this matter I'm providing with the filing parties information one last time the number to contact is 877-595-5588 if the filing party isn't contacted I have no choice but to move forward with your order location you need to be available to provide a signature",
1116
+ ],
1117
+ inputs=input_text,
1118
+ )
1119
+
1120
+ run_btn.click(
1121
+ fn=pipeline,
1122
+ inputs=input_text,
1123
+ outputs=[detector_json, explainer_md],
1124
+ api_name="pipeline",
1125
+ )
1126
+
1127
+
1128
+ if __name__ == "__main__":
1129
+ warmup_prefix_kv_cache()
1130
+ demo.queue().launch()
docs/app.js ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { Client } from "https://cdn.jsdelivr.net/npm/@gradio/client/+esm";
2
+
3
+ const spaceInput = document.getElementById("spaceId");
4
+ const messageInput = document.getElementById("message");
5
+ const runBtn = document.getElementById("runBtn");
6
+ const detectorOutput = document.getElementById("detectorOutput");
7
+ const explainerOutput = document.getElementById("explainerOutput");
8
+ const statusEl = document.getElementById("status");
9
+
10
+ const initialSpaceId = window.SPADE_CONFIG?.spaceId || "";
11
+ spaceInput.value = initialSpaceId;
12
+
13
+ let clientCache = null;
14
+ let cachedSpaceId = null;
15
+
16
+ function setStatus(message, isError = false) {
17
+ statusEl.textContent = message;
18
+ statusEl.className = isError ? "status error" : "status";
19
+ }
20
+
21
+ async function getClient(spaceId) {
22
+ if (clientCache && cachedSpaceId === spaceId) {
23
+ return clientCache;
24
+ }
25
+ clientCache = await Client.connect(spaceId);
26
+ cachedSpaceId = spaceId;
27
+ return clientCache;
28
+ }
29
+
30
+ async function runPipeline() {
31
+ const spaceId = spaceInput.value.trim();
32
+ const text = messageInput.value.trim();
33
+
34
+ if (!spaceId) {
35
+ setStatus("Please set your Hugging Face Space ID.", true);
36
+ return;
37
+ }
38
+ if (!text) {
39
+ setStatus("Please enter input text.", true);
40
+ return;
41
+ }
42
+
43
+ runBtn.disabled = true;
44
+ setStatus("Running...");
45
+
46
+ try {
47
+ const app = await getClient(spaceId);
48
+ const result = await app.predict("/pipeline", { text });
49
+ const [detector, explanation] = result.data;
50
+
51
+ detectorOutput.textContent = JSON.stringify(detector, null, 2);
52
+ explainerOutput.textContent = explanation;
53
+ setStatus("Done.");
54
+ } catch (err) {
55
+ setStatus(`Request failed: ${err.message || String(err)}`, true);
56
+ } finally {
57
+ runBtn.disabled = false;
58
+ }
59
+ }
60
+
61
+ runBtn.addEventListener("click", runPipeline);
docs/config.js ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ window.SPADE_CONFIG = {
2
+ // Example: "eliwang2332/spade-demo"
3
+ spaceId: "",
4
+ };
docs/index.html ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
+ <title>SPADE Demo</title>
7
+ <link rel="stylesheet" href="./styles.css" />
8
+ </head>
9
+ <body>
10
+ <main class="page">
11
+ <header class="hero">
12
+ <p class="kicker">SPADE Demo</p>
13
+ <h1>Detector + Explainer</h1>
14
+ <p class="subtitle">Detector output feeds directly into explainer output.</p>
15
+ </header>
16
+
17
+ <section class="card setup-card">
18
+ <div class="section-head">
19
+ <h2>Connection</h2>
20
+ <p class="hint">Set a default in <code>/docs/config.js</code>.</p>
21
+ </div>
22
+ <label for="spaceId">Hugging Face Space ID</label>
23
+ <input id="spaceId" type="text" placeholder="your-username/your-space-name" />
24
+ </section>
25
+
26
+ <section class="card run-card">
27
+ <div class="section-head">
28
+ <h2>Input Message</h2>
29
+ </div>
30
+ <label for="message">Input message x</label>
31
+ <textarea id="message" rows="8" placeholder="Type a message..."></textarea>
32
+ <button id="runBtn" type="button">Run Detector + Explainer</button>
33
+ </section>
34
+
35
+ <section class="outputs-grid">
36
+ <article class="card output-card">
37
+ <h2>Detector Output</h2>
38
+ <pre id="detectorOutput">{}</pre>
39
+ </article>
40
+
41
+ <article class="card output-card">
42
+ <h2>Explainer Output</h2>
43
+ <div id="explainerOutput" class="explainer">No output yet.</div>
44
+ </article>
45
+ </section>
46
+
47
+ <p id="status" class="status"></p>
48
+ </main>
49
+
50
+ <script src="./config.js"></script>
51
+ <script type="module" src="./app.js"></script>
52
+ </body>
53
+ </html>
docs/styles.css ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ --bg: #f4f5f7;
3
+ --surface: #ffffff;
4
+ --text: #111827;
5
+ --muted: #6b7280;
6
+ --accent: #245fa8;
7
+ --accent-hover: #1f4f8a;
8
+ --border: #d9dde3;
9
+ --danger: #b42318;
10
+ --focus: #245fa8;
11
+ --focus-ring: rgba(36, 95, 168, 0.18);
12
+ }
13
+
14
+ * {
15
+ box-sizing: border-box;
16
+ }
17
+
18
+ body {
19
+ margin: 0;
20
+ background: var(--bg);
21
+ color: var(--text);
22
+ font-family: "Helvetica Neue", Helvetica, Arial, sans-serif;
23
+ font-size: 15px;
24
+ line-height: 1.45;
25
+ }
26
+
27
+ .page {
28
+ max-width: 1040px;
29
+ margin: 0 auto;
30
+ padding: 32px 16px 40px;
31
+ }
32
+
33
+ .hero {
34
+ margin-bottom: 12px;
35
+ }
36
+
37
+ .kicker {
38
+ margin: 0 0 8px;
39
+ color: var(--muted);
40
+ font-size: 13px;
41
+ font-weight: 600;
42
+ letter-spacing: -0.01em;
43
+ }
44
+
45
+ h1 {
46
+ margin: 0;
47
+ font-size: 30px;
48
+ font-weight: 600;
49
+ letter-spacing: -0.01em;
50
+ }
51
+
52
+ .subtitle {
53
+ margin: 10px 0 0;
54
+ color: var(--muted);
55
+ }
56
+
57
+ .card {
58
+ background: var(--surface);
59
+ border: 1px solid var(--border);
60
+ border-radius: 12px;
61
+ padding: 16px;
62
+ margin-top: 16px;
63
+ }
64
+
65
+ .section-head {
66
+ display: flex;
67
+ align-items: baseline;
68
+ justify-content: space-between;
69
+ gap: 12px;
70
+ flex-wrap: wrap;
71
+ margin-bottom: 10px;
72
+ }
73
+
74
+ h2 {
75
+ margin: 0;
76
+ font-size: 16px;
77
+ font-weight: 600;
78
+ letter-spacing: -0.01em;
79
+ }
80
+
81
+ label {
82
+ display: block;
83
+ margin-bottom: 8px;
84
+ font-size: 14px;
85
+ font-weight: 600;
86
+ letter-spacing: -0.01em;
87
+ }
88
+
89
+ input,
90
+ textarea {
91
+ width: 100%;
92
+ border: 1px solid var(--border);
93
+ border-radius: 8px;
94
+ padding: 10px 12px;
95
+ font: inherit;
96
+ color: var(--text);
97
+ background: #fff;
98
+ }
99
+
100
+ textarea {
101
+ resize: vertical;
102
+ min-height: 140px;
103
+ }
104
+
105
+ input:focus,
106
+ textarea:focus,
107
+ button:focus-visible {
108
+ outline: none;
109
+ border-color: var(--focus);
110
+ box-shadow: 0 0 0 3px var(--focus-ring);
111
+ }
112
+
113
+ button {
114
+ margin-top: 12px;
115
+ border: 1px solid var(--accent);
116
+ border-radius: 8px;
117
+ min-height: 40px;
118
+ padding: 9px 14px;
119
+ color: #fff;
120
+ background: var(--accent);
121
+ font: inherit;
122
+ font-weight: 600;
123
+ cursor: pointer;
124
+ }
125
+
126
+ button:hover {
127
+ background: var(--accent-hover);
128
+ }
129
+
130
+ button:disabled {
131
+ opacity: 0.72;
132
+ cursor: default;
133
+ }
134
+
135
+ .outputs-grid {
136
+ margin-top: 4px;
137
+ display: grid;
138
+ grid-template-columns: 1fr 1fr;
139
+ gap: 16px;
140
+ }
141
+
142
+ .output-card {
143
+ min-height: 320px;
144
+ }
145
+
146
+ pre,
147
+ .explainer {
148
+ margin: 12px 0 0;
149
+ padding: 12px;
150
+ border-radius: 8px;
151
+ border: 1px solid var(--border);
152
+ background: #fbfbfc;
153
+ white-space: pre-wrap;
154
+ word-break: break-word;
155
+ }
156
+
157
+ pre {
158
+ min-height: 244px;
159
+ }
160
+
161
+ .explainer {
162
+ min-height: 244px;
163
+ }
164
+
165
+ .hint,
166
+ .status {
167
+ font-size: 14px;
168
+ color: var(--muted);
169
+ }
170
+
171
+ .status {
172
+ margin: 16px 2px 0;
173
+ }
174
+
175
+ .status.error {
176
+ color: var(--danger);
177
+ }
178
+
179
+ code {
180
+ font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;
181
+ }
182
+
183
+ @media (max-width: 900px) {
184
+ .page {
185
+ padding: 20px 12px 28px;
186
+ }
187
+
188
+ h1 {
189
+ font-size: 24px;
190
+ }
191
+
192
+ .outputs-grid {
193
+ grid-template-columns: 1fr;
194
+ }
195
+ }
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio==6.5.1
2
+ transformers>=4.45.0
3
+ torch>=2.2.0
4
+ accelerate>=0.30.0
5
+ safetensors>=0.4.3