aleclyu commited on
Commit
1efad72
·
1 Parent(s): 1b38493

fix zerogpu error

Browse files
Files changed (1) hide show
  1. app.py +75 -26
app.py CHANGED
@@ -41,22 +41,46 @@ def _get_args():
41
  action='store_true',
42
  default=False,
43
  help='Automatically launch the interface in a new tab on the default browser.')
44
- # parser.add_argument('--server-port', type=int, default=8080, help='Demo server port.')
45
- # parser.add_argument('--server-name', type=str, default='29.210.129.176', help='Demo server name.')
46
 
47
  args = parser.parse_args()
48
  return args
49
 
50
 
51
  def _load_model_processor(args):
52
- model = HunYuanVLForConditionalGeneration.from_pretrained(
53
- args.checkpoint_path,
54
- attn_implementation="eager", # "flash_attention_2", #也可以是 flash_attention_2 或 sdpa,根据你的环境支持情况选择
55
- torch_dtype=torch.bfloat16,
56
- # device_map="auto",
57
- device_map="cuda",
58
- token=os.environ.get('HF_TOKEN')
59
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  processor = AutoProcessor.from_pretrained(args.checkpoint_path, use_fast=False, trust_remote_code=True)
61
  return model, processor
62
 
@@ -88,16 +112,29 @@ def _gc():
88
 
89
 
90
  def _launch_demo(args, model, processor):
91
- @spaces.GPU(duration=200)
 
92
  def call_local_model(model, processor, messages):
93
- print(messages)
 
 
 
 
94
  messages = [messages]
95
  # 使用 processor 构造输入格式
96
  texts = [
97
  processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
98
  for msg in messages
99
  ]
 
 
 
 
100
  image_inputs, video_inputs = process_vision_info(messages)
 
 
 
 
101
  inputs = processor(
102
  text=texts,
103
  images=image_inputs,
@@ -107,25 +144,30 @@ def _launch_demo(args, model, processor):
107
  )
108
  inputs = inputs.to(model.device)
109
 
110
-
111
- # gen_kwargs = {'max_new_tokens': 32768, 'streamer': streamer, **inputs}
112
- # thread = Thread(target=model.generate, kwargs=gen_kwargs)
113
- # thread.start()
114
-
115
- # generated_text = ''
116
- # for new_text in streamer:
117
- # generated_text += new_text
118
- # yield generated_text
119
-
120
- # 模型推理
121
  with torch.no_grad():
122
  generated_ids = model.generate(
123
  **inputs,
124
- max_new_tokens=1024*8,
125
  repetition_penalty=1.03,
126
- do_sample=False
 
 
 
 
 
127
  )
128
 
 
 
 
 
129
  # 解码输出
130
  if "input_ids" in inputs:
131
  input_ids = inputs.input_ids
@@ -135,11 +177,18 @@ def _launch_demo(args, model, processor):
135
  generated_ids_trimmed = [
136
  out_ids[len(in_ids):] for in_ids, out_ids in zip(input_ids, generated_ids)
137
  ]
 
 
138
 
139
  output_texts = processor.batch_decode(
140
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
141
  )
142
-
 
 
 
 
 
143
  return output_texts
144
 
145
 
 
41
  action='store_true',
42
  default=False,
43
  help='Automatically launch the interface in a new tab on the default browser.')
44
+
 
45
 
46
  args = parser.parse_args()
47
  return args
48
 
49
 
50
  def _load_model_processor(args):
51
+ # 优化:尝试使用 flash_attention_2 或 sdpa
52
+ try:
53
+ attn_impl = "flash_attention_2"
54
+ print(f"[INFO] 尝试使用 {attn_impl}")
55
+ model = HunYuanVLForConditionalGeneration.from_pretrained(
56
+ args.checkpoint_path,
57
+ attn_implementation=attn_impl,
58
+ torch_dtype=torch.bfloat16,
59
+ device_map="cuda",
60
+ token=os.environ.get('HF_TOKEN')
61
+ )
62
+ except Exception as e:
63
+ print(f"[WARNING] flash_attention_2 不可用: {e}")
64
+ print(f"[INFO] 降级使用 sdpa")
65
+ try:
66
+ model = HunYuanVLForConditionalGeneration.from_pretrained(
67
+ args.checkpoint_path,
68
+ attn_implementation="sdpa",
69
+ torch_dtype=torch.bfloat16,
70
+ device_map="cuda",
71
+ token=os.environ.get('HF_TOKEN')
72
+ )
73
+ except Exception as e2:
74
+ print(f"[WARNING] sdpa 不可用: {e2}")
75
+ print(f"[INFO] 使用 eager (最慢)")
76
+ model = HunYuanVLForConditionalGeneration.from_pretrained(
77
+ args.checkpoint_path,
78
+ attn_implementation="eager",
79
+ torch_dtype=torch.bfloat16,
80
+ device_map="cuda",
81
+ token=os.environ.get('HF_TOKEN')
82
+ )
83
+
84
  processor = AutoProcessor.from_pretrained(args.checkpoint_path, use_fast=False, trust_remote_code=True)
85
  return model, processor
86
 
 
112
 
113
 
114
  def _launch_demo(args, model, processor):
115
+ # 关键修复:减少 duration,添加调试信息
116
+ @spaces.GPU(duration=60)
117
  def call_local_model(model, processor, messages):
118
+ import time
119
+ start_time = time.time()
120
+ print(f"[DEBUG] 开始推理,时间: {start_time}")
121
+ print(f"[DEBUG] Messages: {messages}")
122
+
123
  messages = [messages]
124
  # 使用 processor 构造输入格式
125
  texts = [
126
  processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
127
  for msg in messages
128
  ]
129
+
130
+ prep_time = time.time()
131
+ print(f"[DEBUG] 模板处理耗时: {prep_time - start_time:.2f}s")
132
+
133
  image_inputs, video_inputs = process_vision_info(messages)
134
+
135
+ vision_time = time.time()
136
+ print(f"[DEBUG] 视觉处理耗时: {vision_time - prep_time:.2f}s")
137
+
138
  inputs = processor(
139
  text=texts,
140
  images=image_inputs,
 
144
  )
145
  inputs = inputs.to(model.device)
146
 
147
+ input_time = time.time()
148
+ print(f"[DEBUG] 输入处理耗时: {input_time - vision_time:.2f}s")
149
+ print(f"[DEBUG] Input shape: {inputs.input_ids.shape if 'input_ids' in inputs else 'N/A'}")
150
+
151
+ # 关键修复1: 大幅减少 max_new_tokens
152
+ # 关键修复2: 添加 EOS token 和停止条件
153
+ # 关键修复3: 添加超时保护
 
 
 
 
154
  with torch.no_grad():
155
  generated_ids = model.generate(
156
  **inputs,
157
+ max_new_tokens=512, # 从 8192 降到 512,避免无限生成
158
  repetition_penalty=1.03,
159
+ do_sample=False,
160
+ # 关键:设置 EOS token,确保能正常停止
161
+ eos_token_id=processor.tokenizer.eos_token_id,
162
+ pad_token_id=processor.tokenizer.pad_token_id,
163
+ # 添加提前停止条件
164
+ use_cache=True,
165
  )
166
 
167
+ gen_time = time.time()
168
+ print(f"[DEBUG] 生成耗时: {gen_time - input_time:.2f}s")
169
+ print(f"[DEBUG] Generated shape: {generated_ids.shape}")
170
+
171
  # 解码输出
172
  if "input_ids" in inputs:
173
  input_ids = inputs.input_ids
 
177
  generated_ids_trimmed = [
178
  out_ids[len(in_ids):] for in_ids, out_ids in zip(input_ids, generated_ids)
179
  ]
180
+
181
+ print(f"[DEBUG] Trimmed tokens count: {[len(ids) for ids in generated_ids_trimmed]}")
182
 
183
  output_texts = processor.batch_decode(
184
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
185
  )
186
+
187
+ decode_time = time.time()
188
+ print(f"[DEBUG] 解码耗时: {decode_time - gen_time:.2f}s")
189
+ print(f"[DEBUG] 总耗时: {decode_time - start_time:.2f}s")
190
+ print(f"[DEBUG] Output: {output_texts[0][:200]}...") # 只打印前200字符
191
+
192
  return output_texts
193
 
194