aleclyu commited on
Commit
d87f42b
·
1 Parent(s): acfce9f

fix zerogpu error

Browse files
Files changed (1) hide show
  1. app.py +19 -15
app.py CHANGED
@@ -90,13 +90,21 @@ def _gc():
90
 
91
 
92
  def _launch_demo(args, model, processor):
93
- # 关键:减少 duration 30 秒,如果超时说明有问题
94
- @spaces.GPU(duration=30)
95
- def call_local_model(model, processor, messages):
96
  import time
97
  start_time = time.time()
98
  print(f"[DEBUG] ========== 开始推理 ==========")
99
- print(f"[DEBUG] 时间: {time.strftime('%Y-%m-%d %H:%M:%S')}")
 
 
 
 
 
 
 
 
100
 
101
  messages = [messages]
102
 
@@ -117,26 +125,22 @@ def _launch_demo(args, model, processor):
117
  padding=True,
118
  return_tensors="pt",
119
  )
120
- inputs = inputs.to(model.device)
 
121
  print(f"[DEBUG] 输入准备完成,耗时: {time.time() - start_time:.2f}s")
122
  print(f"[DEBUG] Input IDs shape: {inputs.input_ids.shape}")
123
- print(f"[DEBUG] Device: {model.device}")
124
 
125
- # 关键优化:极限压缩参数
126
  gen_start = time.time()
127
  with torch.no_grad():
128
  generated_ids = model.generate(
129
  **inputs,
130
- max_new_tokens=512, # 从 8192 降到 512,避免无限生成
131
  repetition_penalty=1.03,
132
  do_sample=False,
133
  eos_token_id=processor.tokenizer.eos_token_id,
134
  pad_token_id=processor.tokenizer.pad_token_id,
135
- use_cache=True,
136
- # 关键:添加长度惩罚,鼓励短输出
137
- length_penalty=0.8,
138
- # 添加早停
139
- early_stopping=True,
140
  )
141
 
142
  gen_time = time.time() - gen_start
@@ -201,8 +205,8 @@ def _launch_demo(args, model, processor):
201
  content = []
202
  messages.pop()
203
 
204
- # 调用模型获取响应
205
- response_list = call_local_model(model, processor, messages)
206
  response = response_list[0] if response_list else ""
207
 
208
  _chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response)))
 
90
 
91
 
92
  def _launch_demo(args, model, processor):
93
+ # 关键修复:移除 model processor 参数,使用闭包访问
94
+ @spaces.GPU(duration=60)
95
+ def call_local_model(messages):
96
  import time
97
  start_time = time.time()
98
  print(f"[DEBUG] ========== 开始推理 ==========")
99
+
100
+ # 关键:检查并确保模型在 GPU 上
101
+ model_device = next(model.parameters()).device
102
+ print(f"[DEBUG] Model device: {model_device}")
103
+
104
+ if str(model_device) == 'cpu':
105
+ print(f"[ERROR] 模型在 CPU 上!尝试移动到 GPU...")
106
+ model.cuda()
107
+ print(f"[DEBUG] Model device after cuda(): {next(model.parameters()).device}")
108
 
109
  messages = [messages]
110
 
 
125
  padding=True,
126
  return_tensors="pt",
127
  )
128
+ # 确保输入在 GPU 上
129
+ inputs = inputs.to('cuda' if torch.cuda.is_available() else 'cpu')
130
  print(f"[DEBUG] 输入准备完成,耗时: {time.time() - start_time:.2f}s")
131
  print(f"[DEBUG] Input IDs shape: {inputs.input_ids.shape}")
132
+ print(f"[DEBUG] Input device: {inputs.input_ids.device}")
133
 
134
+ # 生成
135
  gen_start = time.time()
136
  with torch.no_grad():
137
  generated_ids = model.generate(
138
  **inputs,
139
+ max_new_tokens=256,
140
  repetition_penalty=1.03,
141
  do_sample=False,
142
  eos_token_id=processor.tokenizer.eos_token_id,
143
  pad_token_id=processor.tokenizer.pad_token_id,
 
 
 
 
 
144
  )
145
 
146
  gen_time = time.time() - gen_start
 
205
  content = []
206
  messages.pop()
207
 
208
+ # 调用模型获取响应(已修改:不再传递 model 和 processor)
209
+ response_list = call_local_model(messages)
210
  response = response_list[0] if response_list else ""
211
 
212
  _chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response)))