import os
import torch
import numpy as np
from PIL import Image
import spaces
from transformers import AutoProcessor
from qwen_vl_utils import process_vision_info # 请确保该模块在你的环境可用
from transformers import HunYuanVLForConditionalGeneration
import gradio as gr
from argparse import ArgumentParser
import copy
import requests
from io import BytesIO
import tempfile
import hashlib
import gc
def _get_args():
parser = ArgumentParser()
parser.add_argument('-c',
'--checkpoint-path',
type=str,
default='tencent/HunyuanOCR',
help='Checkpoint name or path, default to %(default)r')
parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only')
parser.add_argument('--flash-attn2',
action='store_true',
default=False,
help='Enable flash_attention_2 when loading the model.')
parser.add_argument('--share',
action='store_true',
default=False,
help='Create a publicly shareable link for the interface.')
parser.add_argument('--inbrowser',
action='store_true',
default=False,
help='Automatically launch the interface in a new tab on the default browser.')
args = parser.parse_args()
return args
def _load_model_processor(args):
# ZeroGPU 环境:模型在 CPU 上加载,使用 eager 模式
# 在 @spaces.GPU 装饰器内会自动移到 GPU
print(f"[INFO] 加载模型(ZeroGPU 环境使用 eager 模式)")
model = HunYuanVLForConditionalGeneration.from_pretrained(
args.checkpoint_path,
attn_implementation="eager", # ZeroGPU 必须用 eager,因为初始在 CPU
torch_dtype=torch.bfloat16,
device_map="auto", # 改回 auto,让 ZeroGPU 自动管理
token=os.environ.get('HF_TOKEN')
)
processor = AutoProcessor.from_pretrained(args.checkpoint_path, use_fast=False, trust_remote_code=True)
print(f"[INFO] 模型加载完成")
return model, processor
def _parse_text(text):
"""解析文本,处理特殊格式"""
# if text is None:
# return text
text = text.replace("|', '', text)
# return text
return text
def _gc():
"""垃圾回收"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _launch_demo(args, model, processor):
# 关键修复:移除 model 和 processor 参数,使用闭包访问
@spaces.GPU(duration=60)
def call_local_model(messages):
import time
start_time = time.time()
print(f"[DEBUG] ========== 开始推理 ==========")
# 关键:检查并确保模型在 GPU 上
model_device = next(model.parameters()).device
print(f"[DEBUG] Model device: {model_device}")
if str(model_device) == 'cpu':
print(f"[ERROR] 模型在 CPU 上!尝试移动到 GPU...")
model.cuda()
print(f"[DEBUG] Model device after cuda(): {next(model.parameters()).device}")
messages = [messages]
# 使用 processor 构造输入格式
texts = [
processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
for msg in messages
]
print(f"[DEBUG] 模板构建完成,耗时: {time.time() - start_time:.2f}s")
image_inputs, video_inputs = process_vision_info(messages)
print(f"[DEBUG] 图像处理完成,耗时: {time.time() - start_time:.2f}s")
inputs = processor(
text=texts,
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# 确保输入在 GPU 上
inputs = inputs.to('cuda' if torch.cuda.is_available() else 'cpu')
print(f"[DEBUG] 输入准备完成,耗时: {time.time() - start_time:.2f}s")
print(f"[DEBUG] Input IDs shape: {inputs.input_ids.shape}")
print(f"[DEBUG] Input device: {inputs.input_ids.device}")
# 生成
gen_start = time.time()
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=256,
repetition_penalty=1.03,
do_sample=False,
eos_token_id=processor.tokenizer.eos_token_id,
pad_token_id=processor.tokenizer.pad_token_id,
)
gen_time = time.time() - gen_start
print(f"[DEBUG] ========== 生成完成 ==========")
print(f"[DEBUG] 生成耗时: {gen_time:.2f}s")
print(f"[DEBUG] Output shape: {generated_ids.shape}")
# 解码输出
if "input_ids" in inputs:
input_ids = inputs.input_ids
else:
input_ids = inputs.inputs
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(input_ids, generated_ids)
]
actual_tokens = len(generated_ids_trimmed[0])
print(f"[DEBUG] 实际生成 token 数: {actual_tokens}")
print(f"[DEBUG] 每 token 耗时: {gen_time/actual_tokens if actual_tokens > 0 else 0:.3f}s")
output_texts = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
total_time = time.time() - start_time
print(f"[DEBUG] ========== 全部完成 ==========")
print(f"[DEBUG] 总耗时: {total_time:.2f}s")
print(f"[DEBUG] 输出长度: {len(output_texts[0])} 字符")
print(f"[DEBUG] 输出预览: {output_texts[0][:100]}...")
return output_texts
def create_predict_fn():
def predict(_chatbot, task_history):
nonlocal model, processor
chat_query = _chatbot[-1][0]
query = task_history[-1][0]
if len(chat_query) == 0:
_chatbot.pop()
task_history.pop()
return _chatbot
print('User: ', query)
history_cp = copy.deepcopy(task_history)
full_response = ''
messages = []
content = []
for q, a in history_cp:
if isinstance(q, (tuple, list)):
# 判断是URL还是本地路径
img_path = q[0]
if img_path.startswith(('http://', 'https://')):
content.append({'type': 'image', 'image': img_path})
else:
content.append({'type': 'image', 'image': f'{os.path.abspath(img_path)}'})
else:
content.append({'type': 'text', 'text': q})
messages.append({'role': 'user', 'content': content})
messages.append({'role': 'assistant', 'content': [{'type': 'text', 'text': a}]})
content = []
messages.pop()
# 调用模型获取响应(已修改:不再传递 model 和 processor)
response_list = call_local_model(messages)
response = response_list[0] if response_list else ""
_chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response)))
full_response = _parse_text(response)
task_history[-1] = (query, full_response)
print('HunyuanOCR: ' + _parse_text(full_response))
yield _chatbot
return predict
def create_regenerate_fn():
def regenerate(_chatbot, task_history):
nonlocal model, processor
if not task_history:
return _chatbot
item = task_history[-1]
if item[1] is None:
return _chatbot
task_history[-1] = (item[0], None)
chatbot_item = _chatbot.pop(-1)
if chatbot_item[0] is None:
_chatbot[-1] = (_chatbot[-1][0], None)
else:
_chatbot.append((chatbot_item[0], None))
# 使用外层的predict函数
_chatbot_gen = predict(_chatbot, task_history)
for _chatbot in _chatbot_gen:
yield _chatbot
return regenerate
predict = create_predict_fn()
regenerate = create_regenerate_fn()
def add_text(history, task_history, text):
task_text = text
history = history if history is not None else []
task_history = task_history if task_history is not None else []
history = history + [(_parse_text(text), None)]
task_history = task_history + [(task_text, None)]
return history, task_history, ''
def add_file(history, task_history, file):
history = history if history is not None else []
task_history = task_history if task_history is not None else []
history = history + [((file.name,), None)]
task_history = task_history + [((file.name,), None)]
return history, task_history
def download_url_image(url):
"""下载 URL 图片到本地临时文件"""
try:
# 使用 URL 的哈希值作为文件名,避免重复下载
url_hash = hashlib.md5(url.encode()).hexdigest()
temp_dir = tempfile.gettempdir()
temp_path = os.path.join(temp_dir, f"hyocr_demo_{url_hash}.png")
# 如果文件已存在,直接返回
if os.path.exists(temp_path):
return temp_path
# 下载图片
response = requests.get(url, timeout=10)
response.raise_for_status()
with open(temp_path, 'wb') as f:
f.write(response.content)
return temp_path
except Exception as e:
print(f"下载图片失败: {url}, 错误: {e}")
return url # 失败时返回原 URL
def reset_user_input():
return gr.update(value='')
def reset_state(_chatbot, task_history):
task_history.clear()
_chatbot.clear()
_gc()
return []
# 示例图片路径配置 - 请替换为实际图片路径
EXAMPLE_IMAGES = {
"spotting": "https://hunyuan-multimodal-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuan_multimodal/mllm_data/23cc43af9376b948f3febaf4ce854a8a.jpg?q-sign-algorithm=sha1&q-ak=AKIDbLEFMUYZgyERZnygUQLC7xkQ1hTAzulX&q-sign-time=1763523817%3B1794627877&q-key-time=1763523817%3B1794627877&q-header-list=host&q-url-param-list=&q-signature=8ebd6a9d3ed7eba73bb783c337349db9c29972e2", # TODO: 替换为场景文字示例图片路径
"parsing": "https://hunyuan-multimodal-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuan_multimodal/mllm_data/c4997ebd1be9f7c3e002fabba8b46cb7.jpg?q-sign-algorithm=sha1&q-ak=AKIDbLEFMUYZgyERZnygUQLC7xkQ1hTAzulX&q-sign-time=1763523818%3B1794627878&q-key-time=1763523818%3B1794627878&q-header-list=host&q-url-param-list=&q-signature=d2cd12be4c7902821c8c82203e4642624046911a",
"ie": "https://hunyuan-multimodal-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuan_multimodal/mllm_data/7c67c0f78e4423d51644a325da1f8e85.jpg?q-sign-algorithm=sha1&q-ak=AKIDbLEFMUYZgyERZnygUQLC7xkQ1hTAzulX&q-sign-time=1763523818%3B1794627878&q-key-time=1763523818%3B1794627878&q-header-list=host&q-url-param-list=&q-signature=803648f3253706f654faf1423869fd9e00e7056e",
"vqa": "https://hunyuan-multimodal-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuan_multimodal/mllm_data/fea0865d1c70c53aaa2ab91cd0e787f5.jpg?q-sign-algorithm=sha1&q-ak=AKIDbLEFMUYZgyERZnygUQLC7xkQ1hTAzulX&q-sign-time=1763523818%3B1794627878&q-key-time=1763523818%3B1794627878&q-header-list=host&q-url-param-list=&q-signature=a92b94e298a11aea130d730d3b16ee761acc3f4c",
"translation": "https://hunyuan-multimodal-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuan_multimodal/mllm_data/d1af99d35e9db9e820ebebb5bc68993a.jpg?q-sign-algorithm=sha1&q-ak=AKIDbLEFMUYZgyERZnygUQLC7xkQ1hTAzulX&q-sign-time=1763967603%3B1795071663&q-key-time=1763967603%3B1795071663&q-header-list=host&q-url-param-list=&q-signature=a57080c0b3d4c76ea74b88c6291f9004241c9d49",
# "spotting": "examples/spotting.jpg",
# "parsing": "examples/parsing.jpg",
# "ie": "examples/ie.jpg",
# "vqa": "examples/vqa.jpg",
# "translation": "examples/translation.jpg"
}
with gr.Blocks(css="""
body {
background: #f5f7fa;
}
.gradio-container {
max-width: 100% !important;
padding: 0 40px !important;
}
.header-section {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
padding: 30px 0;
margin: -20px -40px 30px -40px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
.header-content {
max-width: 1600px;
margin: 0 auto;
padding: 0 40px;
display: flex;
align-items: center;
gap: 20px;
}
.header-logo {
height: 60px;
}
.header-text h1 {
color: white;
font-size: 32px;
font-weight: bold;
margin: 0 0 5px 0;
}
.header-text p {
color: rgba(255,255,255,0.9);
margin: 0;
font-size: 14px;
}
.main-container {
max-width: 1800px;
margin: 0 auto;
}
.chatbot {
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08) !important;
border-radius: 12px !important;
border: 1px solid #e5e7eb !important;
background: white !important;
}
.input-panel {
background: white;
padding: 20px;
border-radius: 12px;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08);
border: 1px solid #e5e7eb;
}
.input-box textarea {
border: 2px solid #e5e7eb !important;
border-radius: 8px !important;
font-size: 14px !important;
}
.input-box textarea:focus {
border-color: #667eea !important;
}
.btn-primary {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
border: none !important;
color: white !important;
font-weight: 500 !important;
padding: 10px 24px !important;
font-size: 14px !important;
}
.btn-primary:hover {
transform: translateY(-1px) !important;
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4) !important;
}
.btn-secondary {
background: white !important;
border: 2px solid #667eea !important;
color: #667eea !important;
padding: 8px 20px !important;
font-size: 14px !important;
}
.btn-secondary:hover {
background: #f0f4ff !important;
}
.example-grid {
display: grid;
grid-template-columns: repeat(4, 1fr);
gap: 20px;
margin-top: 30px;
}
.example-card {
background: white;
border-radius: 12px;
overflow: hidden;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08);
border: 1px solid #e5e7eb;
transition: all 0.3s ease;
}
.example-card:hover {
transform: translateY(-4px);
box-shadow: 0 8px 20px rgba(102, 126, 234, 0.15);
border-color: #667eea;
}
.example-image-wrapper {
width: 100%;
height: 180px;
overflow: hidden;
background: #f5f7fa;
}
.example-image-wrapper img {
width: 100%;
height: 100%;
object-fit: cover;
}
.example-btn {
width: 100% !important;
white-space: pre-wrap !important;
text-align: left !important;
padding: 16px !important;
background: white !important;
border: none !important;
border-top: 1px solid #e5e7eb !important;
color: #1f2937 !important;
font-size: 14px !important;
line-height: 1.6 !important;
transition: all 0.3s ease !important;
font-weight: 500 !important;
}
.example-btn:hover {
background: #f9fafb !important;
color: #667eea !important;
}
.feature-section {
background: white;
padding: 24px;
border-radius: 12px;
margin-top: 30px;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08);
border: 1px solid #e5e7eb;
}
.section-title {
font-size: 18px;
font-weight: 600;
color: #1f2937;
margin-bottom: 20px;
padding-bottom: 12px;
border-bottom: 2px solid #e5e7eb;
}
""") as demo:
# 顶部导航栏
gr.HTML("""
Powered by Tencent Hunyuan Team
© 2025 Tencent Hunyuan Team. All rights reserved.
本系统基于 HunyuanOCR 构建 | 仅供学习研究使用