support MPS / CPU inference
Browse files- .gitattributes +1 -0
- README.md +53 -90
- deepencoder.py +1 -1
- demo/image.png +3 -0
- demo/requirements.txt +12 -0
- demo/run_dpsk_ocr.py +51 -0
- modeling_deepseekocr.py +32 -20
.gitattributes
CHANGED
|
@@ -38,3 +38,4 @@ assets/show1.jpg filter=lfs diff=lfs merge=lfs -text
|
|
| 38 |
assets/show2.jpg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
assets/show3.jpg filter=lfs diff=lfs merge=lfs -text
|
| 40 |
assets/show4.jpg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 38 |
assets/show2.jpg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
assets/show3.jpg filter=lfs diff=lfs merge=lfs -text
|
| 40 |
assets/show4.jpg filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
demo/image.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -10,96 +10,76 @@ tags:
|
|
| 10 |
license: mit
|
| 11 |
library_name: transformers
|
| 12 |
---
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
<a href="https://huggingface.co/deepseek-ai/DeepSeek-OCR" target="_blank">
|
| 22 |
-
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" />
|
| 23 |
-
</a>
|
| 24 |
-
|
| 25 |
-
</div>
|
| 26 |
-
|
| 27 |
-
<div align="center">
|
| 28 |
-
|
| 29 |
-
<a href="https://discord.gg/Tc7c45Zzu5" target="_blank">
|
| 30 |
-
<img alt="Discord" src="https://img.shields.io/badge/Discord-DeepSeek%20AI-7289da?logo=discord&logoColor=white&color=7289da" />
|
| 31 |
-
</a>
|
| 32 |
-
<a href="https://twitter.com/deepseek_ai" target="_blank">
|
| 33 |
-
<img alt="Twitter Follow" src="https://img.shields.io/badge/Twitter-deepseek_ai-white?logo=x&logoColor=white" />
|
| 34 |
-
</a>
|
| 35 |
-
|
| 36 |
-
</div>
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
<p align="center">
|
| 41 |
-
<a href="https://github.com/deepseek-ai/DeepSeek-OCR"><b>🌟 Github</b></a> |
|
| 42 |
-
<a href="https://huggingface.co/deepseek-ai/DeepSeek-OCR"><b>📥 Model Download</b></a> |
|
| 43 |
-
<a href="https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek_OCR_paper.pdf"><b>📄 Paper Link</b></a> |
|
| 44 |
-
<a href="https://arxiv.org/abs/2510.18234"><b>📄 Arxiv Paper Link</b></a> |
|
| 45 |
-
</p>
|
| 46 |
-
<h2>
|
| 47 |
-
<p align="center">
|
| 48 |
-
<a href="https://huggingface.co/papers/2510.18234">DeepSeek-OCR: Contexts Optical Compression</a>
|
| 49 |
-
</p>
|
| 50 |
-
</h2>
|
| 51 |
-
<p align="center">
|
| 52 |
-
<img src="assets/fig1.png" style="width: 1000px" align=center>
|
| 53 |
-
</p>
|
| 54 |
-
<p align="center">
|
| 55 |
-
<a href="https://huggingface.co/papers/2510.18234">Explore the boundaries of visual-text compression.</a>
|
| 56 |
-
</p>
|
| 57 |
|
| 58 |
## Usage
|
| 59 |
-
Inference using Huggingface transformers on
|
| 60 |
|
| 61 |
-
```
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
pip install
|
|
|
|
|
|
|
| 69 |
```
|
| 70 |
|
| 71 |
```python
|
| 72 |
from transformers import AutoModel, AutoTokenizer
|
| 73 |
import torch
|
| 74 |
-
import os
|
| 75 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
|
| 76 |
-
model_name = 'deepseek-ai/DeepSeek-OCR'
|
| 77 |
|
| 78 |
-
|
| 79 |
-
model = AutoModel.from_pretrained(model_name, _attn_implementation='flash_attention_2', trust_remote_code=True, use_safetensors=True)
|
| 80 |
-
model = model.eval().cuda().to(torch.bfloat16)
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
image_file = 'your_image.jpg'
|
| 85 |
-
output_path = 'your/output/dir'
|
| 86 |
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
# Base: base_size = 1024, image_size = 1024, crop_mode = False
|
| 92 |
-
# Large: base_size = 1280, image_size = 1280, crop_mode = False
|
| 93 |
|
| 94 |
-
|
| 95 |
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
```
|
| 98 |
|
| 99 |
## vLLM
|
|
|
|
|
|
|
|
|
|
| 100 |
Refer to [����GitHub](https://github.com/deepseek-ai/DeepSeek-OCR/) for guidance on model inference acceleration and PDF processing, etc.<!-- -->
|
| 101 |
|
| 102 |
-
[2025/10/23] 🚀🚀🚀 DeepSeek-OCR is now officially supported in upstream [vLLM](https://docs.vllm.ai/projects/recipes/en/latest/DeepSeek/DeepSeek-OCR.html#installing-vllm).
|
| 103 |
```shell
|
| 104 |
uv venv
|
| 105 |
source .venv/bin/activate
|
|
@@ -114,7 +94,7 @@ from PIL import Image
|
|
| 114 |
|
| 115 |
# Create model instance
|
| 116 |
llm = LLM(
|
| 117 |
-
model="
|
| 118 |
enable_prefix_caching=False,
|
| 119 |
mm_processor_cache_gb=0,
|
| 120 |
logits_processors=[NGramPerReqLogitsProcessor]
|
|
@@ -166,21 +146,4 @@ for output in model_outputs:
|
|
| 166 |
<td><img src="assets/show3.jpg" style="width: 500px"></td>
|
| 167 |
<td><img src="assets/show4.jpg" style="width: 500px"></td>
|
| 168 |
</tr>
|
| 169 |
-
</table>
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
## Acknowledgement
|
| 173 |
-
|
| 174 |
-
We would like to thank [Vary](https://github.com/Ucas-HaoranWei/Vary/), [GOT-OCR2.0](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/), [MinerU](https://github.com/opendatalab/MinerU), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [OneChart](https://github.com/LingyvKong/OneChart), [Slow Perception](https://github.com/Ucas-HaoranWei/Slow-Perception) for their valuable models and ideas.
|
| 175 |
-
|
| 176 |
-
We also appreciate the benchmarks: [Fox](https://github.com/ucaslcl/Fox), [OminiDocBench](https://github.com/opendatalab/OmniDocBench).
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
## Citation
|
| 180 |
-
```bibtex
|
| 181 |
-
@article{wei2025deepseek,
|
| 182 |
-
title={DeepSeek-OCR: Contexts Optical Compression},
|
| 183 |
-
author={Wei, Haoran and Sun, Yaofeng and Li, Yukun},
|
| 184 |
-
journal={arXiv preprint arXiv:2510.18234},
|
| 185 |
-
year={2025}
|
| 186 |
-
}
|
|
|
|
| 10 |
license: mit
|
| 11 |
library_name: transformers
|
| 12 |
---
|
| 13 |
+
|
| 14 |
+
# DeepSeek-OCR – Apple Metal Performance Shaders (MPS) & CPU Support
|
| 15 |
+
|
| 16 |
+
This repository uses the weights from the original DeepSeek-OCR and modifies model to support MPS and CPU inference
|
| 17 |
+
|
| 18 |
+
- [Link to the original DeepSeek-OCR model](https://huggingface.co/deepseek-ai/DeepSeek-OCR).
|
| 19 |
+
|
| 20 |
+
- [Link to the original DeepSeek-OCR repository](https://github.com/deepseek-ai/DeepSeek-OCR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
## Usage
|
| 23 |
+
Inference using Huggingface transformers on Metal Performance Shaders (MPS) and CPU. Requirements tested on python 3.12.9:
|
| 24 |
|
| 25 |
+
```shell
|
| 26 |
+
git clone git@hf.co:Dogacel/DeepSeek-OCR-Metal-MPS
|
| 27 |
+
cd DeepSeek-OCR-Metal-MPS/demo
|
| 28 |
+
|
| 29 |
+
# Use mamba or conda
|
| 30 |
+
mamba create -n deepseek-ocr python=3.12.9 -y
|
| 31 |
+
mamba activate deepseek-ocr
|
| 32 |
+
pip install -r requirements.txt
|
| 33 |
+
|
| 34 |
+
python run_dpsk_ocr.py
|
| 35 |
```
|
| 36 |
|
| 37 |
```python
|
| 38 |
from transformers import AutoModel, AutoTokenizer
|
| 39 |
import torch
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
model_name = 'Dogacel/DeepSeek-OCR-Metal-MPS'
|
|
|
|
|
|
|
| 42 |
|
| 43 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 44 |
+
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
model = AutoModel.from_pretrained(
|
| 47 |
+
model_name,
|
| 48 |
+
_attn_implementation='eager',
|
| 49 |
+
trust_remote_code=True,
|
| 50 |
+
use_safetensors=True,
|
| 51 |
+
)
|
| 52 |
|
| 53 |
+
device = torch.device("mps")
|
| 54 |
+
dtype = torch.float16
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
model = model.eval().to(device).to(dtype)
|
| 57 |
|
| 58 |
+
prompt = "<image>\n<|grounding|>Convert the document to markdown. "
|
| 59 |
+
image_file = 'image.png'
|
| 60 |
+
output_path = 'results4'
|
| 61 |
+
|
| 62 |
+
res = model.infer(
|
| 63 |
+
tokenizer,
|
| 64 |
+
device=device,
|
| 65 |
+
dtype=dtype,
|
| 66 |
+
prompt=prompt,
|
| 67 |
+
image_file=image_file,
|
| 68 |
+
output_path = output_path,
|
| 69 |
+
base_size=1024,
|
| 70 |
+
image_size=640,
|
| 71 |
+
crop_mode=False,
|
| 72 |
+
save_results = True,
|
| 73 |
+
test_compress = True,
|
| 74 |
+
)
|
| 75 |
```
|
| 76 |
|
| 77 |
## vLLM
|
| 78 |
+
|
| 79 |
+
> vLLM integration hasn't been tested yet.
|
| 80 |
+
|
| 81 |
Refer to [����GitHub](https://github.com/deepseek-ai/DeepSeek-OCR/) for guidance on model inference acceleration and PDF processing, etc.<!-- -->
|
| 82 |
|
|
|
|
| 83 |
```shell
|
| 84 |
uv venv
|
| 85 |
source .venv/bin/activate
|
|
|
|
| 94 |
|
| 95 |
# Create model instance
|
| 96 |
llm = LLM(
|
| 97 |
+
model="Dogacel/DeepSeek-OCR-Metal-MPS",
|
| 98 |
enable_prefix_caching=False,
|
| 99 |
mm_processor_cache_gb=0,
|
| 100 |
logits_processors=[NGramPerReqLogitsProcessor]
|
|
|
|
| 146 |
<td><img src="assets/show3.jpg" style="width: 500px"></td>
|
| 147 |
<td><img src="assets/show4.jpg" style="width: 500px"></td>
|
| 148 |
</tr>
|
| 149 |
+
</table>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
deepencoder.py
CHANGED
|
@@ -1011,7 +1011,7 @@ def build_sam_vit_b(checkpoint=None):
|
|
| 1011 |
checkpoint=checkpoint,
|
| 1012 |
)
|
| 1013 |
|
| 1014 |
-
def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.
|
| 1015 |
image_encoder = build_sam_vit_b(checkpoint).eval().to(dtype)
|
| 1016 |
# sam = _apply_eval_dtype_sam(sam, dtype)
|
| 1017 |
image_encoder = torch.compile(image_encoder, mode=compile_mode)
|
|
|
|
| 1011 |
checkpoint=checkpoint,
|
| 1012 |
)
|
| 1013 |
|
| 1014 |
+
def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.float16):
|
| 1015 |
image_encoder = build_sam_vit_b(checkpoint).eval().to(dtype)
|
| 1016 |
# sam = _apply_eval_dtype_sam(sam, dtype)
|
| 1017 |
image_encoder = torch.compile(image_encoder, mode=compile_mode)
|
demo/image.png
ADDED
|
Git LFS Details
|
demo/requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers==4.46.3
|
| 2 |
+
tokenizers==0.20.3
|
| 3 |
+
torch==2.9.1
|
| 4 |
+
torchvision==0.24.1
|
| 5 |
+
torchaudio==2.9.1
|
| 6 |
+
PyMuPDF
|
| 7 |
+
img2pdf
|
| 8 |
+
einops
|
| 9 |
+
easydict
|
| 10 |
+
addict
|
| 11 |
+
Pillow
|
| 12 |
+
numpy
|
demo/run_dpsk_ocr.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoModel, AutoTokenizer
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
model_name = '../'
|
| 5 |
+
|
| 6 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 7 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 8 |
+
|
| 9 |
+
model = AutoModel.from_pretrained(
|
| 10 |
+
model_name,
|
| 11 |
+
_attn_implementation='eager',
|
| 12 |
+
trust_remote_code=True,
|
| 13 |
+
use_safetensors=True,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
device = torch.device("mps")
|
| 17 |
+
dtype = torch.float16
|
| 18 |
+
|
| 19 |
+
model = model.eval().to(device).to(dtype)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# prompt = "<image>\nFree OCR. "
|
| 24 |
+
prompt = "<image>\n<|grounding|>Convert the document to markdown. "
|
| 25 |
+
image_file = 'image.png'
|
| 26 |
+
output_path = 'results'
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# infer(self, tokenizer, prompt='', image_file='', output_path = ' ', base_size = 1024, image_size = 640, crop_mode = True, test_compress = False, save_results = False):
|
| 31 |
+
|
| 32 |
+
# Tiny: base_size = 512, image_size = 512, crop_mode = False
|
| 33 |
+
# Small: base_size = 640, image_size = 640, crop_mode = False
|
| 34 |
+
# Base: base_size = 1024, image_size = 1024, crop_mode = False
|
| 35 |
+
# Large: base_size = 1280, image_size = 1280, crop_mode = False
|
| 36 |
+
|
| 37 |
+
# Gundam: base_size = 1024, image_size = 640, crop_mode = True
|
| 38 |
+
|
| 39 |
+
res = model.infer(
|
| 40 |
+
tokenizer,
|
| 41 |
+
device=device,
|
| 42 |
+
dtype=dtype,
|
| 43 |
+
prompt=prompt,
|
| 44 |
+
image_file=image_file,
|
| 45 |
+
output_path = output_path,
|
| 46 |
+
base_size=1024,
|
| 47 |
+
image_size=640,
|
| 48 |
+
crop_mode=False,
|
| 49 |
+
save_results = True,
|
| 50 |
+
test_compress = True,
|
| 51 |
+
)
|
modeling_deepseekocr.py
CHANGED
|
@@ -501,8 +501,11 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 501 |
if images_in_this_batch:
|
| 502 |
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
| 503 |
# exit()
|
| 504 |
-
|
| 505 |
-
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
idx += 1
|
| 508 |
|
|
@@ -652,7 +655,12 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 652 |
if attention_mask is not None and position_ids is None:
|
| 653 |
# create position_ids on the fly for batch generation
|
| 654 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 655 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
if past_key_values:
|
| 657 |
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 658 |
|
|
@@ -700,9 +708,12 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 700 |
|
| 701 |
|
| 702 |
|
| 703 |
-
def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
|
| 704 |
self.disable_torch_init()
|
| 705 |
|
|
|
|
|
|
|
|
|
|
| 706 |
os.makedirs(output_path, exist_ok=True)
|
| 707 |
os.makedirs(f'{output_path}/images', exist_ok=True)
|
| 708 |
|
|
@@ -799,9 +810,9 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 799 |
|
| 800 |
|
| 801 |
|
| 802 |
-
images_list.append(image_transform(global_view).to(
|
| 803 |
|
| 804 |
-
# global_view_tensor = image_transform(global_view).to(
|
| 805 |
|
| 806 |
width_crop_num, height_crop_num = crop_ratio
|
| 807 |
|
|
@@ -812,7 +823,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 812 |
"""process the local views"""
|
| 813 |
|
| 814 |
for i in range(len(images_crop_raw)):
|
| 815 |
-
images_crop_list.append(image_transform(images_crop_raw[i]).to(
|
| 816 |
|
| 817 |
if image_size == 640:
|
| 818 |
valid_img_tokens += len(images_crop_list) * 100
|
|
@@ -846,7 +857,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 846 |
# else:
|
| 847 |
global_view = ImageOps.pad(image, (image_size, image_size),
|
| 848 |
color=tuple(int(x * 255) for x in image_transform.mean))
|
| 849 |
-
images_list.append(image_transform(global_view).to(
|
| 850 |
|
| 851 |
if base_size == 1024:
|
| 852 |
valid_img_tokens += int(256 * ratio)
|
|
@@ -905,18 +916,19 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 905 |
if images_crop_list:
|
| 906 |
images_crop = torch.stack(images_crop_list, dim=0)
|
| 907 |
else:
|
| 908 |
-
|
| 909 |
|
| 910 |
|
|
|
|
| 911 |
|
| 912 |
if not eval_mode:
|
| 913 |
streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
|
| 914 |
-
with torch.autocast(
|
| 915 |
with torch.no_grad():
|
| 916 |
output_ids = self.generate(
|
| 917 |
-
|
| 918 |
-
images=[(images_crop.
|
| 919 |
-
images_seq_mask = images_seq_mask.unsqueeze(0).
|
| 920 |
images_spatial_crop = images_spatial_crop,
|
| 921 |
# do_sample=False,
|
| 922 |
# num_beams = 1,
|
|
@@ -929,12 +941,12 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 929 |
)
|
| 930 |
|
| 931 |
else:
|
| 932 |
-
with torch.autocast(
|
| 933 |
with torch.no_grad():
|
| 934 |
output_ids = self.generate(
|
| 935 |
-
|
| 936 |
-
images=[(images_crop.
|
| 937 |
-
images_seq_mask = images_seq_mask.unsqueeze(0).
|
| 938 |
images_spatial_crop = images_spatial_crop,
|
| 939 |
# do_sample=False,
|
| 940 |
# num_beams = 1,
|
|
@@ -947,7 +959,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 947 |
|
| 948 |
|
| 949 |
if '<image>' in conversation[0]['content'] and eval_mode:
|
| 950 |
-
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).
|
| 951 |
stop_str = '<|end▁of▁sentence|>'
|
| 952 |
if outputs.endswith(stop_str):
|
| 953 |
outputs = outputs[:-len(stop_str)]
|
|
@@ -957,7 +969,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 957 |
return outputs
|
| 958 |
|
| 959 |
if '<image>' in conversation[0]['content'] and test_compress:
|
| 960 |
-
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).
|
| 961 |
pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
|
| 962 |
print('='*50)
|
| 963 |
print('image size: ', (w, h))
|
|
@@ -968,7 +980,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 968 |
|
| 969 |
|
| 970 |
if '<image>' in conversation[0]['content'] and save_results:
|
| 971 |
-
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).
|
| 972 |
stop_str = '<|end▁of▁sentence|>'
|
| 973 |
|
| 974 |
print('='*15 + 'save results:' + '='*15)
|
|
|
|
| 501 |
if images_in_this_batch:
|
| 502 |
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
| 503 |
# exit()
|
| 504 |
+
mask_indices = images_seq_mask[idx].nonzero(as_tuple=True)[0]
|
| 505 |
+
if len(mask_indices) == images_in_this_batch.shape[0]:
|
| 506 |
+
inputs_embeds[idx, mask_indices] = images_in_this_batch
|
| 507 |
+
else:
|
| 508 |
+
print(f"Size mismatch: mask has {len(mask_indices)} positions, but got {images_in_this_batch.shape[0]} features")
|
| 509 |
|
| 510 |
idx += 1
|
| 511 |
|
|
|
|
| 655 |
if attention_mask is not None and position_ids is None:
|
| 656 |
# create position_ids on the fly for batch generation
|
| 657 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 658 |
+
|
| 659 |
+
# position_ids.masked_fill_(attention_mask == 0, 1)
|
| 660 |
+
position_ids = torch.where(attention_mask == 0,
|
| 661 |
+
torch.ones_like(position_ids),
|
| 662 |
+
position_ids)
|
| 663 |
+
|
| 664 |
if past_key_values:
|
| 665 |
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 666 |
|
|
|
|
| 708 |
|
| 709 |
|
| 710 |
|
| 711 |
+
def infer(self, tokenizer, device=torch.device("mps"), dtype=torch.float16, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
|
| 712 |
self.disable_torch_init()
|
| 713 |
|
| 714 |
+
self.target_device = device
|
| 715 |
+
self.target_dtype = dtype
|
| 716 |
+
|
| 717 |
os.makedirs(output_path, exist_ok=True)
|
| 718 |
os.makedirs(f'{output_path}/images', exist_ok=True)
|
| 719 |
|
|
|
|
| 810 |
|
| 811 |
|
| 812 |
|
| 813 |
+
images_list.append(image_transform(global_view).to(self.target_dtype))
|
| 814 |
|
| 815 |
+
# global_view_tensor = image_transform(global_view).to(self.dtype)
|
| 816 |
|
| 817 |
width_crop_num, height_crop_num = crop_ratio
|
| 818 |
|
|
|
|
| 823 |
"""process the local views"""
|
| 824 |
|
| 825 |
for i in range(len(images_crop_raw)):
|
| 826 |
+
images_crop_list.append(image_transform(images_crop_raw[i]).to(self.target_dtype))
|
| 827 |
|
| 828 |
if image_size == 640:
|
| 829 |
valid_img_tokens += len(images_crop_list) * 100
|
|
|
|
| 857 |
# else:
|
| 858 |
global_view = ImageOps.pad(image, (image_size, image_size),
|
| 859 |
color=tuple(int(x * 255) for x in image_transform.mean))
|
| 860 |
+
images_list.append(image_transform(global_view).to(self.target_dtype))
|
| 861 |
|
| 862 |
if base_size == 1024:
|
| 863 |
valid_img_tokens += int(256 * ratio)
|
|
|
|
| 916 |
if images_crop_list:
|
| 917 |
images_crop = torch.stack(images_crop_list, dim=0)
|
| 918 |
else:
|
| 919 |
+
images_crop = torch.zeros((1, 3, base_size, base_size))
|
| 920 |
|
| 921 |
|
| 922 |
+
input_ids_tensor = input_ids.unsqueeze(0).to(self.target_device)
|
| 923 |
|
| 924 |
if not eval_mode:
|
| 925 |
streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
|
| 926 |
+
with torch.autocast(self.target_device.type, dtype=self.target_dtype):
|
| 927 |
with torch.no_grad():
|
| 928 |
output_ids = self.generate(
|
| 929 |
+
input_ids_tensor,
|
| 930 |
+
images=[(images_crop.to(self.target_device), images_ori.to(self.target_device))],
|
| 931 |
+
images_seq_mask = images_seq_mask.unsqueeze(0).to(self.target_device),
|
| 932 |
images_spatial_crop = images_spatial_crop,
|
| 933 |
# do_sample=False,
|
| 934 |
# num_beams = 1,
|
|
|
|
| 941 |
)
|
| 942 |
|
| 943 |
else:
|
| 944 |
+
with torch.autocast(self.target_device.type, dtype=self.target_dtype):
|
| 945 |
with torch.no_grad():
|
| 946 |
output_ids = self.generate(
|
| 947 |
+
input_ids_tensor,
|
| 948 |
+
images=[(images_crop.to(self.target_device), images_ori.to(self.target_device))],
|
| 949 |
+
images_seq_mask = images_seq_mask.unsqueeze(0).to(self.target_device),
|
| 950 |
images_spatial_crop = images_spatial_crop,
|
| 951 |
# do_sample=False,
|
| 952 |
# num_beams = 1,
|
|
|
|
| 959 |
|
| 960 |
|
| 961 |
if '<image>' in conversation[0]['content'] and eval_mode:
|
| 962 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1]:])
|
| 963 |
stop_str = '<|end▁of▁sentence|>'
|
| 964 |
if outputs.endswith(stop_str):
|
| 965 |
outputs = outputs[:-len(stop_str)]
|
|
|
|
| 969 |
return outputs
|
| 970 |
|
| 971 |
if '<image>' in conversation[0]['content'] and test_compress:
|
| 972 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1]:])
|
| 973 |
pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
|
| 974 |
print('='*50)
|
| 975 |
print('image size: ', (w, h))
|
|
|
|
| 980 |
|
| 981 |
|
| 982 |
if '<image>' in conversation[0]['content'] and save_results:
|
| 983 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1]:])
|
| 984 |
stop_str = '<|end▁of▁sentence|>'
|
| 985 |
|
| 986 |
print('='*15 + 'save results:' + '='*15)
|