Dogacel commited on
Commit
6b272b4
·
1 Parent(s): 9f30c71

support MPS / CPU inference

Browse files
.gitattributes CHANGED
@@ -38,3 +38,4 @@ assets/show1.jpg filter=lfs diff=lfs merge=lfs -text
38
  assets/show2.jpg filter=lfs diff=lfs merge=lfs -text
39
  assets/show3.jpg filter=lfs diff=lfs merge=lfs -text
40
  assets/show4.jpg filter=lfs diff=lfs merge=lfs -text
 
 
38
  assets/show2.jpg filter=lfs diff=lfs merge=lfs -text
39
  assets/show3.jpg filter=lfs diff=lfs merge=lfs -text
40
  assets/show4.jpg filter=lfs diff=lfs merge=lfs -text
41
+ demo/image.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -10,96 +10,76 @@ tags:
10
  license: mit
11
  library_name: transformers
12
  ---
13
- <div align="center">
14
- <img src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/logo.svg?raw=true" width="60%" alt="DeepSeek AI" />
15
- </div>
16
- <hr>
17
- <div align="center">
18
- <a href="https://www.deepseek.com/" target="_blank">
19
- <img alt="Homepage" src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/badge.svg?raw=true" />
20
- </a>
21
- <a href="https://huggingface.co/deepseek-ai/DeepSeek-OCR" target="_blank">
22
- <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" />
23
- </a>
24
-
25
- </div>
26
-
27
- <div align="center">
28
-
29
- <a href="https://discord.gg/Tc7c45Zzu5" target="_blank">
30
- <img alt="Discord" src="https://img.shields.io/badge/Discord-DeepSeek%20AI-7289da?logo=discord&logoColor=white&color=7289da" />
31
- </a>
32
- <a href="https://twitter.com/deepseek_ai" target="_blank">
33
- <img alt="Twitter Follow" src="https://img.shields.io/badge/Twitter-deepseek_ai-white?logo=x&logoColor=white" />
34
- </a>
35
-
36
- </div>
37
-
38
-
39
-
40
- <p align="center">
41
- <a href="https://github.com/deepseek-ai/DeepSeek-OCR"><b>🌟 Github</b></a> |
42
- <a href="https://huggingface.co/deepseek-ai/DeepSeek-OCR"><b>📥 Model Download</b></a> |
43
- <a href="https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek_OCR_paper.pdf"><b>📄 Paper Link</b></a> |
44
- <a href="https://arxiv.org/abs/2510.18234"><b>📄 Arxiv Paper Link</b></a> |
45
- </p>
46
- <h2>
47
- <p align="center">
48
- <a href="https://huggingface.co/papers/2510.18234">DeepSeek-OCR: Contexts Optical Compression</a>
49
- </p>
50
- </h2>
51
- <p align="center">
52
- <img src="assets/fig1.png" style="width: 1000px" align=center>
53
- </p>
54
- <p align="center">
55
- <a href="https://huggingface.co/papers/2510.18234">Explore the boundaries of visual-text compression.</a>
56
- </p>
57
 
58
  ## Usage
59
- Inference using Huggingface transformers on NVIDIA GPUs. Requirements tested on python 3.12.9 + CUDA11.8
60
 
61
- ```
62
- torch==2.6.0
63
- transformers==4.46.3
64
- tokenizers==0.20.3
65
- einops
66
- addict
67
- easydict
68
- pip install flash-attn==2.7.3 --no-build-isolation
 
 
69
  ```
70
 
71
  ```python
72
  from transformers import AutoModel, AutoTokenizer
73
  import torch
74
- import os
75
- os.environ["CUDA_VISIBLE_DEVICES"] = '0'
76
- model_name = 'deepseek-ai/DeepSeek-OCR'
77
 
78
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
79
- model = AutoModel.from_pretrained(model_name, _attn_implementation='flash_attention_2', trust_remote_code=True, use_safetensors=True)
80
- model = model.eval().cuda().to(torch.bfloat16)
81
 
82
- # prompt = "<image>\nFree OCR. "
83
- prompt = "<image>\n<|grounding|>Convert the document to markdown. "
84
- image_file = 'your_image.jpg'
85
- output_path = 'your/output/dir'
86
 
87
- # infer(self, tokenizer, prompt='', image_file='', output_path = ' ', base_size = 1024, image_size = 640, crop_mode = True, test_compress = False, save_results = False):
 
 
 
 
 
88
 
89
- # Tiny: base_size = 512, image_size = 512, crop_mode = False
90
- # Small: base_size = 640, image_size = 640, crop_mode = False
91
- # Base: base_size = 1024, image_size = 1024, crop_mode = False
92
- # Large: base_size = 1280, image_size = 1280, crop_mode = False
93
 
94
- # Gundam: base_size = 1024, image_size = 640, crop_mode = True
95
 
96
- res = model.infer(tokenizer, prompt=prompt, image_file=image_file, output_path = output_path, base_size = 1024, image_size = 640, crop_mode=True, save_results = True, test_compress = True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  ```
98
 
99
  ## vLLM
 
 
 
100
  Refer to [����GitHub](https://github.com/deepseek-ai/DeepSeek-OCR/) for guidance on model inference acceleration and PDF processing, etc.<!-- -->
101
 
102
- [2025/10/23] 🚀🚀🚀 DeepSeek-OCR is now officially supported in upstream [vLLM](https://docs.vllm.ai/projects/recipes/en/latest/DeepSeek/DeepSeek-OCR.html#installing-vllm).
103
  ```shell
104
  uv venv
105
  source .venv/bin/activate
@@ -114,7 +94,7 @@ from PIL import Image
114
 
115
  # Create model instance
116
  llm = LLM(
117
- model="deepseek-ai/DeepSeek-OCR",
118
  enable_prefix_caching=False,
119
  mm_processor_cache_gb=0,
120
  logits_processors=[NGramPerReqLogitsProcessor]
@@ -166,21 +146,4 @@ for output in model_outputs:
166
  <td><img src="assets/show3.jpg" style="width: 500px"></td>
167
  <td><img src="assets/show4.jpg" style="width: 500px"></td>
168
  </tr>
169
- </table>
170
-
171
-
172
- ## Acknowledgement
173
-
174
- We would like to thank [Vary](https://github.com/Ucas-HaoranWei/Vary/), [GOT-OCR2.0](https://github.com/Ucas-HaoranWei/GOT-OCR2.0/), [MinerU](https://github.com/opendatalab/MinerU), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [OneChart](https://github.com/LingyvKong/OneChart), [Slow Perception](https://github.com/Ucas-HaoranWei/Slow-Perception) for their valuable models and ideas.
175
-
176
- We also appreciate the benchmarks: [Fox](https://github.com/ucaslcl/Fox), [OminiDocBench](https://github.com/opendatalab/OmniDocBench).
177
-
178
-
179
- ## Citation
180
- ```bibtex
181
- @article{wei2025deepseek,
182
- title={DeepSeek-OCR: Contexts Optical Compression},
183
- author={Wei, Haoran and Sun, Yaofeng and Li, Yukun},
184
- journal={arXiv preprint arXiv:2510.18234},
185
- year={2025}
186
- }
 
10
  license: mit
11
  library_name: transformers
12
  ---
13
+
14
+ # DeepSeek-OCR Apple Metal Performance Shaders (MPS) & CPU Support
15
+
16
+ This repository uses the weights from the original DeepSeek-OCR and modifies model to support MPS and CPU inference
17
+
18
+ - [Link to the original DeepSeek-OCR model](https://huggingface.co/deepseek-ai/DeepSeek-OCR).
19
+
20
+ - [Link to the original DeepSeek-OCR repository](https://github.com/deepseek-ai/DeepSeek-OCR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  ## Usage
23
+ Inference using Huggingface transformers on Metal Performance Shaders (MPS) and CPU. Requirements tested on python 3.12.9:
24
 
25
+ ```shell
26
+ git clone git@hf.co:Dogacel/DeepSeek-OCR-Metal-MPS
27
+ cd DeepSeek-OCR-Metal-MPS/demo
28
+
29
+ # Use mamba or conda
30
+ mamba create -n deepseek-ocr python=3.12.9 -y
31
+ mamba activate deepseek-ocr
32
+ pip install -r requirements.txt
33
+
34
+ python run_dpsk_ocr.py
35
  ```
36
 
37
  ```python
38
  from transformers import AutoModel, AutoTokenizer
39
  import torch
 
 
 
40
 
41
+ model_name = 'Dogacel/DeepSeek-OCR-Metal-MPS'
 
 
42
 
43
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
44
+ tokenizer.pad_token = tokenizer.eos_token
 
 
45
 
46
+ model = AutoModel.from_pretrained(
47
+ model_name,
48
+ _attn_implementation='eager',
49
+ trust_remote_code=True,
50
+ use_safetensors=True,
51
+ )
52
 
53
+ device = torch.device("mps")
54
+ dtype = torch.float16
 
 
55
 
56
+ model = model.eval().to(device).to(dtype)
57
 
58
+ prompt = "<image>\n<|grounding|>Convert the document to markdown. "
59
+ image_file = 'image.png'
60
+ output_path = 'results4'
61
+
62
+ res = model.infer(
63
+ tokenizer,
64
+ device=device,
65
+ dtype=dtype,
66
+ prompt=prompt,
67
+ image_file=image_file,
68
+ output_path = output_path,
69
+ base_size=1024,
70
+ image_size=640,
71
+ crop_mode=False,
72
+ save_results = True,
73
+ test_compress = True,
74
+ )
75
  ```
76
 
77
  ## vLLM
78
+
79
+ > vLLM integration hasn't been tested yet.
80
+
81
  Refer to [����GitHub](https://github.com/deepseek-ai/DeepSeek-OCR/) for guidance on model inference acceleration and PDF processing, etc.<!-- -->
82
 
 
83
  ```shell
84
  uv venv
85
  source .venv/bin/activate
 
94
 
95
  # Create model instance
96
  llm = LLM(
97
+ model="Dogacel/DeepSeek-OCR-Metal-MPS",
98
  enable_prefix_caching=False,
99
  mm_processor_cache_gb=0,
100
  logits_processors=[NGramPerReqLogitsProcessor]
 
146
  <td><img src="assets/show3.jpg" style="width: 500px"></td>
147
  <td><img src="assets/show4.jpg" style="width: 500px"></td>
148
  </tr>
149
+ </table>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deepencoder.py CHANGED
@@ -1011,7 +1011,7 @@ def build_sam_vit_b(checkpoint=None):
1011
  checkpoint=checkpoint,
1012
  )
1013
 
1014
- def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16):
1015
  image_encoder = build_sam_vit_b(checkpoint).eval().to(dtype)
1016
  # sam = _apply_eval_dtype_sam(sam, dtype)
1017
  image_encoder = torch.compile(image_encoder, mode=compile_mode)
 
1011
  checkpoint=checkpoint,
1012
  )
1013
 
1014
+ def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.float16):
1015
  image_encoder = build_sam_vit_b(checkpoint).eval().to(dtype)
1016
  # sam = _apply_eval_dtype_sam(sam, dtype)
1017
  image_encoder = torch.compile(image_encoder, mode=compile_mode)
demo/image.png ADDED

Git LFS Details

  • SHA256: e07688758d8bb7a9ca0c3c53d33ec5a06a6a9dec2edd7ffb8e3a32d4609de80d
  • Pointer size: 131 Bytes
  • Size of remote file: 365 kB
demo/requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.46.3
2
+ tokenizers==0.20.3
3
+ torch==2.9.1
4
+ torchvision==0.24.1
5
+ torchaudio==2.9.1
6
+ PyMuPDF
7
+ img2pdf
8
+ einops
9
+ easydict
10
+ addict
11
+ Pillow
12
+ numpy
demo/run_dpsk_ocr.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer
2
+ import torch
3
+
4
+ model_name = '../'
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
7
+ tokenizer.pad_token = tokenizer.eos_token
8
+
9
+ model = AutoModel.from_pretrained(
10
+ model_name,
11
+ _attn_implementation='eager',
12
+ trust_remote_code=True,
13
+ use_safetensors=True,
14
+ )
15
+
16
+ device = torch.device("mps")
17
+ dtype = torch.float16
18
+
19
+ model = model.eval().to(device).to(dtype)
20
+
21
+
22
+
23
+ # prompt = "<image>\nFree OCR. "
24
+ prompt = "<image>\n<|grounding|>Convert the document to markdown. "
25
+ image_file = 'image.png'
26
+ output_path = 'results'
27
+
28
+
29
+
30
+ # infer(self, tokenizer, prompt='', image_file='', output_path = ' ', base_size = 1024, image_size = 640, crop_mode = True, test_compress = False, save_results = False):
31
+
32
+ # Tiny: base_size = 512, image_size = 512, crop_mode = False
33
+ # Small: base_size = 640, image_size = 640, crop_mode = False
34
+ # Base: base_size = 1024, image_size = 1024, crop_mode = False
35
+ # Large: base_size = 1280, image_size = 1280, crop_mode = False
36
+
37
+ # Gundam: base_size = 1024, image_size = 640, crop_mode = True
38
+
39
+ res = model.infer(
40
+ tokenizer,
41
+ device=device,
42
+ dtype=dtype,
43
+ prompt=prompt,
44
+ image_file=image_file,
45
+ output_path = output_path,
46
+ base_size=1024,
47
+ image_size=640,
48
+ crop_mode=False,
49
+ save_results = True,
50
+ test_compress = True,
51
+ )
modeling_deepseekocr.py CHANGED
@@ -501,8 +501,11 @@ class DeepseekOCRModel(DeepseekV2Model):
501
  if images_in_this_batch:
502
  images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
503
  # exit()
504
-
505
- inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch)
 
 
 
506
 
507
  idx += 1
508
 
@@ -652,7 +655,12 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
652
  if attention_mask is not None and position_ids is None:
653
  # create position_ids on the fly for batch generation
654
  position_ids = attention_mask.long().cumsum(-1) - 1
655
- position_ids.masked_fill_(attention_mask == 0, 1)
 
 
 
 
 
656
  if past_key_values:
657
  position_ids = position_ids[:, -input_ids.shape[1] :]
658
 
@@ -700,9 +708,12 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
700
 
701
 
702
 
703
- def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
704
  self.disable_torch_init()
705
 
 
 
 
706
  os.makedirs(output_path, exist_ok=True)
707
  os.makedirs(f'{output_path}/images', exist_ok=True)
708
 
@@ -799,9 +810,9 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
799
 
800
 
801
 
802
- images_list.append(image_transform(global_view).to(torch.bfloat16))
803
 
804
- # global_view_tensor = image_transform(global_view).to(torch.bfloat16)
805
 
806
  width_crop_num, height_crop_num = crop_ratio
807
 
@@ -812,7 +823,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
812
  """process the local views"""
813
 
814
  for i in range(len(images_crop_raw)):
815
- images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16))
816
 
817
  if image_size == 640:
818
  valid_img_tokens += len(images_crop_list) * 100
@@ -846,7 +857,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
846
  # else:
847
  global_view = ImageOps.pad(image, (image_size, image_size),
848
  color=tuple(int(x * 255) for x in image_transform.mean))
849
- images_list.append(image_transform(global_view).to(torch.bfloat16))
850
 
851
  if base_size == 1024:
852
  valid_img_tokens += int(256 * ratio)
@@ -905,18 +916,19 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
905
  if images_crop_list:
906
  images_crop = torch.stack(images_crop_list, dim=0)
907
  else:
908
- images_crop = torch.zeros((1, 3, base_size, base_size))
909
 
910
 
 
911
 
912
  if not eval_mode:
913
  streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
914
- with torch.autocast("cuda", dtype=torch.bfloat16):
915
  with torch.no_grad():
916
  output_ids = self.generate(
917
- input_ids.unsqueeze(0).cuda(),
918
- images=[(images_crop.cuda(), images_ori.cuda())],
919
- images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
920
  images_spatial_crop = images_spatial_crop,
921
  # do_sample=False,
922
  # num_beams = 1,
@@ -929,12 +941,12 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
929
  )
930
 
931
  else:
932
- with torch.autocast("cuda", dtype=torch.bfloat16):
933
  with torch.no_grad():
934
  output_ids = self.generate(
935
- input_ids.unsqueeze(0).cuda(),
936
- images=[(images_crop.cuda(), images_ori.cuda())],
937
- images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
938
  images_spatial_crop = images_spatial_crop,
939
  # do_sample=False,
940
  # num_beams = 1,
@@ -947,7 +959,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
947
 
948
 
949
  if '<image>' in conversation[0]['content'] and eval_mode:
950
- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
951
  stop_str = '<|end▁of▁sentence|>'
952
  if outputs.endswith(stop_str):
953
  outputs = outputs[:-len(stop_str)]
@@ -957,7 +969,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
957
  return outputs
958
 
959
  if '<image>' in conversation[0]['content'] and test_compress:
960
- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
961
  pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
962
  print('='*50)
963
  print('image size: ', (w, h))
@@ -968,7 +980,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
968
 
969
 
970
  if '<image>' in conversation[0]['content'] and save_results:
971
- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
972
  stop_str = '<|end▁of▁sentence|>'
973
 
974
  print('='*15 + 'save results:' + '='*15)
 
501
  if images_in_this_batch:
502
  images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
503
  # exit()
504
+ mask_indices = images_seq_mask[idx].nonzero(as_tuple=True)[0]
505
+ if len(mask_indices) == images_in_this_batch.shape[0]:
506
+ inputs_embeds[idx, mask_indices] = images_in_this_batch
507
+ else:
508
+ print(f"Size mismatch: mask has {len(mask_indices)} positions, but got {images_in_this_batch.shape[0]} features")
509
 
510
  idx += 1
511
 
 
655
  if attention_mask is not None and position_ids is None:
656
  # create position_ids on the fly for batch generation
657
  position_ids = attention_mask.long().cumsum(-1) - 1
658
+
659
+ # position_ids.masked_fill_(attention_mask == 0, 1)
660
+ position_ids = torch.where(attention_mask == 0,
661
+ torch.ones_like(position_ids),
662
+ position_ids)
663
+
664
  if past_key_values:
665
  position_ids = position_ids[:, -input_ids.shape[1] :]
666
 
 
708
 
709
 
710
 
711
+ def infer(self, tokenizer, device=torch.device("mps"), dtype=torch.float16, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
712
  self.disable_torch_init()
713
 
714
+ self.target_device = device
715
+ self.target_dtype = dtype
716
+
717
  os.makedirs(output_path, exist_ok=True)
718
  os.makedirs(f'{output_path}/images', exist_ok=True)
719
 
 
810
 
811
 
812
 
813
+ images_list.append(image_transform(global_view).to(self.target_dtype))
814
 
815
+ # global_view_tensor = image_transform(global_view).to(self.dtype)
816
 
817
  width_crop_num, height_crop_num = crop_ratio
818
 
 
823
  """process the local views"""
824
 
825
  for i in range(len(images_crop_raw)):
826
+ images_crop_list.append(image_transform(images_crop_raw[i]).to(self.target_dtype))
827
 
828
  if image_size == 640:
829
  valid_img_tokens += len(images_crop_list) * 100
 
857
  # else:
858
  global_view = ImageOps.pad(image, (image_size, image_size),
859
  color=tuple(int(x * 255) for x in image_transform.mean))
860
+ images_list.append(image_transform(global_view).to(self.target_dtype))
861
 
862
  if base_size == 1024:
863
  valid_img_tokens += int(256 * ratio)
 
916
  if images_crop_list:
917
  images_crop = torch.stack(images_crop_list, dim=0)
918
  else:
919
+ images_crop = torch.zeros((1, 3, base_size, base_size))
920
 
921
 
922
+ input_ids_tensor = input_ids.unsqueeze(0).to(self.target_device)
923
 
924
  if not eval_mode:
925
  streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
926
+ with torch.autocast(self.target_device.type, dtype=self.target_dtype):
927
  with torch.no_grad():
928
  output_ids = self.generate(
929
+ input_ids_tensor,
930
+ images=[(images_crop.to(self.target_device), images_ori.to(self.target_device))],
931
+ images_seq_mask = images_seq_mask.unsqueeze(0).to(self.target_device),
932
  images_spatial_crop = images_spatial_crop,
933
  # do_sample=False,
934
  # num_beams = 1,
 
941
  )
942
 
943
  else:
944
+ with torch.autocast(self.target_device.type, dtype=self.target_dtype):
945
  with torch.no_grad():
946
  output_ids = self.generate(
947
+ input_ids_tensor,
948
+ images=[(images_crop.to(self.target_device), images_ori.to(self.target_device))],
949
+ images_seq_mask = images_seq_mask.unsqueeze(0).to(self.target_device),
950
  images_spatial_crop = images_spatial_crop,
951
  # do_sample=False,
952
  # num_beams = 1,
 
959
 
960
 
961
  if '<image>' in conversation[0]['content'] and eval_mode:
962
+ outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1]:])
963
  stop_str = '<|end▁of▁sentence|>'
964
  if outputs.endswith(stop_str):
965
  outputs = outputs[:-len(stop_str)]
 
969
  return outputs
970
 
971
  if '<image>' in conversation[0]['content'] and test_compress:
972
+ outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1]:])
973
  pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
974
  print('='*50)
975
  print('image size: ', (w, h))
 
980
 
981
 
982
  if '<image>' in conversation[0]['content'] and save_results:
983
+ outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(self.device).shape[1]:])
984
  stop_str = '<|end▁of▁sentence|>'
985
 
986
  print('='*15 + 'save results:' + '='*15)