Spaces:
Running
Running
fix zerogpu error
Browse files
app.py
CHANGED
|
@@ -90,13 +90,21 @@ def _gc():
|
|
| 90 |
|
| 91 |
|
| 92 |
def _launch_demo(args, model, processor):
|
| 93 |
-
#
|
| 94 |
-
@spaces.GPU(duration=
|
| 95 |
-
def call_local_model(
|
| 96 |
import time
|
| 97 |
start_time = time.time()
|
| 98 |
print(f"[DEBUG] ========== 开始推理 ==========")
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
messages = [messages]
|
| 102 |
|
|
@@ -117,26 +125,22 @@ def _launch_demo(args, model, processor):
|
|
| 117 |
padding=True,
|
| 118 |
return_tensors="pt",
|
| 119 |
)
|
| 120 |
-
|
|
|
|
| 121 |
print(f"[DEBUG] 输入准备完成,耗时: {time.time() - start_time:.2f}s")
|
| 122 |
print(f"[DEBUG] Input IDs shape: {inputs.input_ids.shape}")
|
| 123 |
-
print(f"[DEBUG]
|
| 124 |
|
| 125 |
-
#
|
| 126 |
gen_start = time.time()
|
| 127 |
with torch.no_grad():
|
| 128 |
generated_ids = model.generate(
|
| 129 |
**inputs,
|
| 130 |
-
max_new_tokens=
|
| 131 |
repetition_penalty=1.03,
|
| 132 |
do_sample=False,
|
| 133 |
eos_token_id=processor.tokenizer.eos_token_id,
|
| 134 |
pad_token_id=processor.tokenizer.pad_token_id,
|
| 135 |
-
use_cache=True,
|
| 136 |
-
# 关键:添加长度惩罚,鼓励短输出
|
| 137 |
-
length_penalty=0.8,
|
| 138 |
-
# 添加早停
|
| 139 |
-
early_stopping=True,
|
| 140 |
)
|
| 141 |
|
| 142 |
gen_time = time.time() - gen_start
|
|
@@ -201,8 +205,8 @@ def _launch_demo(args, model, processor):
|
|
| 201 |
content = []
|
| 202 |
messages.pop()
|
| 203 |
|
| 204 |
-
#
|
| 205 |
-
response_list = call_local_model(
|
| 206 |
response = response_list[0] if response_list else ""
|
| 207 |
|
| 208 |
_chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response)))
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
def _launch_demo(args, model, processor):
|
| 93 |
+
# 关键修复:移除 model 和 processor 参数,使用闭包访问
|
| 94 |
+
@spaces.GPU(duration=60)
|
| 95 |
+
def call_local_model(messages):
|
| 96 |
import time
|
| 97 |
start_time = time.time()
|
| 98 |
print(f"[DEBUG] ========== 开始推理 ==========")
|
| 99 |
+
|
| 100 |
+
# 关键:检查并确保模型在 GPU 上
|
| 101 |
+
model_device = next(model.parameters()).device
|
| 102 |
+
print(f"[DEBUG] Model device: {model_device}")
|
| 103 |
+
|
| 104 |
+
if str(model_device) == 'cpu':
|
| 105 |
+
print(f"[ERROR] 模型在 CPU 上!尝试移动到 GPU...")
|
| 106 |
+
model.cuda()
|
| 107 |
+
print(f"[DEBUG] Model device after cuda(): {next(model.parameters()).device}")
|
| 108 |
|
| 109 |
messages = [messages]
|
| 110 |
|
|
|
|
| 125 |
padding=True,
|
| 126 |
return_tensors="pt",
|
| 127 |
)
|
| 128 |
+
# 确保输入在 GPU 上
|
| 129 |
+
inputs = inputs.to('cuda' if torch.cuda.is_available() else 'cpu')
|
| 130 |
print(f"[DEBUG] 输入准备完成,耗时: {time.time() - start_time:.2f}s")
|
| 131 |
print(f"[DEBUG] Input IDs shape: {inputs.input_ids.shape}")
|
| 132 |
+
print(f"[DEBUG] Input device: {inputs.input_ids.device}")
|
| 133 |
|
| 134 |
+
# 生成
|
| 135 |
gen_start = time.time()
|
| 136 |
with torch.no_grad():
|
| 137 |
generated_ids = model.generate(
|
| 138 |
**inputs,
|
| 139 |
+
max_new_tokens=256,
|
| 140 |
repetition_penalty=1.03,
|
| 141 |
do_sample=False,
|
| 142 |
eos_token_id=processor.tokenizer.eos_token_id,
|
| 143 |
pad_token_id=processor.tokenizer.pad_token_id,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
)
|
| 145 |
|
| 146 |
gen_time = time.time() - gen_start
|
|
|
|
| 205 |
content = []
|
| 206 |
messages.pop()
|
| 207 |
|
| 208 |
+
# 调用模型获取响应(已修改:不再传递 model 和 processor)
|
| 209 |
+
response_list = call_local_model(messages)
|
| 210 |
response = response_list[0] if response_list else ""
|
| 211 |
|
| 212 |
_chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response)))
|