harmonic-P's picture
Update handler.py
b6ff38c verified
from typing import Dict, List, Any
import torch
from transformers import AutoModel, AutoTokenizer
import base64
from io import BytesIO
from PIL import Image
import os
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the handler with the model and tokenizer.
Args:
path: Path to the model directory (not used - loads from original repo)
"""
# Load from the original DeepSeek-OCR repository
# This avoids needing to copy 40GB+ of model weights
model_id = "deepseek-ai/DeepSeek-OCR"
# Load tokenizer and model with trust_remote_code=True for custom code
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True
)
# Try to load with flash attention, fall back to eager if not available
try:
self.model = AutoModel.from_pretrained(
model_id,
_attn_implementation='flash_attention_2',
trust_remote_code=True,
use_safetensors=True
)
print("✓ Loaded with Flash Attention 2")
except Exception as e:
print(f"⚠ Flash Attention not available: {e}")
print(" Loading with eager attention (slower but compatible)")
self.model = AutoModel.from_pretrained(
model_id,
trust_remote_code=True,
use_safetensors=True
)
# Set model to evaluation mode and move to GPU
self.model = self.model.eval().cuda()
# Determine supported dtype and configure model
if torch.cuda.is_bf16_supported():
self.dtype = torch.bfloat16
self.model = self.model.to(torch.bfloat16)
print("✓ Using bfloat16 precision")
else:
self.dtype = torch.float16
self.model = self.model.to(torch.float16)
print("✓ Using float16 precision (bfloat16 not supported on this GPU)")
# Monkey-patch the model to use float16 instead of hardcoded bfloat16
# The DeepSeek-OCR infer method hardcodes .to(torch.bfloat16)
original_infer = self.model.infer
def patched_infer(*args, **kwargs):
# Temporarily replace torch.bfloat16 with float16
original_bfloat16 = torch.bfloat16
torch.bfloat16 = torch.float16
try:
result = original_infer(*args, **kwargs)
finally:
torch.bfloat16 = original_bfloat16
return result
self.model.infer = patched_infer
print("✓ Patched model.infer to use float16")
# Default parameters
self.default_base_size = 1024
self.default_image_size = 640
self.default_crop_mode = True
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the input data and return OCR results.
Args:
data: Dictionary containing:
- inputs: Either a base64 encoded image string or image URL
- parameters (optional): Dict with:
- prompt: Custom prompt (default: "<image>\n<|grounding|>Convert the document to markdown. ")
- base_size: Base image size (default: 1024)
- image_size: Crop image size (default: 640)
- crop_mode: Whether to use crop mode (default: True)
- save_results: Whether to save detailed results (default: False)
- test_compress: Whether to test compression (default: False)
Returns:
List containing result dictionary with OCR text
"""
# Extract inputs
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Get parameters with defaults
prompt = parameters.get("prompt", "<image>\n<|grounding|>Convert the document to markdown. ")
base_size = parameters.get("base_size", self.default_base_size)
image_size = parameters.get("image_size", self.default_image_size)
crop_mode = parameters.get("crop_mode", self.default_crop_mode)
save_results = parameters.get("save_results", False)
test_compress = parameters.get("test_compress", False)
# Process image input
image = self._process_image_input(inputs)
# Save image temporarily
temp_image_path = "/tmp/temp_ocr_image.jpg"
output_path = "/tmp/ocr_output"
image.save(temp_image_path)
# Run inference
try:
result = self.model.infer(
self.tokenizer,
prompt=prompt,
image_file=temp_image_path,
output_path=output_path,
base_size=base_size,
image_size=image_size,
crop_mode=crop_mode,
save_results=save_results,
test_compress=test_compress,
eval_mode=True # Return text directly without streaming
)
# Clean up temporary file
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
return [{"text": result}]
except Exception as e:
# Clean up on error
if os.path.exists(temp_image_path):
os.remove(temp_image_path)
raise e
def _process_image_input(self, inputs: str) -> Image.Image:
"""
Process image input from base64 string or URL.
Args:
inputs: Base64 encoded image string or image URL
Returns:
PIL Image object
"""
# Check if input is base64 encoded
if inputs.startswith("data:image"):
# Format: data:image/jpeg;base64,/9j/4AAQ...
image_data = inputs.split(",")[1]
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes))
elif inputs.startswith("http://") or inputs.startswith("https://"):
# URL input
import requests
response = requests.get(inputs)
image = Image.open(BytesIO(response.content))
else:
# Assume it's raw base64
try:
image_bytes = base64.b64decode(inputs)
image = Image.open(BytesIO(image_bytes))
except Exception:
raise ValueError(
"Invalid input format. Please provide either a base64 encoded image "
"or an image URL."
)
return image.convert("RGB")