Qwen3_GPTQ_int4 / app.py
Ziming
thinking process up15
48bb0d3
import os
import time
import warnings
import uuid
from html import escape
import re
# (optional) Disable experimental SSR to avoid Node server restarts
os.environ.setdefault("GRADIO_SERVER_SSR", "false")
# ---- Speed-up friendly torch defaults (no-ops on CPU) ----
try:
import torch
torch.backends.cuda.matmul.allow_tf32 = True # TF32 on Ampere+ GPUs
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")
except Exception:
torch = None # torch may not be available in some CI envs
import gradio as gr
from transformers import AutoTokenizer, TextIteratorStreamer
# ---- Config ----
MODEL_ID = os.getenv("MODEL_ID", "2imi9/qwen3-1.7b-gptq-int4")
DEFAULT_SYSTEM_PROMPT = (
"You are a helpful assistant. Keep answers concise unless asked for detail."
)
# Optional compile (helps for repeated runs; set ENABLE_COMPILE=1)
ENABLE_COMPILE = os.getenv("ENABLE_COMPILE", "0") == "1"
GPU_DURATION = int(os.getenv("GPU_DURATION", "120")) # for Spaces slices
STREAM_YIELD_INTERVAL_S = float(os.getenv("STREAM_YIELD_INTERVAL_S", "0.03"))
STREAM_YIELD_MIN_TOKENS = int(os.getenv("STREAM_YIELD_MIN_TOKENS", "8"))
_tok = None
_mdl = None
# --- Reasoning/text filters ---
THINK_RE = re.compile(r"<think>.*?</think>", re.IGNORECASE | re.DOTALL)
def strip_reasoning(text: str) -> str:
return THINK_RE.sub("", text).strip()
# =========================
# Model loading & inference
# =========================
def _maybe_flash_attn_impl():
"""Pick the best attention backend available for speed."""
try:
from transformers.utils import is_flash_attn_2_available # type: ignore
return "flash_attention_2" if is_flash_attn_2_available() else "eager"
except Exception:
return "eager"
def load_model():
"""Lazy-load tokenizer & model on first call, with fast kernels when possible."""
global _tok, _mdl
if _tok is not None and _mdl is not None:
return _tok, _mdl
try:
from transformers import AutoModelForCausalLM
_tok = AutoTokenizer.from_pretrained(
MODEL_ID,
trust_remote_code=True,
use_fast=True,
)
attn_impl = _maybe_flash_attn_impl()
load_kwargs = dict(
torch_dtype="auto",
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
use_safetensors=True,
)
# Newer Transformers support attn_implementation
try:
_mdl = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
attn_implementation=attn_impl,
**load_kwargs,
)
except TypeError:
_mdl = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
_mdl.eval()
# Optional BetterTransformer (if available & compatible)
try:
from optimum.bettertransformer import BetterTransformer # type: ignore
_mdl = BetterTransformer.transform(_mdl, keep_original_model=False)
except Exception:
pass
# Optional torch.compile for extra speed on supported stacks
if ENABLE_COMPILE and torch is not None and hasattr(torch, "compile"):
try:
_mdl = torch.compile(_mdl, mode="reduce-overhead", fullgraph=True)
except Exception:
pass
# Ensure pad token is set (prevents warnings/slow paths)
if _tok.pad_token_id is None and _tok.eos_token_id is not None:
_tok.pad_token = _tok.eos_token
except Exception as e:
warnings.warn(f"[MODEL LOAD ERROR] {e}")
return None, None
return _tok, _mdl
# --- ZeroGPU scheduling: wrap the CUDA work (no-op locally) ---
try:
import spaces
def _gpu_wrap(fn):
return spaces.GPU(duration=GPU_DURATION)(fn)
except Exception: # running outside Spaces
def _gpu_wrap(fn):
return fn
@_gpu_wrap
def _infer(messages, max_new_tokens, temperature, top_p):
tok, mdl = load_model()
if tok is None or mdl is None:
return "Model unavailable. Check Hub access/token or hardware."
if torch is not None:
ctx = torch.inference_mode()
else:
# Fallback context manager
class _noop:
def __enter__(self):
return None
def __exit__(self, *exc):
return False
ctx = _noop()
ids = tok.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
).to(mdl.device)
with ctx:
out = mdl.generate(
ids,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
do_sample=True,
use_cache=True,
pad_token_id=tok.pad_token_id or tok.eos_token_id,
eos_token_id=tok.eos_token_id,
)
return tok.decode(out[0][ids.shape[-1] :], skip_special_tokens=True).strip()
# =========================
# Chat helpers (messages-style)
# =========================
EXAMPLES = [
"What are 8 good questions to get to know a stranger?",
"Create a list of 10 unusual excuses people might use to get out of a work meeting",
"Write a python function to reverse a string",
"Explain special relativity in French",
"¿Cómo le explicarías el aprendizaje automático a un extraterrestre?",
"Summarize recent news about the North American tech job market",
"Explain gravity to a chicken.",
]
CUSTOM_CSS = """
#app-title h1 { margin: 0 0 6px 0; }
#app-title .subtle { opacity: .8; font-size: 14px; }
#params-card { border: 1px solid var(--border-color-primary); }
.gradio-container { max-width: 1200px !important; }
#footer { text-align:center; font-size: 12px; opacity: .8; }
#chatbot { height: 640px !important; }
"""
def add_user_message(msg: str, history_msgs: list):
if not msg:
raise gr.Error("Please enter a message.")
history_msgs = history_msgs or []
history_msgs.append({"role": "user", "content": msg})
return "", history_msgs
def stream_assistant_with_thinking(
system_prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
history_msgs: list,
show_thinking: bool = True,
hide_reasoning: bool = True,
):
"""
Streams the model response with an optional thinking placeholder and optional
filtering of <think>...</think> content from the generated text.
Yields tuples of (display_messages_without_system, full_history_with_system)
so both the Chatbot and the hidden state are updated together.
"""
history_msgs = history_msgs or []
if not history_msgs:
yield [], []
return
tok, mdl = load_model()
if tok is None or mdl is None:
history_msgs.append({"role": "assistant", "content": "Model unavailable."})
yield [m for m in history_msgs if m["role"] != "system"], history_msgs
return
messages_for_model = [
{"role": "system", "content": system_prompt or DEFAULT_SYSTEM_PROMPT}
] + history_msgs
# Optional thinking bubble
if show_thinking:
history_msgs.append({"role": "assistant", "content": "🧠 Thinking..."})
yield [m for m in history_msgs if m["role"] != "system"], history_msgs
inputs = tok.apply_chat_template(
messages_for_model,
add_generation_prompt=True,
return_tensors="pt",
).to(mdl.device)
streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
inputs=inputs,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
do_sample=True,
use_cache=True,
pad_token_id=tok.pad_token_id or tok.eos_token_id,
eos_token_id=tok.eos_token_id,
streamer=streamer,
)
def _gen():
if torch is not None:
with torch.inference_mode():
mdl.generate(**gen_kwargs)
else:
mdl.generate(**gen_kwargs)
import threading
t = threading.Thread(target=_gen, daemon=True)
t.start()
# Second assistant message (final answer), stream into it
history_msgs.append({"role": "assistant", "content": ""})
partial_raw = ""
last_yield = time.monotonic()
tokens_since_yield = 0
for token in streamer:
partial_raw += token
tokens_since_yield += 1
# Mark thinking as done the moment we get tokens (if enabled)
if show_thinking and history_msgs[-2]["content"].startswith("🧠 Thinking..."):
history_msgs[-2]["content"] = "🧠 Thoughts (done)"
# Optionally remove <think> blocks on the fly
rendered = strip_reasoning(partial_raw) if hide_reasoning else partial_raw
# Throttle UI updates to reduce overhead while keeping it snappy
now = time.monotonic()
if (
tokens_since_yield >= STREAM_YIELD_MIN_TOKENS
or (now - last_yield) >= STREAM_YIELD_INTERVAL_S
):
history_msgs[-1]["content"] = escape(rendered)
yield [m for m in history_msgs if m["role"] != "system"], history_msgs
last_yield = now
tokens_since_yield = 0
# Final flush
final_text = strip_reasoning(partial_raw) if hide_reasoning else partial_raw
history_msgs[-1]["content"] = escape(final_text.strip())
yield [m for m in history_msgs if m["role"] != "system"], history_msgs
t.join()
def clear_chat():
return [], [], DEFAULT_SYSTEM_PROMPT
# =========================
# UI
# =========================
with gr.Blocks(
title="Qwen3 1.7B — GPTQ INT4 (Fast)", css=CUSTOM_CSS, theme=gr.themes.Soft()
) as app:
# ----- Header -----
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(
"""
<div id='app-title'>
<h1>Qwen3 1.7B — GPTQ INT4</h1>
<div class='subtle'>Faster chat demo with streaming + optimized kernels</div>
</div>
""",
elem_id="app-title",
)
with gr.Column(scale=1):
model_id = gr.Textbox(value=MODEL_ID, label="Model ID", interactive=False)
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label=None,
show_copy_button=True,
show_share_button=False,
elem_id="chatbot",
height=640,
type="messages",
render_markdown=False,
autoscroll=True,
)
with gr.Row():
user_tb = gr.Textbox(
placeholder="Ask anything…",
show_label=False,
lines=2,
scale=5,
autofocus=True,
)
send_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Row():
clear_btn = gr.Button("Clear chat", variant="secondary")
gr.Examples(examples=EXAMPLES, inputs=user_tb, label="Try one", examples_per_page=12)
with gr.Column(scale=1):
with gr.Group(elem_id="params-card"):
with gr.Accordion("Parameters", open=True):
sys_prompt = gr.Textbox(value=DEFAULT_SYSTEM_PROMPT, lines=3, label="System prompt")
max_new = gr.Slider(16, 1024, value=256, step=8, label="Max new tokens")
temperature = gr.Slider(0.05, 1.5, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p")
show_thinking = gr.Checkbox(value=True, label="Show thinking bubble")
hide_reasoning = gr.Checkbox(value=True, label="Hide <think> content")
with gr.Group():
gr.Markdown(
"""
**About this demo**
• Quantized Qwen3 for quick experiments.
• FlashAttention 2 / BetterTransformer / torch.compile when available.
• Throttled streaming for smoother UI and lower CPU overhead.
"""
)
# ----- State -----
history_state = gr.State([]) # list of {role, content}
# ----- Events -----
def _submit(msg, history, system_prompt, _max_new, _temperature, _top_p=0.95, _show_thinking=True, _hide_reasoning=True):
# clear text box immediately
cleared, history = add_user_message(msg, history)
# first yield clears the textbox; Chatbot updates will follow from the stream
yield cleared, [m for m in history if m["role"] != "system"], history
# then stream assistant with thinking
for display, hist in stream_assistant_with_thinking(
system_prompt, _max_new, _temperature, _top_p, history, _show_thinking, _hide_reasoning
):
yield "", display, hist
user_tb.submit(
_submit, [user_tb, history_state, sys_prompt, max_new, temperature, top_p, show_thinking, hide_reasoning], [user_tb, chatbot, history_state]
)
send_btn.click(
_submit, [user_tb, history_state, sys_prompt, max_new, temperature, top_p, show_thinking, hide_reasoning], [user_tb, chatbot, history_state]
)
clear_btn.click(clear_chat, None, [chatbot, history_state, sys_prompt])
gr.Markdown("<div id='footer'>Built with 🤗 Gradio · Qwen3 1.7B GPTQ INT4 (Fast)</div>", elem_id="footer")
# Expose for Spaces
demo = app
if __name__ == "__main__":
try:
import inspect
# Build queue kwargs that your installed Gradio actually supports
queue_sig = inspect.signature(demo.queue)
qkw = {}
if "api_open" in queue_sig.parameters:
qkw["api_open"] = False
if "max_size" in queue_sig.parameters:
qkw["max_size"] = 20
if "concurrency_count" in queue_sig.parameters:
qkw["concurrency_count"] = 2
q = demo.queue(**qkw)
# Build launch kwargs safely across Gradio versions
launch_sig = inspect.signature(q.launch)
lkw = {"show_api": False}
if "ssr_mode" in launch_sig.parameters:
lkw["ssr_mode"] = False
q.launch(**lkw)
except Exception as e:
print(f"Error: {e}")