|
|
import os |
|
|
import base64 |
|
|
import requests |
|
|
from io import BytesIO |
|
|
from typing import List, Union |
|
|
|
|
|
from PIL import Image |
|
|
import pypdfium2 as pdfium |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
HF_API_TOKEN = os.environ.get("ocr_model") |
|
|
if HF_API_TOKEN is None: |
|
|
raise RuntimeError( |
|
|
"环境变量 ocr_model 未设置,请在 Space 的 Settings -> Variables 中添加一个名为 ocr_model 的 Secret。" |
|
|
) |
|
|
|
|
|
|
|
|
MODEL_ID = "tencent/HunyuanOCR" |
|
|
API_URL = f"https://api-inference.huggingface.co/models/{MODEL_ID}" |
|
|
HEADERS = {"Authorization": f"Bearer {HF_API_TOKEN}"} |
|
|
|
|
|
|
|
|
def image_to_base64(image: Image.Image) -> str: |
|
|
"""把 PIL Image 转成 base64 字符串""" |
|
|
buffered = BytesIO() |
|
|
image.save(buffered, format="PNG") |
|
|
img_bytes = buffered.getvalue() |
|
|
img_b64 = base64.b64encode(img_bytes).decode("utf-8") |
|
|
return img_b64 |
|
|
|
|
|
|
|
|
def call_ocr_model(image: Image.Image) -> str: |
|
|
"""对单张图片调用 HunyuanOCR""" |
|
|
img_b64 = image_to_base64(image) |
|
|
|
|
|
|
|
|
payload = { |
|
|
"inputs": { |
|
|
"image": img_b64 |
|
|
} |
|
|
} |
|
|
|
|
|
try: |
|
|
response = requests.post(API_URL, headers=HEADERS, json=payload, timeout=120) |
|
|
response.raise_for_status() |
|
|
except Exception as e: |
|
|
return f"[调用模型出错] {type(e).__name__}: {e}" |
|
|
|
|
|
try: |
|
|
data = response.json() |
|
|
except Exception as e: |
|
|
return f"[解析返回结果出错] {type(e).__name__}: {e}\n原始返回:{response.text[:1000]}" |
|
|
|
|
|
|
|
|
if isinstance(data, list) and len(data) > 0: |
|
|
first = data[0] |
|
|
if isinstance(first, dict): |
|
|
for key in ["generated_text", "text", "output", "label"]: |
|
|
if key in first and isinstance(first[key], str): |
|
|
return first[key].strip() |
|
|
return str(first) |
|
|
|
|
|
if isinstance(data, dict): |
|
|
for key in ["generated_text", "text", "output", "label"]: |
|
|
if key in data and isinstance(data[key], str): |
|
|
return data[key].strip() |
|
|
return str(data) |
|
|
|
|
|
return str(data) |
|
|
|
|
|
|
|
|
def pdf_to_images(pdf_bytes: bytes, dpi: int = 200) -> List[Image.Image]: |
|
|
"""把 PDF 的每一页渲染成 PIL Image 列表""" |
|
|
pdf = pdfium.PdfDocument(pdf_bytes) |
|
|
n_pages = len(pdf) |
|
|
images: List[Image.Image] = [] |
|
|
|
|
|
for i in range(n_pages): |
|
|
page = pdf[i] |
|
|
|
|
|
pil_image = page.render(scale=dpi / 72).to_pil() |
|
|
images.append(pil_image) |
|
|
|
|
|
return images |
|
|
|
|
|
|
|
|
def run_ocr(file: Union[bytes, None], image: Union[Image.Image, None]) -> str: |
|
|
""" |
|
|
总入口:可以上传 PDF 或 图片。 |
|
|
- 如果上传了 PDF(file),对 PDF 每一页做 OCR |
|
|
- 如果只上传图片,对图片做 OCR |
|
|
- 如果两个都没传,提示用户 |
|
|
""" |
|
|
if file is None and image is None: |
|
|
return "请上传 PDF 文件或图片。" |
|
|
|
|
|
results = [] |
|
|
|
|
|
|
|
|
if file is not None: |
|
|
try: |
|
|
pdf_bytes = file |
|
|
pages = pdf_to_images(pdf_bytes) |
|
|
except Exception as e: |
|
|
return f"[解析 PDF 出错] {type(e).__name__}: {e}" |
|
|
|
|
|
if not pages: |
|
|
return "PDF 中未检测到页面。" |
|
|
|
|
|
for idx, page_img in enumerate(pages, start=1): |
|
|
text = call_ocr_model(page_img) |
|
|
results.append(f"===== 第 {idx} 页 =====\n{text}\n") |
|
|
|
|
|
|
|
|
if image is not None: |
|
|
text = call_ocr_model(image) |
|
|
if results: |
|
|
results.append("===== 图片识别结果 =====\n" + text) |
|
|
else: |
|
|
results.append(text) |
|
|
|
|
|
return "\n".join(results) |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown( |
|
|
f"""# 文档 OCR Demo(HunyuanOCR) |
|
|
使用模型:`{MODEL_ID}` |
|
|
|
|
|
你可以: |
|
|
- 上传 **PDF 文件**(多页会逐页识别,并按页分隔) |
|
|
- 或上传 **单张图片**(截图、拍照等) |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
pdf_input = gr.File( |
|
|
label="上传 PDF 文件(可选)", |
|
|
file_types=[".pdf"], |
|
|
type="binary", |
|
|
) |
|
|
image_input = gr.Image( |
|
|
type="pil", |
|
|
label="上传图片(可选)", |
|
|
) |
|
|
run_button = gr.Button("开始识别") |
|
|
with gr.Column(): |
|
|
output_text = gr.Textbox(label="识别结果", lines=25) |
|
|
|
|
|
run_button.click( |
|
|
fn=run_ocr, |
|
|
inputs=[pdf_input, image_input], |
|
|
outputs=output_text, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |