Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import warnings | |
| import uuid | |
| from html import escape | |
| import re | |
| # (optional) Disable experimental SSR to avoid Node server restarts | |
| os.environ.setdefault("GRADIO_SERVER_SSR", "false") | |
| # ---- Speed-up friendly torch defaults (no-ops on CPU) ---- | |
| try: | |
| import torch | |
| torch.backends.cuda.matmul.allow_tf32 = True # TF32 on Ampere+ GPUs | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.set_float32_matmul_precision("high") | |
| except Exception: | |
| torch = None # torch may not be available in some CI envs | |
| import gradio as gr | |
| from transformers import AutoTokenizer, TextIteratorStreamer | |
| # ---- Config ---- | |
| MODEL_ID = os.getenv("MODEL_ID", "2imi9/qwen3-1.7b-gptq-int4") | |
| DEFAULT_SYSTEM_PROMPT = ( | |
| "You are a helpful assistant. Keep answers concise unless asked for detail." | |
| ) | |
| # Optional compile (helps for repeated runs; set ENABLE_COMPILE=1) | |
| ENABLE_COMPILE = os.getenv("ENABLE_COMPILE", "0") == "1" | |
| GPU_DURATION = int(os.getenv("GPU_DURATION", "120")) # for Spaces slices | |
| STREAM_YIELD_INTERVAL_S = float(os.getenv("STREAM_YIELD_INTERVAL_S", "0.03")) | |
| STREAM_YIELD_MIN_TOKENS = int(os.getenv("STREAM_YIELD_MIN_TOKENS", "8")) | |
| _tok = None | |
| _mdl = None | |
| # --- Reasoning/text filters --- | |
| THINK_RE = re.compile(r"<think>.*?</think>", re.IGNORECASE | re.DOTALL) | |
| def strip_reasoning(text: str) -> str: | |
| return THINK_RE.sub("", text).strip() | |
| # ========================= | |
| # Model loading & inference | |
| # ========================= | |
| def _maybe_flash_attn_impl(): | |
| """Pick the best attention backend available for speed.""" | |
| try: | |
| from transformers.utils import is_flash_attn_2_available # type: ignore | |
| return "flash_attention_2" if is_flash_attn_2_available() else "eager" | |
| except Exception: | |
| return "eager" | |
| def load_model(): | |
| """Lazy-load tokenizer & model on first call, with fast kernels when possible.""" | |
| global _tok, _mdl | |
| if _tok is not None and _mdl is not None: | |
| return _tok, _mdl | |
| try: | |
| from transformers import AutoModelForCausalLM | |
| _tok = AutoTokenizer.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| use_fast=True, | |
| ) | |
| attn_impl = _maybe_flash_attn_impl() | |
| load_kwargs = dict( | |
| torch_dtype="auto", | |
| device_map="auto", | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| use_safetensors=True, | |
| ) | |
| # Newer Transformers support attn_implementation | |
| try: | |
| _mdl = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| attn_implementation=attn_impl, | |
| **load_kwargs, | |
| ) | |
| except TypeError: | |
| _mdl = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs) | |
| _mdl.eval() | |
| # Optional BetterTransformer (if available & compatible) | |
| try: | |
| from optimum.bettertransformer import BetterTransformer # type: ignore | |
| _mdl = BetterTransformer.transform(_mdl, keep_original_model=False) | |
| except Exception: | |
| pass | |
| # Optional torch.compile for extra speed on supported stacks | |
| if ENABLE_COMPILE and torch is not None and hasattr(torch, "compile"): | |
| try: | |
| _mdl = torch.compile(_mdl, mode="reduce-overhead", fullgraph=True) | |
| except Exception: | |
| pass | |
| # Ensure pad token is set (prevents warnings/slow paths) | |
| if _tok.pad_token_id is None and _tok.eos_token_id is not None: | |
| _tok.pad_token = _tok.eos_token | |
| except Exception as e: | |
| warnings.warn(f"[MODEL LOAD ERROR] {e}") | |
| return None, None | |
| return _tok, _mdl | |
| # --- ZeroGPU scheduling: wrap the CUDA work (no-op locally) --- | |
| try: | |
| import spaces | |
| def _gpu_wrap(fn): | |
| return spaces.GPU(duration=GPU_DURATION)(fn) | |
| except Exception: # running outside Spaces | |
| def _gpu_wrap(fn): | |
| return fn | |
| def _infer(messages, max_new_tokens, temperature, top_p): | |
| tok, mdl = load_model() | |
| if tok is None or mdl is None: | |
| return "Model unavailable. Check Hub access/token or hardware." | |
| if torch is not None: | |
| ctx = torch.inference_mode() | |
| else: | |
| # Fallback context manager | |
| class _noop: | |
| def __enter__(self): | |
| return None | |
| def __exit__(self, *exc): | |
| return False | |
| ctx = _noop() | |
| ids = tok.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| ).to(mdl.device) | |
| with ctx: | |
| out = mdl.generate( | |
| ids, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| do_sample=True, | |
| use_cache=True, | |
| pad_token_id=tok.pad_token_id or tok.eos_token_id, | |
| eos_token_id=tok.eos_token_id, | |
| ) | |
| return tok.decode(out[0][ids.shape[-1] :], skip_special_tokens=True).strip() | |
| # ========================= | |
| # Chat helpers (messages-style) | |
| # ========================= | |
| EXAMPLES = [ | |
| "What are 8 good questions to get to know a stranger?", | |
| "Create a list of 10 unusual excuses people might use to get out of a work meeting", | |
| "Write a python function to reverse a string", | |
| "Explain special relativity in French", | |
| "¿Cómo le explicarías el aprendizaje automático a un extraterrestre?", | |
| "Summarize recent news about the North American tech job market", | |
| "Explain gravity to a chicken.", | |
| ] | |
| CUSTOM_CSS = """ | |
| #app-title h1 { margin: 0 0 6px 0; } | |
| #app-title .subtle { opacity: .8; font-size: 14px; } | |
| #params-card { border: 1px solid var(--border-color-primary); } | |
| .gradio-container { max-width: 1200px !important; } | |
| #footer { text-align:center; font-size: 12px; opacity: .8; } | |
| #chatbot { height: 640px !important; } | |
| """ | |
| def add_user_message(msg: str, history_msgs: list): | |
| if not msg: | |
| raise gr.Error("Please enter a message.") | |
| history_msgs = history_msgs or [] | |
| history_msgs.append({"role": "user", "content": msg}) | |
| return "", history_msgs | |
| def stream_assistant_with_thinking( | |
| system_prompt: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| history_msgs: list, | |
| show_thinking: bool = True, | |
| hide_reasoning: bool = True, | |
| ): | |
| """ | |
| Streams the model response with an optional thinking placeholder and optional | |
| filtering of <think>...</think> content from the generated text. | |
| Yields tuples of (display_messages_without_system, full_history_with_system) | |
| so both the Chatbot and the hidden state are updated together. | |
| """ | |
| history_msgs = history_msgs or [] | |
| if not history_msgs: | |
| yield [], [] | |
| return | |
| tok, mdl = load_model() | |
| if tok is None or mdl is None: | |
| history_msgs.append({"role": "assistant", "content": "Model unavailable."}) | |
| yield [m for m in history_msgs if m["role"] != "system"], history_msgs | |
| return | |
| messages_for_model = [ | |
| {"role": "system", "content": system_prompt or DEFAULT_SYSTEM_PROMPT} | |
| ] + history_msgs | |
| # Optional thinking bubble | |
| if show_thinking: | |
| history_msgs.append({"role": "assistant", "content": "🧠 Thinking..."}) | |
| yield [m for m in history_msgs if m["role"] != "system"], history_msgs | |
| inputs = tok.apply_chat_template( | |
| messages_for_model, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| ).to(mdl.device) | |
| streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True) | |
| gen_kwargs = dict( | |
| inputs=inputs, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| do_sample=True, | |
| use_cache=True, | |
| pad_token_id=tok.pad_token_id or tok.eos_token_id, | |
| eos_token_id=tok.eos_token_id, | |
| streamer=streamer, | |
| ) | |
| def _gen(): | |
| if torch is not None: | |
| with torch.inference_mode(): | |
| mdl.generate(**gen_kwargs) | |
| else: | |
| mdl.generate(**gen_kwargs) | |
| import threading | |
| t = threading.Thread(target=_gen, daemon=True) | |
| t.start() | |
| # Second assistant message (final answer), stream into it | |
| history_msgs.append({"role": "assistant", "content": ""}) | |
| partial_raw = "" | |
| last_yield = time.monotonic() | |
| tokens_since_yield = 0 | |
| for token in streamer: | |
| partial_raw += token | |
| tokens_since_yield += 1 | |
| # Mark thinking as done the moment we get tokens (if enabled) | |
| if show_thinking and history_msgs[-2]["content"].startswith("🧠 Thinking..."): | |
| history_msgs[-2]["content"] = "🧠 Thoughts (done)" | |
| # Optionally remove <think> blocks on the fly | |
| rendered = strip_reasoning(partial_raw) if hide_reasoning else partial_raw | |
| # Throttle UI updates to reduce overhead while keeping it snappy | |
| now = time.monotonic() | |
| if ( | |
| tokens_since_yield >= STREAM_YIELD_MIN_TOKENS | |
| or (now - last_yield) >= STREAM_YIELD_INTERVAL_S | |
| ): | |
| history_msgs[-1]["content"] = escape(rendered) | |
| yield [m for m in history_msgs if m["role"] != "system"], history_msgs | |
| last_yield = now | |
| tokens_since_yield = 0 | |
| # Final flush | |
| final_text = strip_reasoning(partial_raw) if hide_reasoning else partial_raw | |
| history_msgs[-1]["content"] = escape(final_text.strip()) | |
| yield [m for m in history_msgs if m["role"] != "system"], history_msgs | |
| t.join() | |
| def clear_chat(): | |
| return [], [], DEFAULT_SYSTEM_PROMPT | |
| # ========================= | |
| # UI | |
| # ========================= | |
| with gr.Blocks( | |
| title="Qwen3 1.7B — GPTQ INT4 (Fast)", css=CUSTOM_CSS, theme=gr.themes.Soft() | |
| ) as app: | |
| # ----- Header ----- | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown( | |
| """ | |
| <div id='app-title'> | |
| <h1>Qwen3 1.7B — GPTQ INT4</h1> | |
| <div class='subtle'>Faster chat demo with streaming + optimized kernels</div> | |
| </div> | |
| """, | |
| elem_id="app-title", | |
| ) | |
| with gr.Column(scale=1): | |
| model_id = gr.Textbox(value=MODEL_ID, label="Model ID", interactive=False) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot( | |
| label=None, | |
| show_copy_button=True, | |
| show_share_button=False, | |
| elem_id="chatbot", | |
| height=640, | |
| type="messages", | |
| render_markdown=False, | |
| autoscroll=True, | |
| ) | |
| with gr.Row(): | |
| user_tb = gr.Textbox( | |
| placeholder="Ask anything…", | |
| show_label=False, | |
| lines=2, | |
| scale=5, | |
| autofocus=True, | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear chat", variant="secondary") | |
| gr.Examples(examples=EXAMPLES, inputs=user_tb, label="Try one", examples_per_page=12) | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_id="params-card"): | |
| with gr.Accordion("Parameters", open=True): | |
| sys_prompt = gr.Textbox(value=DEFAULT_SYSTEM_PROMPT, lines=3, label="System prompt") | |
| max_new = gr.Slider(16, 1024, value=256, step=8, label="Max new tokens") | |
| temperature = gr.Slider(0.05, 1.5, value=0.7, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p") | |
| show_thinking = gr.Checkbox(value=True, label="Show thinking bubble") | |
| hide_reasoning = gr.Checkbox(value=True, label="Hide <think> content") | |
| with gr.Group(): | |
| gr.Markdown( | |
| """ | |
| **About this demo** | |
| • Quantized Qwen3 for quick experiments. | |
| • FlashAttention 2 / BetterTransformer / torch.compile when available. | |
| • Throttled streaming for smoother UI and lower CPU overhead. | |
| """ | |
| ) | |
| # ----- State ----- | |
| history_state = gr.State([]) # list of {role, content} | |
| # ----- Events ----- | |
| def _submit(msg, history, system_prompt, _max_new, _temperature, _top_p=0.95, _show_thinking=True, _hide_reasoning=True): | |
| # clear text box immediately | |
| cleared, history = add_user_message(msg, history) | |
| # first yield clears the textbox; Chatbot updates will follow from the stream | |
| yield cleared, [m for m in history if m["role"] != "system"], history | |
| # then stream assistant with thinking | |
| for display, hist in stream_assistant_with_thinking( | |
| system_prompt, _max_new, _temperature, _top_p, history, _show_thinking, _hide_reasoning | |
| ): | |
| yield "", display, hist | |
| user_tb.submit( | |
| _submit, [user_tb, history_state, sys_prompt, max_new, temperature, top_p, show_thinking, hide_reasoning], [user_tb, chatbot, history_state] | |
| ) | |
| send_btn.click( | |
| _submit, [user_tb, history_state, sys_prompt, max_new, temperature, top_p, show_thinking, hide_reasoning], [user_tb, chatbot, history_state] | |
| ) | |
| clear_btn.click(clear_chat, None, [chatbot, history_state, sys_prompt]) | |
| gr.Markdown("<div id='footer'>Built with 🤗 Gradio · Qwen3 1.7B GPTQ INT4 (Fast)</div>", elem_id="footer") | |
| # Expose for Spaces | |
| demo = app | |
| if __name__ == "__main__": | |
| try: | |
| import inspect | |
| # Build queue kwargs that your installed Gradio actually supports | |
| queue_sig = inspect.signature(demo.queue) | |
| qkw = {} | |
| if "api_open" in queue_sig.parameters: | |
| qkw["api_open"] = False | |
| if "max_size" in queue_sig.parameters: | |
| qkw["max_size"] = 20 | |
| if "concurrency_count" in queue_sig.parameters: | |
| qkw["concurrency_count"] = 2 | |
| q = demo.queue(**qkw) | |
| # Build launch kwargs safely across Gradio versions | |
| launch_sig = inspect.signature(q.launch) | |
| lkw = {"show_api": False} | |
| if "ssr_mode" in launch_sig.parameters: | |
| lkw["ssr_mode"] = False | |
| q.launch(**lkw) | |
| except Exception as e: | |
| print(f"Error: {e}") | |