Jn-Huang commited on
Commit
fc3b3a2
·
1 Parent(s): 0600d50

Switch to transformers version - vLLM uses too much memory on T4 GPU

Browse files
Files changed (3) hide show
  1. app.py +65 -52
  2. app_transformers.py +0 -111
  3. app_vllm.py +17 -5
app.py CHANGED
@@ -1,83 +1,96 @@
1
- # app_vllm.py - Faster inference using vLLM
2
  import os
 
3
  import spaces
4
  import gradio as gr
5
- from vllm import LLM, SamplingParams
6
- from vllm.lora.request import LoRARequest
7
- from transformers import AutoTokenizer
8
 
9
  HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
10
 
11
  BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
12
  PEFT_MODEL_ID = "befm/Be.FM-8B"
13
 
14
- def load_model():
 
 
 
 
 
 
 
 
 
 
15
  if HF_TOKEN is None:
16
  raise RuntimeError(
17
  "HF_TOKEN is not set. Add it in Space → Settings → Secrets. "
18
  "Also ensure your account has access to the gated base model."
19
  )
20
-
21
- # Initialize vLLM with PEFT support
22
- llm = LLM(
23
- model=BASE_MODEL_ID,
24
- tokenizer=BASE_MODEL_ID,
25
- enable_lora=True,
26
- max_lora_rank=64,
27
- dtype="float16",
28
- gpu_memory_utilization=0.7, # Reduced from 0.9 to avoid OOM on T4 GPU
29
- trust_remote_code=True,
30
  )
31
-
32
- print(f"[INFO] vLLM loaded base model: {BASE_MODEL_ID}")
33
-
34
- # Load PEFT adapter
35
- lora_request = LoRARequest(
36
- lora_name="befm",
37
- lora_int_id=1,
38
- lora_path=PEFT_MODEL_ID,
 
39
  )
40
- print(f"[INFO] PEFT adapter prepared: {PEFT_MODEL_ID}")
41
-
42
- return llm, lora_request
43
 
44
- # Lazy load model and tokenizer
45
- _llm = None
46
- _lora_request = None
47
- _tokenizer = None
48
-
49
- def get_model_and_tokenizer():
50
- global _llm, _lora_request, _tokenizer
51
- if _llm is None:
52
- _llm, _lora_request = load_model()
53
- _tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN)
54
- return _llm, _lora_request, _tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  @spaces.GPU
 
57
  def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.9) -> str:
58
- llm, lora_request, tokenizer = get_model_and_tokenizer()
59
-
60
  # Apply Llama 3.1 chat template
61
  prompt = tokenizer.apply_chat_template(
62
  messages,
63
  tokenize=False,
64
  add_generation_prompt=True
65
  )
 
 
66
 
67
- sampling_params = SamplingParams(
 
 
 
 
68
  temperature=temperature,
69
  top_p=top_p,
70
- max_tokens=max_new_tokens,
71
  )
72
-
73
- # Generate with vLLM
74
- outputs = llm.generate(
75
- prompts=[prompt],
76
- sampling_params=sampling_params,
77
- lora_request=lora_request,
78
- )
79
-
80
- return outputs[0].outputs[0].text
81
 
82
  def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p):
83
  # Build conversation in Llama 3.1 chat format
@@ -112,8 +125,8 @@ demo = gr.ChatInterface(
112
  gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature"),
113
  gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p"),
114
  ],
115
- title="Be.FM-8B (vLLM) - Fast Inference",
116
- description="Chat interface using vLLM for optimized inference with Meta-Llama-3.1-8B-Instruct and PEFT adapter befm/Be.FM-8B."
117
  )
118
 
119
  if __name__ == "__main__":
 
1
+ # app.py
2
  import os
3
+ import torch
4
  import spaces
5
  import gradio as gr
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
7
 
8
  HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
9
 
10
  BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
11
  PEFT_MODEL_ID = "befm/Be.FM-8B"
12
 
13
+ # Use /data for persistent storage to avoid re-downloading models
14
+ CACHE_DIR = "/data" if os.path.exists("/data") else None
15
+
16
+ USE_PEFT = True
17
+ try:
18
+ from peft import PeftModel, PeftConfig # noqa
19
+ except Exception:
20
+ USE_PEFT = False
21
+ print("[WARN] 'peft' not installed; running base model only.")
22
+
23
+ def load_model_and_tokenizer():
24
  if HF_TOKEN is None:
25
  raise RuntimeError(
26
  "HF_TOKEN is not set. Add it in Space → Settings → Secrets. "
27
  "Also ensure your account has access to the gated base model."
28
  )
29
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
30
+ tok = AutoTokenizer.from_pretrained(
31
+ BASE_MODEL_ID,
32
+ token=HF_TOKEN,
33
+ cache_dir=CACHE_DIR # Use persistent storage
 
 
 
 
 
34
  )
35
+ if tok.pad_token is None:
36
+ tok.pad_token = tok.eos_token
37
+
38
+ base = AutoModelForCausalLM.from_pretrained(
39
+ BASE_MODEL_ID,
40
+ device_map="auto" if torch.cuda.is_available() else None,
41
+ torch_dtype=dtype,
42
+ token=HF_TOKEN,
43
+ cache_dir=CACHE_DIR # Use persistent storage
44
  )
 
 
 
45
 
46
+ print(f"[INFO] Using cache directory: {CACHE_DIR}")
47
+
48
+ if USE_PEFT:
49
+ try:
50
+ _ = PeftConfig.from_pretrained(
51
+ PEFT_MODEL_ID,
52
+ token=HF_TOKEN,
53
+ cache_dir=CACHE_DIR # Use persistent storage
54
+ )
55
+ model = PeftModel.from_pretrained(
56
+ base,
57
+ PEFT_MODEL_ID,
58
+ token=HF_TOKEN,
59
+ cache_dir=CACHE_DIR # Use persistent storage
60
+ )
61
+ print(f"[INFO] Loaded PEFT adapter: {PEFT_MODEL_ID}")
62
+ return model, tok
63
+ except Exception as e:
64
+ print(f"[WARN] Failed to load PEFT adapter: {e}")
65
+ return base, tok
66
+ return base, tok
67
+
68
+ model, tokenizer = load_model_and_tokenizer()
69
+ DEVICE = model.device
70
 
71
  @spaces.GPU
72
+ @torch.inference_mode()
73
  def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.9) -> str:
 
 
74
  # Apply Llama 3.1 chat template
75
  prompt = tokenizer.apply_chat_template(
76
  messages,
77
  tokenize=False,
78
  add_generation_prompt=True
79
  )
80
+ enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
81
+ enc = {k: v.to(DEVICE) for k, v in enc.items()}
82
 
83
+ input_length = enc['input_ids'].shape[1]
84
+ out = model.generate(
85
+ **enc,
86
+ max_new_tokens=max_new_tokens,
87
+ do_sample=True,
88
  temperature=temperature,
89
  top_p=top_p,
90
+ pad_token_id=tokenizer.eos_token_id,
91
  )
92
+ # Decode only the newly generated tokens
93
+ return tokenizer.decode(out[0][input_length:], skip_special_tokens=True)
 
 
 
 
 
 
 
94
 
95
  def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p):
96
  # Build conversation in Llama 3.1 chat format
 
125
  gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature"),
126
  gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p"),
127
  ],
128
+ title="Be.FM-8B (PEFT) on Meta-Llama-3.1-8B-Instruct",
129
+ description="Chat interface using Meta-Llama-3.1-8B-Instruct with PEFT adapter befm/Be.FM-8B."
130
  )
131
 
132
  if __name__ == "__main__":
app_transformers.py DELETED
@@ -1,111 +0,0 @@
1
- # app.py
2
- import os
3
- import torch
4
- import spaces
5
- import gradio as gr
6
- from transformers import AutoTokenizer, AutoModelForCausalLM
7
-
8
- HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
9
-
10
- BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
11
- PEFT_MODEL_ID = "befm/Be.FM-8B"
12
-
13
- USE_PEFT = True
14
- try:
15
- from peft import PeftModel, PeftConfig # noqa
16
- except Exception:
17
- USE_PEFT = False
18
- print("[WARN] 'peft' not installed; running base model only.")
19
-
20
- def load_model_and_tokenizer():
21
- if HF_TOKEN is None:
22
- raise RuntimeError(
23
- "HF_TOKEN is not set. Add it in Space → Settings → Secrets. "
24
- "Also ensure your account has access to the gated base model."
25
- )
26
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
27
- tok = AutoTokenizer.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN)
28
- if tok.pad_token is None:
29
- tok.pad_token = tok.eos_token
30
-
31
- base = AutoModelForCausalLM.from_pretrained(
32
- BASE_MODEL_ID,
33
- device_map="auto" if torch.cuda.is_available() else None,
34
- torch_dtype=dtype,
35
- token=HF_TOKEN,
36
- )
37
-
38
- if USE_PEFT:
39
- try:
40
- _ = PeftConfig.from_pretrained(PEFT_MODEL_ID, token=HF_TOKEN)
41
- model = PeftModel.from_pretrained(base, PEFT_MODEL_ID, token=HF_TOKEN)
42
- print(f"[INFO] Loaded PEFT adapter: {PEFT_MODEL_ID}")
43
- return model, tok
44
- except Exception as e:
45
- print(f"[WARN] Failed to load PEFT adapter: {e}")
46
- return base, tok
47
- return base, tok
48
-
49
- model, tokenizer = load_model_and_tokenizer()
50
- DEVICE = model.device
51
-
52
- @spaces.GPU
53
- @torch.inference_mode()
54
- def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.9) -> str:
55
- # Apply Llama 3.1 chat template
56
- prompt = tokenizer.apply_chat_template(
57
- messages,
58
- tokenize=False,
59
- add_generation_prompt=True
60
- )
61
- enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
62
- enc = {k: v.to(DEVICE) for k, v in enc.items()}
63
-
64
- input_length = enc['input_ids'].shape[1]
65
- out = model.generate(
66
- **enc,
67
- max_new_tokens=max_new_tokens,
68
- do_sample=True,
69
- temperature=temperature,
70
- top_p=top_p,
71
- pad_token_id=tokenizer.eos_token_id,
72
- )
73
- # Decode only the newly generated tokens
74
- return tokenizer.decode(out[0][input_length:], skip_special_tokens=True)
75
-
76
- def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p):
77
- # Build conversation in Llama 3.1 chat format
78
- messages = []
79
- if system_prompt:
80
- messages.append({"role": "system", "content": system_prompt})
81
-
82
- # History is already in dict format: [{"role": "user", "content": "..."}, ...]
83
- for msg in (history or []):
84
- messages.append(msg)
85
-
86
- if message:
87
- messages.append({"role": "user", "content": message})
88
-
89
- reply = generate_response(
90
- messages,
91
- max_new_tokens=max_new_tokens,
92
- temperature=temperature,
93
- top_p=top_p,
94
- )
95
- return reply
96
-
97
- demo = gr.ChatInterface(
98
- fn=lambda message, history, system_prompt, max_new_tokens, temperature, top_p:
99
- chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p),
100
- additional_inputs=[
101
- gr.Textbox(label="System prompt (optional)", placeholder="You are Be.FM assistant...", lines=2),
102
- gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens"),
103
- gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature"),
104
- gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p"),
105
- ],
106
- title="Be.FM-8B (PEFT) on Meta-Llama-3.1-8B-Instruct",
107
- description="Chat interface using Meta-Llama-3.1-8B-Instruct with PEFT adapter befm/Be.FM-8B."
108
- )
109
-
110
- if __name__ == "__main__":
111
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_vllm.py CHANGED
@@ -11,6 +11,9 @@ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
11
  BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
12
  PEFT_MODEL_ID = "befm/Be.FM-8B"
13
 
 
 
 
14
  def load_model():
15
  if HF_TOKEN is None:
16
  raise RuntimeError(
@@ -25,11 +28,13 @@ def load_model():
25
  enable_lora=True,
26
  max_lora_rank=64,
27
  dtype="float16",
28
- gpu_memory_utilization=0.9,
29
  trust_remote_code=True,
 
30
  )
31
 
32
  print(f"[INFO] vLLM loaded base model: {BASE_MODEL_ID}")
 
33
 
34
  # Load PEFT adapter
35
  lora_request = LoRARequest(
@@ -50,7 +55,11 @@ def get_model_and_tokenizer():
50
  global _llm, _lora_request, _tokenizer
51
  if _llm is None:
52
  _llm, _lora_request = load_model()
53
- _tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN)
 
 
 
 
54
  return _llm, _lora_request, _tokenizer
55
 
56
  @spaces.GPU
@@ -82,8 +91,11 @@ def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.9)
82
  def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p):
83
  # Build conversation in Llama 3.1 chat format
84
  messages = []
85
- if system_prompt:
86
- messages.append({"role": "system", "content": system_prompt})
 
 
 
87
 
88
  # History is already in dict format: [{"role": "user", "content": "..."}, ...]
89
  for msg in (history or []):
@@ -109,7 +121,7 @@ demo = gr.ChatInterface(
109
  gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature"),
110
  gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p"),
111
  ],
112
- title="Be.FM-8B (vLLM) - Fast Inference",
113
  description="Chat interface using vLLM for optimized inference with Meta-Llama-3.1-8B-Instruct and PEFT adapter befm/Be.FM-8B."
114
  )
115
 
 
11
  BASE_MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
12
  PEFT_MODEL_ID = "befm/Be.FM-8B"
13
 
14
+ # Use /data for persistent storage to avoid re-downloading models
15
+ CACHE_DIR = "/data" if os.path.exists("/data") else None
16
+
17
  def load_model():
18
  if HF_TOKEN is None:
19
  raise RuntimeError(
 
28
  enable_lora=True,
29
  max_lora_rank=64,
30
  dtype="float16",
31
+ gpu_memory_utilization=0.7, # Reduced from 0.9 to avoid OOM on T4 GPU
32
  trust_remote_code=True,
33
+ download_dir=CACHE_DIR, # Use persistent storage
34
  )
35
 
36
  print(f"[INFO] vLLM loaded base model: {BASE_MODEL_ID}")
37
+ print(f"[INFO] Using cache directory: {CACHE_DIR}")
38
 
39
  # Load PEFT adapter
40
  lora_request = LoRARequest(
 
55
  global _llm, _lora_request, _tokenizer
56
  if _llm is None:
57
  _llm, _lora_request = load_model()
58
+ _tokenizer = AutoTokenizer.from_pretrained(
59
+ BASE_MODEL_ID,
60
+ token=HF_TOKEN,
61
+ cache_dir=CACHE_DIR # Use persistent storage
62
+ )
63
  return _llm, _lora_request, _tokenizer
64
 
65
  @spaces.GPU
 
91
  def chat_fn(message, history, system_prompt, max_new_tokens, temperature, top_p):
92
  # Build conversation in Llama 3.1 chat format
93
  messages = []
94
+
95
+ # Add system prompt (use default if not provided)
96
+ if not system_prompt:
97
+ system_prompt = "You are Be.FM, a helpful and knowledgeable AI assistant. Provide clear, accurate, and concise responses."
98
+ messages.append({"role": "system", "content": system_prompt})
99
 
100
  # History is already in dict format: [{"role": "user", "content": "..."}, ...]
101
  for msg in (history or []):
 
121
  gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="temperature"),
122
  gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p"),
123
  ],
124
+ title="Be.FM-8B (vLLM)",
125
  description="Chat interface using vLLM for optimized inference with Meta-Llama-3.1-8B-Instruct and PEFT adapter befm/Be.FM-8B."
126
  )
127