import os import torch import numpy as np from PIL import Image import spaces from transformers import AutoProcessor from qwen_vl_utils import process_vision_info # 请确保该模块在你的环境可用 from transformers import HunYuanVLForConditionalGeneration import gradio as gr from argparse import ArgumentParser import copy import requests from io import BytesIO import tempfile import hashlib import gc def _get_args(): parser = ArgumentParser() parser.add_argument('-c', '--checkpoint-path', type=str, default='tencent/HunyuanOCR', help='Checkpoint name or path, default to %(default)r') parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only') parser.add_argument('--flash-attn2', action='store_true', default=False, help='Enable flash_attention_2 when loading the model.') parser.add_argument('--share', action='store_true', default=False, help='Create a publicly shareable link for the interface.') parser.add_argument('--inbrowser', action='store_true', default=False, help='Automatically launch the interface in a new tab on the default browser.') args = parser.parse_args() return args def _load_model_processor(args): # ZeroGPU 环境:模型在 CPU 上加载,使用 eager 模式 # 在 @spaces.GPU 装饰器内会自动移到 GPU print(f"[INFO] 加载模型(ZeroGPU 环境使用 eager 模式)") model = HunYuanVLForConditionalGeneration.from_pretrained( args.checkpoint_path, attn_implementation="eager", # ZeroGPU 必须用 eager,因为初始在 CPU torch_dtype=torch.bfloat16, device_map="auto", # 改回 auto,让 ZeroGPU 自动管理 token=os.environ.get('HF_TOKEN') ) processor = AutoProcessor.from_pretrained(args.checkpoint_path, use_fast=False, trust_remote_code=True) print(f"[INFO] 模型加载完成") return model, processor def _parse_text(text): """解析文本,处理特殊格式""" # if text is None: # return text text = text.replace("", "").replace("", "") return text def _remove_image_special(text): """移除图像特殊标记""" # if text is None: # return text # # 移除可能的图像特殊标记 # import re # text = re.sub(r'|||', '', text) # return text return text def _gc(): """垃圾回收""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def _launch_demo(args, model, processor): # 关键修复:移除 model 和 processor 参数,使用闭包访问 @spaces.GPU(duration=60) def call_local_model(messages): import time start_time = time.time() print(f"[DEBUG] ========== 开始推理 ==========") # 关键:检查并确保模型在 GPU 上 model_device = next(model.parameters()).device print(f"[DEBUG] Model device: {model_device}") if str(model_device) == 'cpu': print(f"[ERROR] 模型在 CPU 上!尝试移动到 GPU...") model.cuda() print(f"[DEBUG] Model device after cuda(): {next(model.parameters()).device}") messages = [messages] # 使用 processor 构造输入格式 texts = [ processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages ] print(f"[DEBUG] 模板构建完成,耗时: {time.time() - start_time:.2f}s") image_inputs, video_inputs = process_vision_info(messages) print(f"[DEBUG] 图像处理完成,耗时: {time.time() - start_time:.2f}s") inputs = processor( text=texts, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) # 确保输入在 GPU 上 inputs = inputs.to('cuda' if torch.cuda.is_available() else 'cpu') print(f"[DEBUG] 输入准备完成,耗时: {time.time() - start_time:.2f}s") print(f"[DEBUG] Input IDs shape: {inputs.input_ids.shape}") print(f"[DEBUG] Input device: {inputs.input_ids.device}") # 生成 gen_start = time.time() with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=256, repetition_penalty=1.03, do_sample=False, eos_token_id=processor.tokenizer.eos_token_id, pad_token_id=processor.tokenizer.pad_token_id, ) gen_time = time.time() - gen_start print(f"[DEBUG] ========== 生成完成 ==========") print(f"[DEBUG] 生成耗时: {gen_time:.2f}s") print(f"[DEBUG] Output shape: {generated_ids.shape}") # 解码输出 if "input_ids" in inputs: input_ids = inputs.input_ids else: input_ids = inputs.inputs generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(input_ids, generated_ids) ] actual_tokens = len(generated_ids_trimmed[0]) print(f"[DEBUG] 实际生成 token 数: {actual_tokens}") print(f"[DEBUG] 每 token 耗时: {gen_time/actual_tokens if actual_tokens > 0 else 0:.3f}s") output_texts = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) total_time = time.time() - start_time print(f"[DEBUG] ========== 全部完成 ==========") print(f"[DEBUG] 总耗时: {total_time:.2f}s") print(f"[DEBUG] 输出长度: {len(output_texts[0])} 字符") print(f"[DEBUG] 输出预览: {output_texts[0][:100]}...") return output_texts def create_predict_fn(): def predict(_chatbot, task_history): nonlocal model, processor chat_query = _chatbot[-1][0] query = task_history[-1][0] if len(chat_query) == 0: _chatbot.pop() task_history.pop() return _chatbot print('User: ', query) history_cp = copy.deepcopy(task_history) full_response = '' messages = [] content = [] for q, a in history_cp: if isinstance(q, (tuple, list)): # 判断是URL还是本地路径 img_path = q[0] if img_path.startswith(('http://', 'https://')): content.append({'type': 'image', 'image': img_path}) else: content.append({'type': 'image', 'image': f'{os.path.abspath(img_path)}'}) else: content.append({'type': 'text', 'text': q}) messages.append({'role': 'user', 'content': content}) messages.append({'role': 'assistant', 'content': [{'type': 'text', 'text': a}]}) content = [] messages.pop() # 调用模型获取响应(已修改:不再传递 model 和 processor) response_list = call_local_model(messages) response = response_list[0] if response_list else "" _chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response))) full_response = _parse_text(response) task_history[-1] = (query, full_response) print('HunyuanOCR: ' + _parse_text(full_response)) yield _chatbot return predict def create_regenerate_fn(): def regenerate(_chatbot, task_history): nonlocal model, processor if not task_history: return _chatbot item = task_history[-1] if item[1] is None: return _chatbot task_history[-1] = (item[0], None) chatbot_item = _chatbot.pop(-1) if chatbot_item[0] is None: _chatbot[-1] = (_chatbot[-1][0], None) else: _chatbot.append((chatbot_item[0], None)) # 使用外层的predict函数 _chatbot_gen = predict(_chatbot, task_history) for _chatbot in _chatbot_gen: yield _chatbot return regenerate predict = create_predict_fn() regenerate = create_regenerate_fn() def add_text(history, task_history, text): task_text = text history = history if history is not None else [] task_history = task_history if task_history is not None else [] history = history + [(_parse_text(text), None)] task_history = task_history + [(task_text, None)] return history, task_history, '' def add_file(history, task_history, file): history = history if history is not None else [] task_history = task_history if task_history is not None else [] history = history + [((file.name,), None)] task_history = task_history + [((file.name,), None)] return history, task_history def download_url_image(url): """下载 URL 图片到本地临时文件""" try: # 使用 URL 的哈希值作为文件名,避免重复下载 url_hash = hashlib.md5(url.encode()).hexdigest() temp_dir = tempfile.gettempdir() temp_path = os.path.join(temp_dir, f"hyocr_demo_{url_hash}.png") # 如果文件已存在,直接返回 if os.path.exists(temp_path): return temp_path # 下载图片 response = requests.get(url, timeout=10) response.raise_for_status() with open(temp_path, 'wb') as f: f.write(response.content) return temp_path except Exception as e: print(f"下载图片失败: {url}, 错误: {e}") return url # 失败时返回原 URL def reset_user_input(): return gr.update(value='') def reset_state(_chatbot, task_history): task_history.clear() _chatbot.clear() _gc() return [] # 示例图片路径配置 - 请替换为实际图片路径 EXAMPLE_IMAGES = { "spotting": "https://hunyuan-multimodal-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuan_multimodal/mllm_data/23cc43af9376b948f3febaf4ce854a8a.jpg?q-sign-algorithm=sha1&q-ak=AKIDbLEFMUYZgyERZnygUQLC7xkQ1hTAzulX&q-sign-time=1763523817%3B1794627877&q-key-time=1763523817%3B1794627877&q-header-list=host&q-url-param-list=&q-signature=8ebd6a9d3ed7eba73bb783c337349db9c29972e2", # TODO: 替换为场景文字示例图片路径 "parsing": "https://hunyuan-multimodal-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuan_multimodal/mllm_data/c4997ebd1be9f7c3e002fabba8b46cb7.jpg?q-sign-algorithm=sha1&q-ak=AKIDbLEFMUYZgyERZnygUQLC7xkQ1hTAzulX&q-sign-time=1763523818%3B1794627878&q-key-time=1763523818%3B1794627878&q-header-list=host&q-url-param-list=&q-signature=d2cd12be4c7902821c8c82203e4642624046911a", "ie": "https://hunyuan-multimodal-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuan_multimodal/mllm_data/7c67c0f78e4423d51644a325da1f8e85.jpg?q-sign-algorithm=sha1&q-ak=AKIDbLEFMUYZgyERZnygUQLC7xkQ1hTAzulX&q-sign-time=1763523818%3B1794627878&q-key-time=1763523818%3B1794627878&q-header-list=host&q-url-param-list=&q-signature=803648f3253706f654faf1423869fd9e00e7056e", "vqa": "https://hunyuan-multimodal-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuan_multimodal/mllm_data/fea0865d1c70c53aaa2ab91cd0e787f5.jpg?q-sign-algorithm=sha1&q-ak=AKIDbLEFMUYZgyERZnygUQLC7xkQ1hTAzulX&q-sign-time=1763523818%3B1794627878&q-key-time=1763523818%3B1794627878&q-header-list=host&q-url-param-list=&q-signature=a92b94e298a11aea130d730d3b16ee761acc3f4c", "translation": "https://hunyuan-multimodal-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuan_multimodal/mllm_data/d1af99d35e9db9e820ebebb5bc68993a.jpg?q-sign-algorithm=sha1&q-ak=AKIDbLEFMUYZgyERZnygUQLC7xkQ1hTAzulX&q-sign-time=1763967603%3B1795071663&q-key-time=1763967603%3B1795071663&q-header-list=host&q-url-param-list=&q-signature=a57080c0b3d4c76ea74b88c6291f9004241c9d49", # "spotting": "examples/spotting.jpg", # "parsing": "examples/parsing.jpg", # "ie": "examples/ie.jpg", # "vqa": "examples/vqa.jpg", # "translation": "examples/translation.jpg" } with gr.Blocks(css=""" body { background: #f5f7fa; } .gradio-container { max-width: 100% !important; padding: 0 40px !important; } .header-section { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 30px 0; margin: -20px -40px 30px -40px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); } .header-content { max-width: 1600px; margin: 0 auto; padding: 0 40px; display: flex; align-items: center; gap: 20px; } .header-logo { height: 60px; } .header-text h1 { color: white; font-size: 32px; font-weight: bold; margin: 0 0 5px 0; } .header-text p { color: rgba(255,255,255,0.9); margin: 0; font-size: 14px; } .main-container { max-width: 1800px; margin: 0 auto; } .chatbot { box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08) !important; border-radius: 12px !important; border: 1px solid #e5e7eb !important; background: white !important; } .input-panel { background: white; padding: 20px; border-radius: 12px; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08); border: 1px solid #e5e7eb; } .input-box textarea { border: 2px solid #e5e7eb !important; border-radius: 8px !important; font-size: 14px !important; } .input-box textarea:focus { border-color: #667eea !important; } .btn-primary { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; border: none !important; color: white !important; font-weight: 500 !important; padding: 10px 24px !important; font-size: 14px !important; } .btn-primary:hover { transform: translateY(-1px) !important; box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4) !important; } .btn-secondary { background: white !important; border: 2px solid #667eea !important; color: #667eea !important; padding: 8px 20px !important; font-size: 14px !important; } .btn-secondary:hover { background: #f0f4ff !important; } .example-grid { display: grid; grid-template-columns: repeat(4, 1fr); gap: 20px; margin-top: 30px; } .example-card { background: white; border-radius: 12px; overflow: hidden; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08); border: 1px solid #e5e7eb; transition: all 0.3s ease; } .example-card:hover { transform: translateY(-4px); box-shadow: 0 8px 20px rgba(102, 126, 234, 0.15); border-color: #667eea; } .example-image-wrapper { width: 100%; height: 180px; overflow: hidden; background: #f5f7fa; } .example-image-wrapper img { width: 100%; height: 100%; object-fit: cover; } .example-btn { width: 100% !important; white-space: pre-wrap !important; text-align: left !important; padding: 16px !important; background: white !important; border: none !important; border-top: 1px solid #e5e7eb !important; color: #1f2937 !important; font-size: 14px !important; line-height: 1.6 !important; transition: all 0.3s ease !important; font-weight: 500 !important; } .example-btn:hover { background: #f9fafb !important; color: #667eea !important; } .feature-section { background: white; padding: 24px; border-radius: 12px; margin-top: 30px; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08); border: 1px solid #e5e7eb; } .section-title { font-size: 18px; font-weight: 600; color: #1f2937; margin-bottom: 20px; padding-bottom: 12px; border-bottom: 2px solid #e5e7eb; } """) as demo: # 顶部导航栏 gr.HTML("""

HunyuanOCR

Powered by Tencent Hunyuan Team

""") with gr.Column(elem_classes=["main-container"]): # 对话区域 - 全宽 chatbot = gr.Chatbot( label='💬 对话窗口', height=600, bubble_full_width=False, layout="bubble", show_copy_button=True, avatar_images=(None, "https://hunyuan-multimodal-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuan_multimodal/mllm_data/6ef6928b21b323b2b00115f86a779d8f.png?q-sign-algorithm=sha1&q-ak=AKIDbLEFMUYZgyERZnygUQLC7xkQ1hTAzulX&q-sign-time=1763450355%3B1794554415&q-key-time=1763450355%3B1794554415&q-header-list=host&q-url-param-list=&q-signature=41328696dc34571324aa18c791c1196192e729c6"), elem_classes=["chatbot"] ) # 输入控制面板 - 全宽 with gr.Group(elem_classes=["input-panel"]): query = gr.Textbox( lines=2, label='💭 输入您的问题', placeholder='请先上传图片,然后输入问题。例如:检测并识别图片中的文字,将文本坐标格式化输出。', elem_classes=["input-box"], show_label=False ) with gr.Row(): addfile_btn = gr.UploadButton('📁 上传图片', file_types=['image'], elem_classes=["btn-secondary"]) submit_btn = gr.Button('🚀 发送消息', variant="primary", elem_classes=["btn-primary"], scale=3) regen_btn = gr.Button('🔄 重新生成', elem_classes=["btn-secondary"]) empty_bin = gr.Button('🗑️ 清空对话', elem_classes=["btn-secondary"]) # 示例区域 - 5列网格布局 gr.HTML('
📚 快速体验示例 - 点击下方卡片快速加载
') with gr.Row(): # 示例1:spotting with gr.Column(scale=1): with gr.Group(elem_classes=["example-card"]): gr.HTML("""
文字检测识别
""") example_1_btn = gr.Button("🔍 文字检测和识别", elem_classes=["example-btn"]) # 示例2:parsing with gr.Column(scale=1): with gr.Group(elem_classes=["example-card"]): gr.HTML("""
文档解析
""") example_2_btn = gr.Button("📋 文档解析", elem_classes=["example-btn"]) # 示例3:ie with gr.Column(scale=1): with gr.Group(elem_classes=["example-card"]): gr.HTML("""
信息抽取
""") example_3_btn = gr.Button("🎯 信息抽取", elem_classes=["example-btn"]) # 示例4:VQA with gr.Column(scale=1): with gr.Group(elem_classes=["example-card"]): gr.HTML("""
视觉问答
""") example_4_btn = gr.Button("💬 视觉问答", elem_classes=["example-btn"]) # 示例5:translation with gr.Column(scale=1): with gr.Group(elem_classes=["example-card"]): gr.HTML("""
图片翻译
""") example_5_btn = gr.Button("🌐 图片翻译", elem_classes=["example-btn"]) task_history = gr.State([]) # 示例1:文档识别 def load_example_1(history, task_hist): prompt = "检测并识别图片中的文字,将文本坐标格式化输出。" image_url = EXAMPLE_IMAGES["spotting"] # 下载 URL 图片到本地 image_path = download_url_image(image_url) # 清空对话历史 history = [] task_hist = [] history = history + [((image_path,), None)] task_hist = task_hist + [((image_path,), None)] return history, task_hist, prompt # 示例2:场景文字 def load_example_2(history, task_hist): prompt = "提取文档图片中正文的所有信息用markdown 格式表示,其中页眉、页脚部分忽略,表格用html 格式表达,文档中公式用latex 格式表示,按照阅读顺序组织进行解析。" image_url = EXAMPLE_IMAGES["parsing"] # 下载 URL 图片到本地 image_path = download_url_image(image_url) # 清空对话历史 history = [] task_hist = [] history = history + [((image_path,), None)] task_hist = task_hist + [((image_path,), None)] return history, task_hist, prompt # 示例3:表格提取 def load_example_3(history, task_hist): prompt = "提取图片中的:['单价', '上车时间','发票号码', '省前缀', '总金额', '发票代码', '下车时间', '里程数'] 的字段内容,并且按照JSON格式返回。" image_url = EXAMPLE_IMAGES["ie"] # 下载 URL 图片到本地 image_path = download_url_image(image_url) # 清空对话历史 history = [] task_hist = [] history = history + [((image_path,), None)] task_hist = task_hist + [((image_path,), None)] return history, task_hist, prompt # 示例4:手写体 def load_example_4(history, task_hist): prompt = "What is the highest life expectancy at birth of male?" image_url = EXAMPLE_IMAGES["vqa"] # 下载 URL 图片到本地 image_path = download_url_image(image_url) # 清空对话历史 history = [] task_hist = [] history = history + [((image_path,), None)] task_hist = task_hist + [((image_path,), None)] return history, task_hist, prompt # 示例5:翻译 def load_example_5(history, task_hist): prompt = "将图中文字翻译为中文。" image_url = EXAMPLE_IMAGES["translation"] # 下载 URL 图片到本地 image_path = download_url_image(image_url) # 清空对话历史 history = [] task_hist = [] history = history + [((image_path,), None)] task_hist = task_hist + [((image_path,), None)] return history, task_hist, prompt # 绑定事件 example_1_btn.click(load_example_1, [chatbot, task_history], [chatbot, task_history, query]) example_2_btn.click(load_example_2, [chatbot, task_history], [chatbot, task_history, query]) example_3_btn.click(load_example_3, [chatbot, task_history], [chatbot, task_history, query]) example_4_btn.click(load_example_4, [chatbot, task_history], [chatbot, task_history, query]) example_5_btn.click(load_example_5, [chatbot, task_history], [chatbot, task_history, query]) submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(predict, [chatbot, task_history], [chatbot], show_progress=True) submit_btn.click(reset_user_input, [], [query]) empty_bin.click(reset_state, [chatbot, task_history], [chatbot], show_progress=True) regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True) addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True) # 功能说明区域 with gr.Row(): with gr.Column(scale=1): gr.HTML("""
✨ 核心功能
""") with gr.Column(scale=1): gr.HTML("""
💡 使用建议
""") # 底部版权信息 gr.HTML("""

© 2025 Tencent Hunyuan Team. All rights reserved.

本系统基于 HunyuanOCR 构建 | 仅供学习研究使用

""") demo.queue().launch( share=args.share, inbrowser=args.inbrowser, # server_port=args.server_port, # server_name=args.server_name, ) def main(): args = _get_args() model, processor = _load_model_processor(args) _launch_demo(args, model, processor) if __name__ == '__main__': main()