# app_vllm.py - Faster inference using vLLM import os import spaces import gradio as gr from vllm import LLM, SamplingParams from vllm.lora.request import LoRARequest from transformers import AutoTokenizer HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct" PEFT_MODEL_ID = "befm/Be.FM-8B" # Use /data for persistent storage to avoid re-downloading models CACHE_DIR = "/data" if os.path.exists("/data") else None def load_model(): if HF_TOKEN is None: raise RuntimeError( "HF_TOKEN is not set. Add it in Space → Settings → Secrets. " "Also ensure your account has access to the gated base model." ) # Initialize vLLM with PEFT support llm = LLM( model=BASE_MODEL_ID, tokenizer=BASE_MODEL_ID, enable_lora=True, max_lora_rank=64, dtype="float16", gpu_memory_utilization=0.7, # Reduced from 0.9 to avoid OOM on T4 GPU trust_remote_code=True, download_dir=CACHE_DIR, # Use persistent storage ) print(f"[INFO] vLLM loaded base model: {BASE_MODEL_ID}") print(f"[INFO] Using cache directory: {CACHE_DIR}") # Load PEFT adapter lora_request = LoRARequest( lora_name="befm", lora_int_id=1, lora_path=PEFT_MODEL_ID, ) print(f"[INFO] PEFT adapter prepared: {PEFT_MODEL_ID}") return llm, lora_request # Lazy load model and tokenizer _llm = None _lora_request = None _tokenizer = None def get_model_and_tokenizer(): global _llm, _lora_request, _tokenizer if _llm is None: _llm, _lora_request = load_model() _tokenizer = AutoTokenizer.from_pretrained( BASE_MODEL_ID, token=HF_TOKEN, cache_dir=CACHE_DIR # Use persistent storage ) return _llm, _lora_request, _tokenizer @spaces.GPU def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.9) -> str: llm, lora_request, tokenizer = get_model_and_tokenizer() # Apply Llama 3.1 chat template prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) sampling_params = SamplingParams( temperature=temperature, top_p=top_p, max_tokens=max_new_tokens, ) # Generate with vLLM outputs = llm.generate( prompts=[prompt], sampling_params=sampling_params, lora_request=lora_request, ) return outputs[0].outputs[0].text def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p): # Build conversation in Llama 3.1 chat format messages = [] # Add system prompt (use default if not provided) if not system_prompt: system_prompt = "You are Be.FM, a helpful and knowledgeable AI assistant. Provide clear, accurate, and concise responses." messages.append({"role": "system", "content": system_prompt}) # History is already in dict format: [{"role": "user", "content": "..."}, ...] for msg in (history or []): messages.append(msg) if message: messages.append({"role": "user", "content": message}) reply = generate_response( messages, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, ) return reply demo = gr.ChatInterface( fn=lambda message, history, system_prompt, max_new_tokens, temperature, top_p: chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p), additional_inputs=[ gr.Textbox(label="System prompt (optional)", placeholder="You are Be.FM assistant...", lines=2), gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"), gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature"), gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p"), ], title="Be.FM-8B (vLLM)", description="Chat interface using vLLM for optimized inference with Meta-Llama-3.1-8B-Instruct and PEFT adapter befm/Be.FM-8B." ) if __name__ == "__main__": demo.launch()