from typing import Dict, List, Any import torch from transformers import AutoModel, AutoTokenizer import base64 from io import BytesIO from PIL import Image import os class EndpointHandler: def __init__(self, path=""): """ Initialize the handler with the model and tokenizer. Args: path: Path to the model directory (not used - loads from original repo) """ # Load from the original DeepSeek-OCR repository # This avoids needing to copy 40GB+ of model weights model_id = "deepseek-ai/DeepSeek-OCR" # Load tokenizer and model with trust_remote_code=True for custom code self.tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True ) # Try to load with flash attention, fall back to eager if not available try: self.model = AutoModel.from_pretrained( model_id, _attn_implementation='flash_attention_2', trust_remote_code=True, use_safetensors=True ) print("✓ Loaded with Flash Attention 2") except Exception as e: print(f"⚠ Flash Attention not available: {e}") print(" Loading with eager attention (slower but compatible)") self.model = AutoModel.from_pretrained( model_id, trust_remote_code=True, use_safetensors=True ) # Set model to evaluation mode and move to GPU self.model = self.model.eval().cuda() # Determine supported dtype and configure model if torch.cuda.is_bf16_supported(): self.dtype = torch.bfloat16 self.model = self.model.to(torch.bfloat16) print("✓ Using bfloat16 precision") else: self.dtype = torch.float16 self.model = self.model.to(torch.float16) print("✓ Using float16 precision (bfloat16 not supported on this GPU)") # Monkey-patch the model to use float16 instead of hardcoded bfloat16 # The DeepSeek-OCR infer method hardcodes .to(torch.bfloat16) original_infer = self.model.infer def patched_infer(*args, **kwargs): # Temporarily replace torch.bfloat16 with float16 original_bfloat16 = torch.bfloat16 torch.bfloat16 = torch.float16 try: result = original_infer(*args, **kwargs) finally: torch.bfloat16 = original_bfloat16 return result self.model.infer = patched_infer print("✓ Patched model.infer to use float16") # Default parameters self.default_base_size = 1024 self.default_image_size = 640 self.default_crop_mode = True def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process the input data and return OCR results. Args: data: Dictionary containing: - inputs: Either a base64 encoded image string or image URL - parameters (optional): Dict with: - prompt: Custom prompt (default: "\n<|grounding|>Convert the document to markdown. ") - base_size: Base image size (default: 1024) - image_size: Crop image size (default: 640) - crop_mode: Whether to use crop mode (default: True) - save_results: Whether to save detailed results (default: False) - test_compress: Whether to test compression (default: False) Returns: List containing result dictionary with OCR text """ # Extract inputs inputs = data.get("inputs", "") parameters = data.get("parameters", {}) # Get parameters with defaults prompt = parameters.get("prompt", "\n<|grounding|>Convert the document to markdown. ") base_size = parameters.get("base_size", self.default_base_size) image_size = parameters.get("image_size", self.default_image_size) crop_mode = parameters.get("crop_mode", self.default_crop_mode) save_results = parameters.get("save_results", False) test_compress = parameters.get("test_compress", False) # Process image input image = self._process_image_input(inputs) # Save image temporarily temp_image_path = "/tmp/temp_ocr_image.jpg" output_path = "/tmp/ocr_output" image.save(temp_image_path) # Run inference try: result = self.model.infer( self.tokenizer, prompt=prompt, image_file=temp_image_path, output_path=output_path, base_size=base_size, image_size=image_size, crop_mode=crop_mode, save_results=save_results, test_compress=test_compress, eval_mode=True # Return text directly without streaming ) # Clean up temporary file if os.path.exists(temp_image_path): os.remove(temp_image_path) return [{"text": result}] except Exception as e: # Clean up on error if os.path.exists(temp_image_path): os.remove(temp_image_path) raise e def _process_image_input(self, inputs: str) -> Image.Image: """ Process image input from base64 string or URL. Args: inputs: Base64 encoded image string or image URL Returns: PIL Image object """ # Check if input is base64 encoded if inputs.startswith("data:image"): # Format: data:image/jpeg;base64,/9j/4AAQ... image_data = inputs.split(",")[1] image_bytes = base64.b64decode(image_data) image = Image.open(BytesIO(image_bytes)) elif inputs.startswith("http://") or inputs.startswith("https://"): # URL input import requests response = requests.get(inputs) image = Image.open(BytesIO(response.content)) else: # Assume it's raw base64 try: image_bytes = base64.b64decode(inputs) image = Image.open(BytesIO(image_bytes)) except Exception: raise ValueError( "Invalid input format. Please provide either a base64 encoded image " "or an image URL." ) return image.convert("RGB")