|
|
from typing import Dict, List, Any |
|
|
import torch |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
import os |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
""" |
|
|
Initialize the handler with the model and tokenizer. |
|
|
|
|
|
Args: |
|
|
path: Path to the model directory (not used - loads from original repo) |
|
|
""" |
|
|
|
|
|
|
|
|
model_id = "deepseek-ai/DeepSeek-OCR" |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_id, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
self.model = AutoModel.from_pretrained( |
|
|
model_id, |
|
|
_attn_implementation='flash_attention_2', |
|
|
trust_remote_code=True, |
|
|
use_safetensors=True |
|
|
) |
|
|
print("✓ Loaded with Flash Attention 2") |
|
|
except Exception as e: |
|
|
print(f"⚠ Flash Attention not available: {e}") |
|
|
print(" Loading with eager attention (slower but compatible)") |
|
|
self.model = AutoModel.from_pretrained( |
|
|
model_id, |
|
|
trust_remote_code=True, |
|
|
use_safetensors=True |
|
|
) |
|
|
|
|
|
|
|
|
self.model = self.model.eval().cuda() |
|
|
|
|
|
|
|
|
if torch.cuda.is_bf16_supported(): |
|
|
self.dtype = torch.bfloat16 |
|
|
self.model = self.model.to(torch.bfloat16) |
|
|
print("✓ Using bfloat16 precision") |
|
|
else: |
|
|
self.dtype = torch.float16 |
|
|
self.model = self.model.to(torch.float16) |
|
|
print("✓ Using float16 precision (bfloat16 not supported on this GPU)") |
|
|
|
|
|
|
|
|
|
|
|
original_infer = self.model.infer |
|
|
|
|
|
def patched_infer(*args, **kwargs): |
|
|
|
|
|
original_bfloat16 = torch.bfloat16 |
|
|
torch.bfloat16 = torch.float16 |
|
|
try: |
|
|
result = original_infer(*args, **kwargs) |
|
|
finally: |
|
|
torch.bfloat16 = original_bfloat16 |
|
|
return result |
|
|
|
|
|
self.model.infer = patched_infer |
|
|
print("✓ Patched model.infer to use float16") |
|
|
|
|
|
|
|
|
self.default_base_size = 1024 |
|
|
self.default_image_size = 640 |
|
|
self.default_crop_mode = True |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Process the input data and return OCR results. |
|
|
|
|
|
Args: |
|
|
data: Dictionary containing: |
|
|
- inputs: Either a base64 encoded image string or image URL |
|
|
- parameters (optional): Dict with: |
|
|
- prompt: Custom prompt (default: "<image>\n<|grounding|>Convert the document to markdown. ") |
|
|
- base_size: Base image size (default: 1024) |
|
|
- image_size: Crop image size (default: 640) |
|
|
- crop_mode: Whether to use crop mode (default: True) |
|
|
- save_results: Whether to save detailed results (default: False) |
|
|
- test_compress: Whether to test compression (default: False) |
|
|
|
|
|
Returns: |
|
|
List containing result dictionary with OCR text |
|
|
""" |
|
|
|
|
|
inputs = data.get("inputs", "") |
|
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
|
|
|
prompt = parameters.get("prompt", "<image>\n<|grounding|>Convert the document to markdown. ") |
|
|
base_size = parameters.get("base_size", self.default_base_size) |
|
|
image_size = parameters.get("image_size", self.default_image_size) |
|
|
crop_mode = parameters.get("crop_mode", self.default_crop_mode) |
|
|
save_results = parameters.get("save_results", False) |
|
|
test_compress = parameters.get("test_compress", False) |
|
|
|
|
|
|
|
|
image = self._process_image_input(inputs) |
|
|
|
|
|
|
|
|
temp_image_path = "/tmp/temp_ocr_image.jpg" |
|
|
output_path = "/tmp/ocr_output" |
|
|
image.save(temp_image_path) |
|
|
|
|
|
|
|
|
try: |
|
|
result = self.model.infer( |
|
|
self.tokenizer, |
|
|
prompt=prompt, |
|
|
image_file=temp_image_path, |
|
|
output_path=output_path, |
|
|
base_size=base_size, |
|
|
image_size=image_size, |
|
|
crop_mode=crop_mode, |
|
|
save_results=save_results, |
|
|
test_compress=test_compress, |
|
|
eval_mode=True |
|
|
) |
|
|
|
|
|
|
|
|
if os.path.exists(temp_image_path): |
|
|
os.remove(temp_image_path) |
|
|
|
|
|
return [{"text": result}] |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
if os.path.exists(temp_image_path): |
|
|
os.remove(temp_image_path) |
|
|
raise e |
|
|
|
|
|
def _process_image_input(self, inputs: str) -> Image.Image: |
|
|
""" |
|
|
Process image input from base64 string or URL. |
|
|
|
|
|
Args: |
|
|
inputs: Base64 encoded image string or image URL |
|
|
|
|
|
Returns: |
|
|
PIL Image object |
|
|
""" |
|
|
|
|
|
if inputs.startswith("data:image"): |
|
|
|
|
|
image_data = inputs.split(",")[1] |
|
|
image_bytes = base64.b64decode(image_data) |
|
|
image = Image.open(BytesIO(image_bytes)) |
|
|
elif inputs.startswith("http://") or inputs.startswith("https://"): |
|
|
|
|
|
import requests |
|
|
response = requests.get(inputs) |
|
|
image = Image.open(BytesIO(response.content)) |
|
|
else: |
|
|
|
|
|
try: |
|
|
image_bytes = base64.b64decode(inputs) |
|
|
image = Image.open(BytesIO(image_bytes)) |
|
|
except Exception: |
|
|
raise ValueError( |
|
|
"Invalid input format. Please provide either a base64 encoded image " |
|
|
"or an image URL." |
|
|
) |
|
|
|
|
|
return image.convert("RGB") |
|
|
|