Upload 19 files
Browse files- NV_LICENSE +35 -0
- README.md +184 -0
- auto_processor.py +495 -0
- base_projector.py +228 -0
- builder.py +247 -0
- config.json +295 -0
- configuration_vila.py +92 -0
- constants.py +83 -0
- conversation.py +191 -0
- distributed.py +73 -0
- loss.py +48 -0
- media.py +130 -0
- media_encoder.py +158 -0
- mm_utils.py +575 -0
- model_utils_packing.py +35 -0
- modeling_vila.py +1256 -0
- siglip_encoder.py +286 -0
- tokenizer_utils.py +181 -0
- utils.py +211 -0
NV_LICENSE
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
NVIDIA License
|
| 2 |
+
|
| 3 |
+
1. Definitions
|
| 4 |
+
|
| 5 |
+
“Licensor” means any person or entity that distributes its Work.
|
| 6 |
+
“Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license.
|
| 7 |
+
The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
|
| 8 |
+
Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license.
|
| 9 |
+
|
| 10 |
+
2. License Grant
|
| 11 |
+
|
| 12 |
+
2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
|
| 13 |
+
|
| 14 |
+
3. Limitations
|
| 15 |
+
|
| 16 |
+
3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
|
| 17 |
+
|
| 18 |
+
3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
|
| 19 |
+
|
| 20 |
+
3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or educational purposes only.
|
| 21 |
+
|
| 22 |
+
3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
|
| 23 |
+
|
| 24 |
+
3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license.
|
| 25 |
+
|
| 26 |
+
3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately.
|
| 27 |
+
|
| 28 |
+
4. Disclaimer of Warranty.
|
| 29 |
+
|
| 30 |
+
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
| 31 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
|
| 32 |
+
|
| 33 |
+
5. Limitation of Liability.
|
| 34 |
+
|
| 35 |
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
README.md
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LongVILA-R1-7B
|
| 2 |
+
[](https://arxiv.org/abs/2507.07966)
|
| 3 |
+
[](https://github.com/NVlabs/Long-RL)
|
| 4 |
+
[](https://huggingface.co/Efficient-Large-Model/LongVILA-R1-7B)
|
| 5 |
+
[](https://www.youtube.com/watch?v=ykbblK2jiEg)
|
| 6 |
+
[](https://6d8b5579459b555d59.gradio.live)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
## Introduction:
|
| 10 |
+
<p>
|
| 11 |
+
<strong>LongVILA-R1-7B</strong> supports both <u>multiple-choice</u> questions and <u>open-ended</u> questions. It can switch between thinking and non-thinking modes.<br>
|
| 12 |
+
<strong>LongVILA-R1-7B</strong> demonstrates strong performance in long video reasoning, achieving <strong>70.7%</strong> on VideoMME (w/ sub.) and surpassing Gemini-1.5-Pro across diverse reasoning tasks.<br>
|
| 13 |
+
<strong>Long-RL</strong> is a codebase that accelerates long video RL training by up to <strong>2.1×</strong> through its MR-SP system. It supports RL training on image, video, and omni inputs across VILA, Qwen/Qwen-VL, and diffusion models.
|
| 14 |
+
</p>
|
| 15 |
+
|
| 16 |
+
## Evaluation:
|
| 17 |
+
### Video QA Benchmarks
|
| 18 |
+
| Models | VideoMME (w/o sub) | VideoMME (w sub) | ActivityNet-QA (test) | LongVideoBench (val) | PerceptionTest (val) | NExT-QA | VNBench (val) |
|
| 19 |
+
|:-------------------|:------------------:|:----------------:|:---------------------:|:--------------------:|:--------------------:|:--------:|:-------------:|
|
| 20 |
+
| **LongVILA-7B** | **60.1** | **65.1** | **59.5** | **57.1** | **58.1** | **80.7** | **63.0** |
|
| 21 |
+
| **LongVILA-R1-7B** | **65.0** | **70.7** | **64.8** | **58.0** | **68.9** | **81.5** | **75.5** |
|
| 22 |
+
|
| 23 |
+
### LongVideo-Reason-eval
|
| 24 |
+
| Models | Temporal | Goal | Plot | Spatial | Overall|
|
| 25 |
+
| :--- | :---: | :---: | :---: | :---: | :---: |
|
| 26 |
+
| | | |
|
| 27 |
+
| **LongVILA-R1-7B** | **68.1** | **85.7** | **70.6** | **53.3** | **72.0** |
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
## Usage
|
| 31 |
+
|
| 32 |
+
### Generation
|
| 33 |
+
```python
|
| 34 |
+
from transformers import AutoModel
|
| 35 |
+
|
| 36 |
+
model_path = "Efficient-Large-Model/LongVILA-R1-7B"
|
| 37 |
+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto")
|
| 38 |
+
|
| 39 |
+
use_thinking = True # Switching between thinking and non-thinking modes
|
| 40 |
+
system_prompt_thinking = "You are a helpful assistant. The user asks a question, and then you solves it.\n\nPlease first think deeply about the question based on the given video, and then provide the final answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>.\n\n Question: {question}"
|
| 41 |
+
|
| 42 |
+
prompt = "What is the main purpose of the video?"
|
| 43 |
+
video_path = "video.mp4"
|
| 44 |
+
|
| 45 |
+
if use_thinking:
|
| 46 |
+
prompt = system_prompt_thinking.format(question=prompt)
|
| 47 |
+
|
| 48 |
+
response = model.generate_content([prompt, {"path": video_path}])
|
| 49 |
+
print("Response: ", response)
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
### with vLLM engine
|
| 53 |
+
Tested on `vllm==0.9.1`. We need to get the remote code first.
|
| 54 |
+
```bash
|
| 55 |
+
mkdir remote_code
|
| 56 |
+
cp path_to/Efficient-Large-Model/LongVILA-R1-7B/*.py remote_code
|
| 57 |
+
```
|
| 58 |
+
Then, you can use the following code for model generation.
|
| 59 |
+
```python
|
| 60 |
+
import os
|
| 61 |
+
from transformers import AutoModel
|
| 62 |
+
from vllm import LLM, SamplingParams
|
| 63 |
+
from remote_code.media import extract_media
|
| 64 |
+
from remote_code.mm_utils import process_images
|
| 65 |
+
from remote_code.tokenizer_utils import tokenize_conversation
|
| 66 |
+
|
| 67 |
+
model_path = "path_to/Efficient-Large-Model/LongVILA-R1-7B"
|
| 68 |
+
|
| 69 |
+
model_encoder = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto", llm_only_need_embed=True)
|
| 70 |
+
# you can change gpu_memory_utilization according to GPU memory
|
| 71 |
+
llm = LLM(model=os.path.join(model_path, "llm"), enable_prompt_embeds=True, gpu_memory_utilization=0.5)
|
| 72 |
+
|
| 73 |
+
use_thinking = True # Switching between thinking and non-thinking modes
|
| 74 |
+
system_prompt_thinking = "You are a helpful assistant. The user asks a question, and then you solves it.\n\nPlease first think deeply about the question based on the given video, and then provide the final answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>.\n\n Question: {question}"
|
| 75 |
+
|
| 76 |
+
prompt = "What is the main purpose of the video?"
|
| 77 |
+
video_path = "video.mp4"
|
| 78 |
+
|
| 79 |
+
if use_thinking:
|
| 80 |
+
prompt = system_prompt_thinking.format(question=prompt)
|
| 81 |
+
|
| 82 |
+
conversation = [{"from": "human", "value": [prompt, {"path": video_path}]}]
|
| 83 |
+
media = extract_media(conversation, model_encoder.config)
|
| 84 |
+
input_ids = tokenize_conversation(conversation, model_encoder.tokenizer, add_generation_prompt=True).unsqueeze(0).cuda()
|
| 85 |
+
media["video"] = [
|
| 86 |
+
process_images(images, model_encoder.vision_tower.image_processor, model_encoder.config).half()
|
| 87 |
+
for images in media["video"]
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
inputs_embeds, _, _ = model_encoder._embed(input_ids, media, {"video": {}}, None, None)
|
| 91 |
+
|
| 92 |
+
completions = llm.generate(prompts=[{"prompt_embeds": inputs_embeds.squeeze(0)}], sampling_params=SamplingParams(max_tokens=1024))
|
| 93 |
+
response = completions[0].outputs[0].text
|
| 94 |
+
print("Response: ", response)
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# LongVILA-R1 Model Card
|
| 99 |
+
|
| 100 |
+
## Model details
|
| 101 |
+
|
| 102 |
+
**Model type:**
|
| 103 |
+
LongVILA-R1 addresses the unique challenges of long video reasoning by integrating three critical components: (1) a large-scale dataset, LongVideo-Reason, comprising 104K long video QA pairs with high-quality reasoning annotations across diverse domains such as sports, games, and vlogs; (2) a two-stage training pipeline that extends VLMs with chain-of-thought supervised fine-tuning (CoT-SFT) and reinforcement learning (RL); and (3) a training infrastructure for long video RL, named Multi-modal Reinforcement Sequence Parallelism (MR-SP), which incorporates sequence parallelism and a vLLM-based engine tailored for long video, using cached video embeddings for efficient rollout and prefilling. In our experiments, LongVILA-R1-7B achieves strong performance on video benchmarks, reaching 65.0% and 70.7% accuracy on VideoMME without and with subtitles, respectively, and consistently outperforming LongVILA-R1 across multiple benchmarks. Moreover, LongVILA-R1 shows steady performance improvements as the number of input video frames increases.
|
| 104 |
+
**Model date:**
|
| 105 |
+
LongVILA-R1-7B was trained in July 2025.
|
| 106 |
+
|
| 107 |
+
**Paper or resources for more information:**
|
| 108 |
+
- Paper https://arxiv.org/abs/2507.07966
|
| 109 |
+
- Code https://github.com/NVLabs/Long-RL
|
| 110 |
+
- Model https://huggingface.co/Efficient-Large-Model/LongVILA-R1-7B
|
| 111 |
+
- Video https://www.youtube.com/watch?v=ykbblK2jiEg
|
| 112 |
+
- Demo https://6d8b5579459b555d59.gradio.live
|
| 113 |
+
|
| 114 |
+
```bibtex
|
| 115 |
+
@misc{long-rl,
|
| 116 |
+
title = {Long-RL: Scaling RL to Long Sequences},
|
| 117 |
+
author = {Yukang Chen, Wei Huang, Shuai Yang, Qinghao Hu, Baifeng Shi, Hanrong Ye, Ligeng Zhu, Zhijian Liu, Pavlo Molchanov, Jan Kautz, Xiaojuan Qi, Sifei Liu,Hongxu Yin, Yao Lu, Song Han},
|
| 118 |
+
year = {2025},
|
| 119 |
+
publisher = {GitHub},
|
| 120 |
+
journal = {GitHub repository},
|
| 121 |
+
howpublished = {\url{https://github.com/NVlabs/Long-RL}},
|
| 122 |
+
}
|
| 123 |
+
```
|
| 124 |
+
```bibtex
|
| 125 |
+
@article{chen2025longvila-r1,
|
| 126 |
+
title={Scaling RL to Long Videos},
|
| 127 |
+
author={Yukang Chen and Wei Huang and Baifeng Shi and Qinghao Hu and Hanrong Ye and Ligeng Zhu and Zhijian Liu and Pavlo Molchanov and Jan Kautz and Xiaojuan Qi and Sifei Liu and Hongxu Yin and Yao Lu and Song Han},
|
| 128 |
+
year={2025},
|
| 129 |
+
eprint={2507.07966},
|
| 130 |
+
archivePrefix={arXiv},
|
| 131 |
+
primaryClass={cs.CV}
|
| 132 |
+
}
|
| 133 |
+
```
|
| 134 |
+
```bibtex
|
| 135 |
+
@inproceedings{chen2024longvila,
|
| 136 |
+
title={LongVILA: Scaling Long-Context Visual Language Models for Long Videos},
|
| 137 |
+
author={Yukang Chen and Fuzhao Xue and Dacheng Li and Qinghao Hu and Ligeng Zhu and Xiuyu Li and Yunhao Fang and Haotian Tang and Shang Yang and Zhijian Liu and Ethan He and Hongxu Yin and Pavlo Molchanov and Jan Kautz and Linxi Fan and Yuke Zhu and Yao Lu and Song Han},
|
| 138 |
+
booktitle={The International Conference on Learning Representations (ICLR)},
|
| 139 |
+
year={2025},
|
| 140 |
+
}
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
## License
|
| 144 |
+
- The weights are released under the [CC-BY-NC-SA-4.0 license](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en).
|
| 145 |
+
- The service is a research preview intended for non-commercial use only, and is subject to the following licenses and terms:
|
| 146 |
+
- [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI
|
| 147 |
+
- [Dataset Licenses](https://github.com/Efficient-Large-Model/VILA/blob/main/data_prepare/LICENSE) for each one used during training.
|
| 148 |
+
- [NVIDIA Licenses](https://huggingface.co/Efficient-Large-Model/LongVILA-R1-7B/blob/main/NV_LICENSE)
|
| 149 |
+
|
| 150 |
+
**Where to send questions or comments about the model:**
|
| 151 |
+
https://github.com/NVLabs/Long-RL/issues
|
| 152 |
+
|
| 153 |
+
## Intended use
|
| 154 |
+
**Primary intended uses:**
|
| 155 |
+
The primary use of LongVILA-R1 is research on large multimodal models and chatbots.
|
| 156 |
+
|
| 157 |
+
**Primary intended users:**
|
| 158 |
+
The primary intended users of the model are researchers and hobbyists in computer vision, natural language processing, machine learning, and artificial intelligence.
|
| 159 |
+
|
| 160 |
+
## Input:
|
| 161 |
+
**Input Type:** Video and Text
|
| 162 |
+
**Input Format:** MP4 and other video fromats
|
| 163 |
+
|
| 164 |
+
## Output:
|
| 165 |
+
**Output Type:** Text
|
| 166 |
+
**Output Format:** String
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
**[Preferred/Supported] Operating System(s):** <br>
|
| 170 |
+
Linux
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
## Inference:
|
| 174 |
+
**Engine:** [Tensor(RT), Triton, Or List Other Here]
|
| 175 |
+
* PyTorch
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
**Test Hardware:**
|
| 179 |
+
* A100
|
| 180 |
+
* H100
|
| 181 |
+
* A6000
|
| 182 |
+
|
| 183 |
+
## Ethical Considerations
|
| 184 |
+
NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse.
|
auto_processor.py
ADDED
|
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import os
|
| 3 |
+
import os.path as osp
|
| 4 |
+
import warnings
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
from typing import List, Optional, Union
|
| 8 |
+
|
| 9 |
+
import PIL.Image
|
| 10 |
+
import requests
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoProcessor, AutoTokenizer
|
| 13 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 14 |
+
from transformers.image_utils import ImageInput
|
| 15 |
+
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
| 16 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 17 |
+
from transformers.utils import logging
|
| 18 |
+
|
| 19 |
+
from .constants import DEFAULT_IMAGE_TOKEN, MEDIA_TOKENS
|
| 20 |
+
from .media import Image, Video, extract_media
|
| 21 |
+
from .mm_utils import process_image, process_images
|
| 22 |
+
from .tokenizer_utils import tokenize_conversation
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def to_rgb(pil_image: PIL.Image.Image) -> PIL.Image.Image:
|
| 26 |
+
if pil_image.mode == "RGBA":
|
| 27 |
+
white_background = PIL.Image.new("RGB", pil_image.size, (255, 255, 255))
|
| 28 |
+
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
|
| 29 |
+
return white_background
|
| 30 |
+
else:
|
| 31 |
+
return pil_image.convert("RGB")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def fetch_image(ele: dict[str, str | PIL.Image.Image], size_factor=None) -> PIL.Image.Image:
|
| 35 |
+
if "image" in ele:
|
| 36 |
+
image = ele["image"]
|
| 37 |
+
else:
|
| 38 |
+
image = ele["image_url"]
|
| 39 |
+
image_obj = None
|
| 40 |
+
if isinstance(image, PIL.Image.Image):
|
| 41 |
+
image_obj = image
|
| 42 |
+
elif image.startswith("http://") or image.startswith("https://"):
|
| 43 |
+
response = requests.get(image, stream=True)
|
| 44 |
+
image_obj = PIL.Image.open(BytesIO(response.content))
|
| 45 |
+
elif image.startswith("file://"):
|
| 46 |
+
image_obj = PIL.Image.open(image[7:])
|
| 47 |
+
elif image.startswith("data:image"):
|
| 48 |
+
if "base64," in image:
|
| 49 |
+
_, base64_data = image.split("base64,", 1)
|
| 50 |
+
data = base64.b64decode(base64_data)
|
| 51 |
+
image_obj = PIL.Image.open(BytesIO(data))
|
| 52 |
+
else:
|
| 53 |
+
image_obj = PIL.Image.open(image)
|
| 54 |
+
if image_obj is None:
|
| 55 |
+
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
|
| 56 |
+
image = to_rgb(image_obj)
|
| 57 |
+
|
| 58 |
+
return image
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def fetch_image_url_or_fpath(url_or_fpath):
|
| 62 |
+
if url_or_fpath.startswith("http") or url_or_fpath.startswith("https"):
|
| 63 |
+
import tempfile
|
| 64 |
+
|
| 65 |
+
import requests
|
| 66 |
+
|
| 67 |
+
# Download the image to a temporary file
|
| 68 |
+
temp_dir = tempfile.mkdtemp()
|
| 69 |
+
temp_file = os.path.join(temp_dir, os.path.basename(url_or_fpath))
|
| 70 |
+
|
| 71 |
+
response = requests.get(url_or_fpath, stream=True)
|
| 72 |
+
response.raise_for_status()
|
| 73 |
+
|
| 74 |
+
with open(temp_file, "wb") as f:
|
| 75 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 76 |
+
f.write(chunk)
|
| 77 |
+
|
| 78 |
+
return temp_file
|
| 79 |
+
elif url_or_fpath.startswith("file://"):
|
| 80 |
+
fpath = url_or_fpath.replace("file://", "")
|
| 81 |
+
assert osp.exists(fpath), f"File {fpath} does not exist"
|
| 82 |
+
return fpath
|
| 83 |
+
elif osp.exists(url_or_fpath):
|
| 84 |
+
assert osp.isfile(url_or_fpath), f"File {url_or_fpath} does not exist"
|
| 85 |
+
return url_or_fpath
|
| 86 |
+
else:
|
| 87 |
+
raise ValueError(f"Unsupported image path: {url_or_fpath}")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def pad_fn(input_ids_list: List[torch.Tensor], padding_value=0, target_len=None, padding_side="left") -> torch.Tensor:
|
| 91 |
+
# tensor shape is (batch_size, seq_len)
|
| 92 |
+
max_len = max([ids.shape[1] for ids in input_ids_list])
|
| 93 |
+
if target_len is not None:
|
| 94 |
+
assert target_len >= max_len, "target_len must be greater than or equal to max_len"
|
| 95 |
+
max_len = target_len
|
| 96 |
+
|
| 97 |
+
new_input_ids_list = []
|
| 98 |
+
for i, input_ids in enumerate(input_ids_list):
|
| 99 |
+
pad_tensor = torch.ones_like(input_ids) * padding_value
|
| 100 |
+
curr_len = input_ids.shape[1]
|
| 101 |
+
pad_tensor = pad_tensor[:, : max_len - curr_len]
|
| 102 |
+
if padding_side == "right":
|
| 103 |
+
input_ids = torch.cat((input_ids, pad_tensor), dim=1)
|
| 104 |
+
else:
|
| 105 |
+
input_ids = torch.cat((pad_tensor, input_ids), dim=1)
|
| 106 |
+
new_input_ids_list.append(input_ids)
|
| 107 |
+
return torch.cat(new_input_ids_list, dim=0)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def extract_value_from_conv(chat):
|
| 111 |
+
value = []
|
| 112 |
+
if isinstance(chat["content"], str):
|
| 113 |
+
# vila_chat["value"].append(chat["content"])
|
| 114 |
+
value.append(chat["content"])
|
| 115 |
+
return value
|
| 116 |
+
|
| 117 |
+
# otherwise, it's a list of content
|
| 118 |
+
for content in chat["content"]:
|
| 119 |
+
if content["type"] == "image":
|
| 120 |
+
if "path" in content:
|
| 121 |
+
# VILA style, can be either filepath or http url
|
| 122 |
+
value.append(Image(fetch_image_url_or_fpath(content["path"])))
|
| 123 |
+
elif "image" in content:
|
| 124 |
+
# Qwen style
|
| 125 |
+
value.append(Image(fetch_image_url_or_fpath(content["image"])))
|
| 126 |
+
elif "image_pil" in content:
|
| 127 |
+
# Qwen style
|
| 128 |
+
assert isinstance(content["image_pil"], PIL.Image.Image), f"Type of {media_key} must be PIL.Image.Image"
|
| 129 |
+
value.append(content["image_pil"])
|
| 130 |
+
else:
|
| 131 |
+
raise ValueError(f"Type = `image` , but no `path` or `image` in | {content=}, {conversation=}")
|
| 132 |
+
elif content["type"] == "video":
|
| 133 |
+
if "video" in content:
|
| 134 |
+
# Qwen style
|
| 135 |
+
value.append(Video(fetch_image_url_or_fpath(content["video"])))
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError(f"Type = `video` , but no `video` in | {content=}, {conversation=}")
|
| 138 |
+
elif content["type"] == "text":
|
| 139 |
+
value.append(content["text"])
|
| 140 |
+
else:
|
| 141 |
+
raise ValueError(f"Unsupported content type: {content['type']}")
|
| 142 |
+
return value
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class VILAProcessorKwargs(ProcessingKwargs, total=False):
|
| 146 |
+
_defaults = {
|
| 147 |
+
"text_kwargs": {
|
| 148 |
+
"padding": False,
|
| 149 |
+
},
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class VILAProcessor(ProcessorMixin):
|
| 154 |
+
# attributes = ["image_processor", "tokenizer"]
|
| 155 |
+
attributes = []
|
| 156 |
+
# valid_kwargs = ["chat_template"]
|
| 157 |
+
valid_kwargs = []
|
| 158 |
+
# image_processor_class = "VILAImageProcessor"
|
| 159 |
+
# tokenizer_class = ("VILATokenizer", "VILATokenizerFast")
|
| 160 |
+
|
| 161 |
+
def __init__(
|
| 162 |
+
self, image_processor=None, tokenizer=None, chat_template=None, config=None, padding_side="left", **kwargs
|
| 163 |
+
):
|
| 164 |
+
self.image_token = MEDIA_TOKENS["image"]
|
| 165 |
+
self.video_token = MEDIA_TOKENS["video"]
|
| 166 |
+
self.config = config
|
| 167 |
+
self.image_processor = image_processor
|
| 168 |
+
self.tokenizer = tokenizer
|
| 169 |
+
self.padding_side = padding_side
|
| 170 |
+
|
| 171 |
+
# This is a special setting for Qwen.
|
| 172 |
+
# self.pad_token_id = tokenizer.pad_token_id
|
| 173 |
+
self.pad_token_id = self.tokenizer("<|endoftext|>").input_ids[0] # 151643
|
| 174 |
+
self.eos_token_id = self.tokenizer.eos_token_id
|
| 175 |
+
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
| 176 |
+
|
| 177 |
+
@staticmethod
|
| 178 |
+
def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
|
| 179 |
+
"""
|
| 180 |
+
referernce from qwen_vl_utils
|
| 181 |
+
"""
|
| 182 |
+
vision_infos = []
|
| 183 |
+
if isinstance(conversations[0], dict):
|
| 184 |
+
conversations = [conversations]
|
| 185 |
+
for conversation in conversations:
|
| 186 |
+
for message in conversation:
|
| 187 |
+
if isinstance(message["content"], list):
|
| 188 |
+
for ele in message["content"]:
|
| 189 |
+
if (
|
| 190 |
+
"image" in ele
|
| 191 |
+
or "image_url" in ele
|
| 192 |
+
or "video" in ele
|
| 193 |
+
or ele["type"] in ("image", "image_url", "video")
|
| 194 |
+
):
|
| 195 |
+
vision_infos.append(ele)
|
| 196 |
+
return vision_infos
|
| 197 |
+
|
| 198 |
+
@staticmethod
|
| 199 |
+
def process_vision_info(
|
| 200 |
+
conversations: list[dict] | list[list[dict]],
|
| 201 |
+
return_video_kwargs: bool = False,
|
| 202 |
+
) -> tuple[list[PIL.Image.Image] | None, list[torch.Tensor | list[PIL.Image.Image]] | None, Optional[dict]]:
|
| 203 |
+
"""
|
| 204 |
+
referernce from qwen_vl_utils
|
| 205 |
+
NVILA does not depend on the function, but the interface is the same.
|
| 206 |
+
"""
|
| 207 |
+
vision_infos = extract_vision_info(conversations)
|
| 208 |
+
## Read images or videos
|
| 209 |
+
image_inputs = []
|
| 210 |
+
video_inputs = []
|
| 211 |
+
video_sample_fps_list = []
|
| 212 |
+
for vision_info in vision_infos:
|
| 213 |
+
if "image" in vision_info or "image_url" in vision_info:
|
| 214 |
+
image_inputs.append(fetch_image(vision_info))
|
| 215 |
+
elif "video" in vision_info:
|
| 216 |
+
video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
|
| 217 |
+
video_sample_fps_list.append(video_sample_fps)
|
| 218 |
+
video_inputs.append(video_input)
|
| 219 |
+
else:
|
| 220 |
+
raise ValueError("image, image_url or video should in content.")
|
| 221 |
+
if len(image_inputs) == 0:
|
| 222 |
+
image_inputs = None
|
| 223 |
+
if len(video_inputs) == 0:
|
| 224 |
+
video_inputs = None
|
| 225 |
+
if return_video_kwargs:
|
| 226 |
+
return image_inputs, video_inputs, {"fps": video_sample_fps_list}
|
| 227 |
+
return image_inputs, video_inputs
|
| 228 |
+
|
| 229 |
+
@staticmethod
|
| 230 |
+
def move_data_to_device(cls, prompt_inputs):
|
| 231 |
+
def _move_data_to_device(item):
|
| 232 |
+
# wrap function grpo trainer _prepare_input
|
| 233 |
+
kwargs = {"device": cls.args.device}
|
| 234 |
+
if cls.is_deepspeed_enabled and (torch.is_floating_point(item) or torch.is_complex(item)):
|
| 235 |
+
kwargs.update({"dtype": cls.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
|
| 236 |
+
return item.to(**kwargs)
|
| 237 |
+
|
| 238 |
+
prompt_inputs.input_ids = _move_data_to_device(prompt_inputs.input_ids)
|
| 239 |
+
prompt_inputs.attention_mask = _move_data_to_device(prompt_inputs.attention_mask)
|
| 240 |
+
if "image" in prompt_inputs.media:
|
| 241 |
+
prompt_inputs.media["image"] = [_move_data_to_device(img) for img in prompt_inputs.media["image"]]
|
| 242 |
+
return prompt_inputs
|
| 243 |
+
|
| 244 |
+
@classmethod
|
| 245 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| 246 |
+
padding_side = kwargs.get("padding_side", "left")
|
| 247 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
| 248 |
+
pretrained_model_name_or_path = pretrained_model_name_or_path
|
| 249 |
+
else:
|
| 250 |
+
print(f"pretrained_model_name_or_path {pretrained_model_name_or_path} is not a directory, downloading")
|
| 251 |
+
from huggingface_hub import snapshot_download
|
| 252 |
+
|
| 253 |
+
pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path)
|
| 254 |
+
|
| 255 |
+
image_processor = AutoImageProcessor.from_pretrained(
|
| 256 |
+
osp.join(pretrained_model_name_or_path, "vision_tower"), trust_remote_code=True
|
| 257 |
+
)
|
| 258 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 259 |
+
osp.join(pretrained_model_name_or_path, "llm"), trust_remote_code=True
|
| 260 |
+
)
|
| 261 |
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
|
| 262 |
+
return cls(image_processor=image_processor, tokenizer=tokenizer, config=config, padding_side=padding_side)
|
| 263 |
+
|
| 264 |
+
def __repr__(self):
|
| 265 |
+
return f"VILAProcessor(image_processor=SigLip, tokenizer={self.tokenizer}, config={self.config})"
|
| 266 |
+
|
| 267 |
+
def __call__(
|
| 268 |
+
self,
|
| 269 |
+
conversation=None,
|
| 270 |
+
**kwargs: Unpack[VILAProcessorKwargs],
|
| 271 |
+
) -> BatchFeature:
|
| 272 |
+
"""
|
| 273 |
+
The `conv` will be look like
|
| 274 |
+
[
|
| 275 |
+
{
|
| 276 |
+
'from': 'human',
|
| 277 |
+
'value': [
|
| 278 |
+
<transformers_modules.NVILA-Lite-2B-hf-preview.media.Image object at 0x154e68e4c460>,
|
| 279 |
+
'What are the common elements in these pictures?'
|
| 280 |
+
]
|
| 281 |
+
}
|
| 282 |
+
]
|
| 283 |
+
and `conversation` will be a list of such `conv`s
|
| 284 |
+
"""
|
| 285 |
+
if kwargs.get("text", None) is not None:
|
| 286 |
+
conversation = kwargs.get("text")
|
| 287 |
+
assert conversation is not None, "`conversation` or `text` is required"
|
| 288 |
+
padding_side = kwargs.get("padding_side", self.padding_side)
|
| 289 |
+
|
| 290 |
+
input_ids_list = []
|
| 291 |
+
attention_mask = []
|
| 292 |
+
media = defaultdict(list)
|
| 293 |
+
media_config = defaultdict(dict)
|
| 294 |
+
for conv in conversation:
|
| 295 |
+
feat = self.__single_call__(conv, **kwargs)
|
| 296 |
+
input_ids_list.append(feat.input_ids)
|
| 297 |
+
attention_mask.append(feat.attention_mask)
|
| 298 |
+
for name in feat.media:
|
| 299 |
+
media[name] += feat.media[name]
|
| 300 |
+
for name in feat.media_config:
|
| 301 |
+
media_config[name].update(feat.media_config[name])
|
| 302 |
+
|
| 303 |
+
# pad the input_ids to batchfy
|
| 304 |
+
input_ids = pad_fn(
|
| 305 |
+
input_ids_list,
|
| 306 |
+
padding_value=self.pad_token_id,
|
| 307 |
+
padding_side=padding_side,
|
| 308 |
+
)
|
| 309 |
+
# ignore the pad token in the attention mask
|
| 310 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
| 311 |
+
attention_mask[input_ids == self.pad_token_id] = False
|
| 312 |
+
input_texts = self.tokenizer.batch_decode(input_ids)
|
| 313 |
+
bdata = BatchFeature(
|
| 314 |
+
data={
|
| 315 |
+
# "input_texts": input_texts,
|
| 316 |
+
"input_ids": input_ids,
|
| 317 |
+
"attention_mask": attention_mask,
|
| 318 |
+
"media": media,
|
| 319 |
+
"media_config": media_config,
|
| 320 |
+
}
|
| 321 |
+
)
|
| 322 |
+
# NOTE: hard coded to cuda
|
| 323 |
+
# bdata.input_ids = bdata.input_ids.cuda()
|
| 324 |
+
# bdata.attention_mask = bdata.attention_mask.cuda()
|
| 325 |
+
# bdata.media["image"] = [img.cuda() for img in bdata.media["image"]]
|
| 326 |
+
return bdata
|
| 327 |
+
|
| 328 |
+
def __single_call__(
|
| 329 |
+
self,
|
| 330 |
+
conversation,
|
| 331 |
+
images: ImageInput = None,
|
| 332 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
| 333 |
+
videos = None,
|
| 334 |
+
**kwargs: Unpack[VILAProcessorKwargs],
|
| 335 |
+
) -> BatchFeature:
|
| 336 |
+
conversation = copy.deepcopy(conversation)
|
| 337 |
+
media = extract_media(conversation, self.config)
|
| 338 |
+
# Process media
|
| 339 |
+
media_config = defaultdict(dict)
|
| 340 |
+
for name in media:
|
| 341 |
+
if name == "image":
|
| 342 |
+
if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]:
|
| 343 |
+
self.config.image_processor = self.image_processor
|
| 344 |
+
if self.config.image_aspect_ratio == "dynamic":
|
| 345 |
+
images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half()
|
| 346 |
+
# NOTE: this only works for images appears at the first conversation
|
| 347 |
+
conversation[0]["value"] = conversation[0]["value"].replace(
|
| 348 |
+
DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0]
|
| 349 |
+
)
|
| 350 |
+
else:
|
| 351 |
+
if type(self.config.s2_scales) is str:
|
| 352 |
+
self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
|
| 353 |
+
images, block_sizes = process_image(
|
| 354 |
+
media["image"][0], self.config, None, enable_dynamic_s2=True
|
| 355 |
+
)
|
| 356 |
+
images = images.half()
|
| 357 |
+
media_config[name]["block_sizes"] = [block_sizes]
|
| 358 |
+
else:
|
| 359 |
+
images = process_images(media["image"], self.image_processor, self.config).half()
|
| 360 |
+
media[name] = [image for image in images]
|
| 361 |
+
elif name == "video":
|
| 362 |
+
media[name] = [
|
| 363 |
+
process_images(images, self.image_processor, self.config).half() for images in media[name]
|
| 364 |
+
]
|
| 365 |
+
else:
|
| 366 |
+
raise ValueError(f"Unsupported media type: {name}")
|
| 367 |
+
|
| 368 |
+
inputs = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True, return_ids_only=False)
|
| 369 |
+
input_ids = inputs.input_ids[0].unsqueeze(0).cuda()
|
| 370 |
+
|
| 371 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
| 372 |
+
return BatchFeature(
|
| 373 |
+
data={
|
| 374 |
+
"input_ids": input_ids,
|
| 375 |
+
"attention_mask": attention_mask,
|
| 376 |
+
"media": media,
|
| 377 |
+
"media_config": media_config,
|
| 378 |
+
}
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
def batch_decode(self, *args, **kwargs):
|
| 382 |
+
"""
|
| 383 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 384 |
+
refer to the docstring of this method for more information.
|
| 385 |
+
"""
|
| 386 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 387 |
+
|
| 388 |
+
def decode(self, *args, **kwargs):
|
| 389 |
+
"""
|
| 390 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 391 |
+
the docstring of this method for more information.
|
| 392 |
+
"""
|
| 393 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 394 |
+
|
| 395 |
+
def post_process_image_text_to_text(self, generated_outputs):
|
| 396 |
+
"""
|
| 397 |
+
Post-process the output of the model to decode the text.
|
| 398 |
+
|
| 399 |
+
Args:
|
| 400 |
+
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
| 401 |
+
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
| 402 |
+
or `(sequence_length,)`.
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
`List[str]`: The decoded text.
|
| 406 |
+
"""
|
| 407 |
+
return self.tokenizer.batch_decode(
|
| 408 |
+
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
@property
|
| 412 |
+
def model_input_names(self):
|
| 413 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 414 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 415 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
| 416 |
+
|
| 417 |
+
def convert_gpt_conv_to_vila_conv(self, conversation):
|
| 418 |
+
vila_conv = []
|
| 419 |
+
for chat in conversation:
|
| 420 |
+
vila_chat = {"from": "", "value": []}
|
| 421 |
+
if chat["role"] in ("user", "system"):
|
| 422 |
+
# user allows to input image and text
|
| 423 |
+
vila_chat["from"] = "human" if chat["role"] == "user" else "system"
|
| 424 |
+
vila_chat["value"] = extract_value_from_conv(chat)
|
| 425 |
+
elif chat["role"] == "assistant":
|
| 426 |
+
vila_chat["from"] = "gpt"
|
| 427 |
+
vila_chat["value"] = extract_value_from_conv(chat)
|
| 428 |
+
else:
|
| 429 |
+
raise ValueError(f"Unsupported role: {chat['role']} in chat {chat}")
|
| 430 |
+
vila_conv.append(vila_chat)
|
| 431 |
+
|
| 432 |
+
return vila_conv
|
| 433 |
+
|
| 434 |
+
def apply_chat_template(self, conversation, add_generation_prompt=True, **kwargs):
|
| 435 |
+
return self.convert_gpt_conv_to_vila_conv(conversation)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
if __name__ == "__main__":
|
| 439 |
+
# gpt style: user, assistant
|
| 440 |
+
# vila style: human, gpt
|
| 441 |
+
gpt_conv = [
|
| 442 |
+
{
|
| 443 |
+
"role": "user",
|
| 444 |
+
"content": [
|
| 445 |
+
{"type": "image", "path": "demo_images/demo_img_1.png"},
|
| 446 |
+
{"type": "text", "text": "Describe this image."},
|
| 447 |
+
],
|
| 448 |
+
}
|
| 449 |
+
]
|
| 450 |
+
|
| 451 |
+
llavaconv = [
|
| 452 |
+
{
|
| 453 |
+
"from": "human",
|
| 454 |
+
"value": [
|
| 455 |
+
PIL.Image.open("demo_images/demo_img_1.png"),
|
| 456 |
+
"Describe this image.",
|
| 457 |
+
],
|
| 458 |
+
}
|
| 459 |
+
]
|
| 460 |
+
|
| 461 |
+
processor = AutoProcessor.from_pretrained(output_dir, trust_remote_code=True)
|
| 462 |
+
inputs = processor.apply_chat_template(conversation=gpt_conv, padding=True, return_tensors="pt")
|
| 463 |
+
# model = llava.load("Efficient-Large-Model/qwen25_2B_3x3-sft").cuda()
|
| 464 |
+
# print(model)
|
| 465 |
+
model_path = "NVILA-Lite-2B-hf-preview"
|
| 466 |
+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto")
|
| 467 |
+
# res = model.generate_content(["how are you today?"])
|
| 468 |
+
# print(model.config)
|
| 469 |
+
# print(model.tokenizer)
|
| 470 |
+
# print(res)
|
| 471 |
+
|
| 472 |
+
processor = VILAProcessor(
|
| 473 |
+
config=model.config,
|
| 474 |
+
image_processor=model.vision_tower.image_processor,
|
| 475 |
+
tokenizer=model.tokenizer,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
inputs = processor(conversation=llavaconv, padding=True, return_tensors="pt")
|
| 479 |
+
print(inputs.keys(), inputs.input_ids.shape, [_.shape for _ in inputs.image])
|
| 480 |
+
print("vila conv pass")
|
| 481 |
+
|
| 482 |
+
inputs = processor.apply_chat_template(conversation=gpt_conv, padding=True, return_tensors="pt")
|
| 483 |
+
print(inputs.keys(), inputs.input_ids.shape, [_.shape for _ in inputs.image])
|
| 484 |
+
print("gpt conv pass")
|
| 485 |
+
|
| 486 |
+
output_ids = model.generate(
|
| 487 |
+
input_ids=inputs.input_ids,
|
| 488 |
+
media={
|
| 489 |
+
"image": inputs.image,
|
| 490 |
+
},
|
| 491 |
+
media_config={"image": {}},
|
| 492 |
+
generation_config=model.generation_config,
|
| 493 |
+
max_new_tokens=100,
|
| 494 |
+
)
|
| 495 |
+
print(output_ids)
|
base_projector.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import re
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class IdentityMap(nn.Module):
|
| 25 |
+
def __init__(self):
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
def forward(self, x, *args, **kwargs):
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def config(self):
|
| 33 |
+
return {"mm_projector_type": "identity"}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SimpleResBlock(nn.Module):
|
| 37 |
+
def __init__(self, channels):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.pre_norm = nn.LayerNorm(channels)
|
| 40 |
+
|
| 41 |
+
self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
x = self.pre_norm(x)
|
| 45 |
+
return x + self.proj(x)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class DownSampleBlock(nn.Module):
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
vit_embeds = x
|
| 51 |
+
h = w = int(vit_embeds.shape[1] ** 0.5)
|
| 52 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
| 53 |
+
vit_embeds = self.flat_square(vit_embeds)
|
| 54 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
| 55 |
+
return vit_embeds
|
| 56 |
+
|
| 57 |
+
def flat_square(self, x):
|
| 58 |
+
n, w, h, c = x.size()
|
| 59 |
+
if w % 2 == 1:
|
| 60 |
+
x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
|
| 61 |
+
n, w, h, c = x.size()
|
| 62 |
+
if h % 2 == 1:
|
| 63 |
+
x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
|
| 64 |
+
n, w, h, c = x.size()
|
| 65 |
+
x = x.contiguous()
|
| 66 |
+
x = x.view(n, w, int(h / 2), int(c * 2))
|
| 67 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 68 |
+
x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
|
| 69 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class DownSample2x2BlockFix(nn.Module):
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
vit_embeds = x
|
| 76 |
+
h = w = int(vit_embeds.shape[1] ** 0.5)
|
| 77 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
| 78 |
+
vit_embeds = flat_square_2x2(vit_embeds)
|
| 79 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
| 80 |
+
return vit_embeds
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def flat_square_2x2(x):
|
| 84 |
+
n, w, h, c = x.size()
|
| 85 |
+
if w % 2 == 1:
|
| 86 |
+
x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
|
| 87 |
+
n, w, h, c = x.size()
|
| 88 |
+
x = x.contiguous()
|
| 89 |
+
if h % 2 == 1:
|
| 90 |
+
x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
|
| 91 |
+
n, w, h, c = x.size()
|
| 92 |
+
x = x.view(n, w, int(h / 2), int(c * 2))
|
| 93 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 94 |
+
x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
|
| 95 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 96 |
+
return x
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class DownSample3x3BlockFix(nn.Module):
|
| 100 |
+
def forward(self, x):
|
| 101 |
+
vit_embeds = x
|
| 102 |
+
h = w = int(vit_embeds.shape[1] ** 0.5)
|
| 103 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
| 104 |
+
vit_embeds = flat_square_3x3(vit_embeds)
|
| 105 |
+
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
| 106 |
+
return vit_embeds
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def flat_square_3x3(x):
|
| 110 |
+
n, w, h, c = x.size()
|
| 111 |
+
if w % 3 != 0:
|
| 112 |
+
x = torch.concat([x, torch.zeros((n, 3 - (w % 3), h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
|
| 113 |
+
n, w, h, c = x.size()
|
| 114 |
+
x = x.contiguous()
|
| 115 |
+
if h % 3 != 0:
|
| 116 |
+
x = torch.concat([x, torch.zeros((n, w, 3 - (h % 3), c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
|
| 117 |
+
n, w, h, c = x.size()
|
| 118 |
+
x = x.view(n, w, int(h / 3), int(c * 3))
|
| 119 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 120 |
+
x = x.view(n, int(h / 3), int(w / 3), int(c * 9))
|
| 121 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 122 |
+
return x
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class MultimodalProjectorConfig(PretrainedConfig):
|
| 126 |
+
model_type = "v2l_projector"
|
| 127 |
+
|
| 128 |
+
def __init__(self, mm_projector_type: str = None, **kwargs):
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.mm_projector_type = mm_projector_type
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class MultimodalProjector(PreTrainedModel):
|
| 134 |
+
config_class = MultimodalProjectorConfig
|
| 135 |
+
|
| 136 |
+
def __init__(self, mm_projector_cfg: MultimodalProjectorConfig, config: PretrainedConfig):
|
| 137 |
+
super().__init__(mm_projector_cfg)
|
| 138 |
+
mm_projector_type = mm_projector_cfg.mm_projector_type
|
| 139 |
+
self.downsample_rate = 1
|
| 140 |
+
if mm_projector_type == "identity":
|
| 141 |
+
self.layers = IdentityMap()
|
| 142 |
+
elif mm_projector_type == "linear":
|
| 143 |
+
self.layers = nn.Linear(config.mm_hidden_size, config.hidden_size)
|
| 144 |
+
elif mm_projector_type == "mlp_downsample":
|
| 145 |
+
self.layers = nn.Sequential(
|
| 146 |
+
DownSampleBlock(),
|
| 147 |
+
nn.LayerNorm(config.mm_hidden_size * 4),
|
| 148 |
+
nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
|
| 149 |
+
nn.GELU(),
|
| 150 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
| 151 |
+
)
|
| 152 |
+
self.downsample_rate = 2
|
| 153 |
+
elif mm_projector_type == "mlp_downsample_2x2_fix":
|
| 154 |
+
self.layers = nn.Sequential(
|
| 155 |
+
DownSample2x2BlockFix(),
|
| 156 |
+
nn.LayerNorm(config.mm_hidden_size * 4),
|
| 157 |
+
nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
|
| 158 |
+
nn.GELU(),
|
| 159 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
| 160 |
+
)
|
| 161 |
+
self.downsample_rate = 2
|
| 162 |
+
elif mm_projector_type == "mlp_downsample_3x3_fix":
|
| 163 |
+
self.layers = nn.Sequential(
|
| 164 |
+
DownSample3x3BlockFix(),
|
| 165 |
+
nn.LayerNorm(config.mm_hidden_size * 9),
|
| 166 |
+
nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 3),
|
| 167 |
+
nn.GELU(),
|
| 168 |
+
nn.LayerNorm(config.mm_hidden_size * 3),
|
| 169 |
+
nn.Linear(config.mm_hidden_size * 3, config.hidden_size),
|
| 170 |
+
nn.GELU(),
|
| 171 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
| 172 |
+
)
|
| 173 |
+
self.downsample_rate = 3
|
| 174 |
+
elif mm_projector_type == "mlp_downsample_3x3_s2":
|
| 175 |
+
self.layers = nn.Sequential(
|
| 176 |
+
DownSample3x3BlockFix(),
|
| 177 |
+
nn.LayerNorm(config.mm_hidden_size * 9),
|
| 178 |
+
nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 3),
|
| 179 |
+
nn.GELU(),
|
| 180 |
+
nn.LayerNorm(config.mm_hidden_size * 3),
|
| 181 |
+
nn.Linear(config.mm_hidden_size * 3, config.mm_hidden_size),
|
| 182 |
+
nn.GELU(),
|
| 183 |
+
nn.LayerNorm(config.mm_hidden_size),
|
| 184 |
+
nn.Linear(config.mm_hidden_size, config.mm_hidden_size // 3),
|
| 185 |
+
nn.GELU(),
|
| 186 |
+
nn.LayerNorm(config.mm_hidden_size // 3),
|
| 187 |
+
nn.Linear(config.mm_hidden_size // 3, config.hidden_size),
|
| 188 |
+
nn.GELU(),
|
| 189 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
| 190 |
+
)
|
| 191 |
+
elif mm_projector_type == "mlp_downsample_3x3_s2_new":
|
| 192 |
+
self.layers = nn.Sequential(
|
| 193 |
+
DownSample3x3BlockFix(),
|
| 194 |
+
nn.LayerNorm(config.mm_hidden_size * 9),
|
| 195 |
+
nn.Linear(config.mm_hidden_size * 9, config.mm_hidden_size * 4),
|
| 196 |
+
nn.GELU(),
|
| 197 |
+
nn.LayerNorm(config.mm_hidden_size * 4),
|
| 198 |
+
nn.Linear(config.mm_hidden_size * 4, config.mm_hidden_size * 2),
|
| 199 |
+
nn.GELU(),
|
| 200 |
+
nn.LayerNorm(config.mm_hidden_size * 2),
|
| 201 |
+
nn.Linear(config.mm_hidden_size * 2, config.mm_hidden_size),
|
| 202 |
+
nn.GELU(),
|
| 203 |
+
nn.LayerNorm(config.mm_hidden_size),
|
| 204 |
+
nn.Linear(config.mm_hidden_size, config.mm_hidden_size // 3),
|
| 205 |
+
nn.GELU(),
|
| 206 |
+
nn.LayerNorm(config.mm_hidden_size // 3),
|
| 207 |
+
nn.Linear(config.mm_hidden_size // 3, config.hidden_size),
|
| 208 |
+
nn.GELU(),
|
| 209 |
+
nn.Linear(config.hidden_size, config.hidden_size),
|
| 210 |
+
)
|
| 211 |
+
else:
|
| 212 |
+
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", mm_projector_type)
|
| 213 |
+
if mlp_gelu_match:
|
| 214 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
| 215 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
| 216 |
+
for _ in range(1, mlp_depth):
|
| 217 |
+
modules.append(nn.GELU())
|
| 218 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
| 219 |
+
self.layers = nn.Sequential(*modules)
|
| 220 |
+
else:
|
| 221 |
+
raise ValueError(f"Unknown projector type: {mm_projector_type}")
|
| 222 |
+
|
| 223 |
+
def forward(self, x, *args, **kwargs):
|
| 224 |
+
return self.layers(x)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# AutoConfig.register("v2l_projector", MultimodalProjectorConfig)
|
| 228 |
+
# AutoModel.register(MultimodalProjectorConfig, MultimodalProjector)
|
builder.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import os
|
| 19 |
+
import os.path as osp
|
| 20 |
+
import warnings
|
| 21 |
+
from dataclasses import asdict
|
| 22 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import transformers
|
| 26 |
+
from huggingface_hub import file_exists, repo_exists
|
| 27 |
+
from huggingface_hub.utils import HFValidationError
|
| 28 |
+
from transformers import (
|
| 29 |
+
AutoConfig,
|
| 30 |
+
AutoModelForCausalLM,
|
| 31 |
+
AutoTokenizer,
|
| 32 |
+
PretrainedConfig,
|
| 33 |
+
PreTrainedModel,
|
| 34 |
+
PreTrainedTokenizer,
|
| 35 |
+
)
|
| 36 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 37 |
+
|
| 38 |
+
# from .conversation import *
|
| 39 |
+
from .conversation import SeparatorStyle, default_conversation
|
| 40 |
+
|
| 41 |
+
SENTINEL_TOKEN = "<vila/sentinel>"
|
| 42 |
+
MEDIA_TOKENS = {
|
| 43 |
+
"image": "<image>",
|
| 44 |
+
"video": "<vila/video>",
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
# from llava.model.utils import packing
|
| 48 |
+
# from llava.utils.logging import logger
|
| 49 |
+
# from llava.utils.tokenizer import infer_stop_tokens
|
| 50 |
+
|
| 51 |
+
DUMMY_CONVERSATION = [
|
| 52 |
+
{"from": "human", "value": "question"},
|
| 53 |
+
{"from": "gpt", "value": "answer"},
|
| 54 |
+
] * 10
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def tokenizer_image_token(prompt, tokenizer, return_tensors=None):
|
| 58 |
+
return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def has_tokenizer(repo_id_or_path: str) -> bool:
|
| 62 |
+
# Check if the tokenizer is in a local directory
|
| 63 |
+
if osp.exists(osp.join(repo_id_or_path, "tokenizer_config.json")):
|
| 64 |
+
return True
|
| 65 |
+
|
| 66 |
+
# Check if the tokenizer is in a Hugging Face Hub repo
|
| 67 |
+
try:
|
| 68 |
+
return repo_exists(repo_id_or_path) and file_exists(repo_id_or_path, "tokenizer_config.json")
|
| 69 |
+
except HFValidationError:
|
| 70 |
+
return False
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
|
| 74 |
+
if not hasattr(tokenizer, "sentinel_token"):
|
| 75 |
+
tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
|
| 76 |
+
tokenizer.sentinel_token = SENTINEL_TOKEN
|
| 77 |
+
tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def tokenize_conversation_legacy(
|
| 81 |
+
messages: Sequence[Dict[str, str]],
|
| 82 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 83 |
+
add_generation_prompt: bool = False,
|
| 84 |
+
overrides: Optional[Dict[str, str]] = None,
|
| 85 |
+
no_system_prompt: bool = False,
|
| 86 |
+
) -> torch.Tensor:
|
| 87 |
+
conv = default_conversation.copy()
|
| 88 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| 89 |
+
|
| 90 |
+
if no_system_prompt:
|
| 91 |
+
conv.system = ""
|
| 92 |
+
|
| 93 |
+
# Skip the first message if it is not from human
|
| 94 |
+
if messages[0]["from"] != "human":
|
| 95 |
+
messages = messages[1:]
|
| 96 |
+
|
| 97 |
+
# Add a generation prompt if needed
|
| 98 |
+
if add_generation_prompt:
|
| 99 |
+
messages.append({"from": "gpt", "value": None})
|
| 100 |
+
|
| 101 |
+
conv.messages = []
|
| 102 |
+
for turn, message in enumerate(messages):
|
| 103 |
+
role = roles[message["from"]]
|
| 104 |
+
assert role == conv.roles[turn % 2]
|
| 105 |
+
if overrides is not None and message["from"] in overrides:
|
| 106 |
+
conv.append_message(role, overrides[message["from"]])
|
| 107 |
+
else:
|
| 108 |
+
conv.append_message(role, message["value"])
|
| 109 |
+
|
| 110 |
+
return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def tokenize_conversation(
|
| 114 |
+
messages: Sequence[Dict[str, str]],
|
| 115 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 116 |
+
add_generation_prompt: bool = False,
|
| 117 |
+
overrides: Optional[Dict[str, str]] = None,
|
| 118 |
+
no_system_prompt: bool = False,
|
| 119 |
+
) -> torch.Tensor:
|
| 120 |
+
# Normalize the conversation before tokenization
|
| 121 |
+
for message in messages:
|
| 122 |
+
message["value"] = message["value"].strip()
|
| 123 |
+
|
| 124 |
+
if default_conversation.sep_style != SeparatorStyle.AUTO:
|
| 125 |
+
return tokenize_conversation_legacy(
|
| 126 |
+
messages,
|
| 127 |
+
tokenizer,
|
| 128 |
+
add_generation_prompt=add_generation_prompt,
|
| 129 |
+
overrides=overrides,
|
| 130 |
+
no_system_prompt=no_system_prompt,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
conversation = []
|
| 134 |
+
for m in messages:
|
| 135 |
+
message = {}
|
| 136 |
+
if m["from"] == "human":
|
| 137 |
+
message["role"] = "user"
|
| 138 |
+
elif m["from"] == "gpt":
|
| 139 |
+
message["role"] = "assistant"
|
| 140 |
+
else:
|
| 141 |
+
raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
|
| 142 |
+
|
| 143 |
+
message["content"] = m["value"]
|
| 144 |
+
if overrides is not None and m["from"] in overrides:
|
| 145 |
+
message["content"] = overrides[m["from"]]
|
| 146 |
+
conversation.append(message)
|
| 147 |
+
|
| 148 |
+
if no_system_prompt:
|
| 149 |
+
conversation = [{"role": "system", "content": ""}] + conversation
|
| 150 |
+
|
| 151 |
+
text = tokenizer.apply_chat_template(
|
| 152 |
+
conversation,
|
| 153 |
+
add_generation_prompt=add_generation_prompt,
|
| 154 |
+
tokenize=False,
|
| 155 |
+
)
|
| 156 |
+
return tokenizer_image_token(text, tokenizer, return_tensors="pt")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
|
| 160 |
+
_maybe_add_sentinel_token(tokenizer)
|
| 161 |
+
template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
|
| 162 |
+
|
| 163 |
+
stop_tokens = {tokenizer.eos_token}
|
| 164 |
+
for k in range(template.size(0) - 1):
|
| 165 |
+
if template[k] == tokenizer.sentinel_token_id:
|
| 166 |
+
stop_token = tokenizer.decode(template[k + 1])
|
| 167 |
+
stop_tokens.add(stop_token)
|
| 168 |
+
return list(stop_tokens)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def context_length_extension(config):
|
| 172 |
+
orig_ctx_len = getattr(config, "max_position_embeddings", None)
|
| 173 |
+
model_max_length = getattr(config, "model_max_length", None)
|
| 174 |
+
if orig_ctx_len and model_max_length > orig_ctx_len:
|
| 175 |
+
print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}")
|
| 176 |
+
scaling_factor = float(math.ceil(model_max_length / orig_ctx_len))
|
| 177 |
+
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
| 178 |
+
return config
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def build_llm_and_tokenizer(
|
| 182 |
+
model_name_or_path: str,
|
| 183 |
+
config: PretrainedConfig,
|
| 184 |
+
attn_implementation=None,
|
| 185 |
+
model_max_length=None,
|
| 186 |
+
*args,
|
| 187 |
+
**kwargs,
|
| 188 |
+
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
| 189 |
+
# print(model_name_or_path)
|
| 190 |
+
llm_cfg = AutoConfig.from_pretrained(model_name_or_path)
|
| 191 |
+
llm_cfg._attn_implementation = attn_implementation
|
| 192 |
+
llm_cfg.model_max_length = model_max_length
|
| 193 |
+
if model_max_length is not None:
|
| 194 |
+
context_length_extension(llm_cfg)
|
| 195 |
+
|
| 196 |
+
# Quantization related
|
| 197 |
+
quantization_restore_from_checkpoint = False
|
| 198 |
+
|
| 199 |
+
if quantization_restore_from_checkpoint:
|
| 200 |
+
fp8_model_name_or_path = kwargs.pop("fp8_llm_cfg", None)
|
| 201 |
+
|
| 202 |
+
llm = AutoModelForCausalLM.from_pretrained(
|
| 203 |
+
fp8_model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
|
| 204 |
+
)
|
| 205 |
+
else:
|
| 206 |
+
if is_deepspeed_zero3_enabled():
|
| 207 |
+
# NOTE: found by wei, need to pop out device_map when using zero3
|
| 208 |
+
kwargs.pop("device_map")
|
| 209 |
+
llm = AutoModelForCausalLM.from_pretrained(
|
| 210 |
+
model_name_or_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs
|
| 211 |
+
)
|
| 212 |
+
# packing.patch(llm)
|
| 213 |
+
|
| 214 |
+
# Locate the tokenizer.
|
| 215 |
+
llm_path = model_name_or_path
|
| 216 |
+
if not has_tokenizer(llm_path):
|
| 217 |
+
llm_path = osp.join(llm_path, "llm")
|
| 218 |
+
if not has_tokenizer(llm_path):
|
| 219 |
+
raise ValueError(f"Cannot find tokenizer in {llm_path}.")
|
| 220 |
+
|
| 221 |
+
tokenizer = AutoTokenizer.from_pretrained(llm_path, padding_side="right", use_fast=True, legacy=False)
|
| 222 |
+
if model_max_length is not None:
|
| 223 |
+
tokenizer.model_max_length = model_max_length
|
| 224 |
+
|
| 225 |
+
# Load chat template if specified.
|
| 226 |
+
if getattr(config, "chat_template", None) is not None:
|
| 227 |
+
print(f"Using chat template: {config.chat_template}")
|
| 228 |
+
fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
|
| 229 |
+
if not os.path.exists(fpath):
|
| 230 |
+
fpath = os.path.join(os.path.dirname(model_name_or_path), f"{config.chat_template}.jinja")
|
| 231 |
+
with open(fpath) as fd:
|
| 232 |
+
chat_template = fd.read()
|
| 233 |
+
tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
|
| 234 |
+
|
| 235 |
+
# Set stop tokens for the tokenizer
|
| 236 |
+
tokenizer.stop_tokens = infer_stop_tokens(tokenizer)
|
| 237 |
+
tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens)
|
| 238 |
+
|
| 239 |
+
# Add media tokens to the tokenizer
|
| 240 |
+
tokenizer.media_tokens = MEDIA_TOKENS
|
| 241 |
+
tokenizer.media_token_ids = {}
|
| 242 |
+
for name, token in MEDIA_TOKENS.items():
|
| 243 |
+
tokenizer.add_tokens([token], special_tokens=True)
|
| 244 |
+
tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token)
|
| 245 |
+
|
| 246 |
+
config.hidden_size = llm.config.hidden_size
|
| 247 |
+
return llm, tokenizer
|
config.json
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_attn_implementation_autoset": true,
|
| 3 |
+
"_name_or_path": "./LongVILA-R1-7B",
|
| 4 |
+
"architectures": [
|
| 5 |
+
"VILAForCausalLM"
|
| 6 |
+
],
|
| 7 |
+
"chat_template": null,
|
| 8 |
+
"drop_path_rate": 0.0,
|
| 9 |
+
"fps": 0.0,
|
| 10 |
+
"hidden_size": 3584,
|
| 11 |
+
"image_aspect_ratio": "resize",
|
| 12 |
+
"image_encoder": {
|
| 13 |
+
"_target_": "llava.model.encoders.BasicImageEncoder"
|
| 14 |
+
},
|
| 15 |
+
"interpolate_mode": "linear",
|
| 16 |
+
"llm_cfg": {
|
| 17 |
+
"_attn_implementation_autoset": false,
|
| 18 |
+
"_name_or_path": "./LongVILA-R1-7B/llm",
|
| 19 |
+
"add_cross_attention": false,
|
| 20 |
+
"architectures": [
|
| 21 |
+
"Qwen2ForCausalLM"
|
| 22 |
+
],
|
| 23 |
+
"attention_dropout": 0.0,
|
| 24 |
+
"bad_words_ids": null,
|
| 25 |
+
"begin_suppress_tokens": null,
|
| 26 |
+
"bos_token_id": 151643,
|
| 27 |
+
"chunk_size_feed_forward": 0,
|
| 28 |
+
"cross_attention_hidden_size": null,
|
| 29 |
+
"decoder_start_token_id": null,
|
| 30 |
+
"diversity_penalty": 0.0,
|
| 31 |
+
"do_sample": false,
|
| 32 |
+
"early_stopping": false,
|
| 33 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 34 |
+
"eos_token_id": 151645,
|
| 35 |
+
"exponential_decay_length_penalty": null,
|
| 36 |
+
"finetuning_task": null,
|
| 37 |
+
"forced_bos_token_id": null,
|
| 38 |
+
"forced_eos_token_id": null,
|
| 39 |
+
"hidden_act": "silu",
|
| 40 |
+
"hidden_size": 3584,
|
| 41 |
+
"id2label": {
|
| 42 |
+
"0": "LABEL_0",
|
| 43 |
+
"1": "LABEL_1"
|
| 44 |
+
},
|
| 45 |
+
"initializer_range": 0.02,
|
| 46 |
+
"intermediate_size": 18944,
|
| 47 |
+
"is_decoder": false,
|
| 48 |
+
"is_encoder_decoder": false,
|
| 49 |
+
"label2id": {
|
| 50 |
+
"LABEL_0": 0,
|
| 51 |
+
"LABEL_1": 1
|
| 52 |
+
},
|
| 53 |
+
"length_penalty": 1.0,
|
| 54 |
+
"max_length": 20,
|
| 55 |
+
"max_position_embeddings": 32768,
|
| 56 |
+
"max_window_layers": 28,
|
| 57 |
+
"min_length": 0,
|
| 58 |
+
"model_max_length": 32768,
|
| 59 |
+
"model_type": "qwen2",
|
| 60 |
+
"no_repeat_ngram_size": 0,
|
| 61 |
+
"num_attention_heads": 28,
|
| 62 |
+
"num_beam_groups": 1,
|
| 63 |
+
"num_beams": 1,
|
| 64 |
+
"num_hidden_layers": 28,
|
| 65 |
+
"num_key_value_heads": 4,
|
| 66 |
+
"num_return_sequences": 1,
|
| 67 |
+
"output_attentions": false,
|
| 68 |
+
"output_hidden_states": false,
|
| 69 |
+
"output_scores": false,
|
| 70 |
+
"pad_token_id": null,
|
| 71 |
+
"prefix": null,
|
| 72 |
+
"problem_type": null,
|
| 73 |
+
"pruned_heads": {},
|
| 74 |
+
"remove_invalid_values": false,
|
| 75 |
+
"repetition_penalty": 1.0,
|
| 76 |
+
"return_dict": true,
|
| 77 |
+
"return_dict_in_generate": false,
|
| 78 |
+
"rms_norm_eps": 1e-06,
|
| 79 |
+
"rope_scaling": null,
|
| 80 |
+
"rope_theta": 1000000.0,
|
| 81 |
+
"sep_token_id": null,
|
| 82 |
+
"sliding_window": null,
|
| 83 |
+
"suppress_tokens": null,
|
| 84 |
+
"task_specific_params": null,
|
| 85 |
+
"temperature": 1.0,
|
| 86 |
+
"tf_legacy_loss": false,
|
| 87 |
+
"tie_encoder_decoder": false,
|
| 88 |
+
"tie_word_embeddings": false,
|
| 89 |
+
"tokenizer_class": null,
|
| 90 |
+
"tokenizer_model_max_length": 4096,
|
| 91 |
+
"tokenizer_padding_side": "right",
|
| 92 |
+
"top_k": 50,
|
| 93 |
+
"top_p": 1.0,
|
| 94 |
+
"torch_dtype": "bfloat16",
|
| 95 |
+
"torchscript": false,
|
| 96 |
+
"typical_p": 1.0,
|
| 97 |
+
"use_bfloat16": false,
|
| 98 |
+
"use_cache": false,
|
| 99 |
+
"use_sliding_window": false,
|
| 100 |
+
"vocab_size": 151651
|
| 101 |
+
},
|
| 102 |
+
"mm_hidden_size": 1152,
|
| 103 |
+
"mm_projector": "mlp_downsample_2x2_fix",
|
| 104 |
+
"mm_projector_cfg": {
|
| 105 |
+
"_attn_implementation_autoset": false,
|
| 106 |
+
"_name_or_path": "./LongVILA-R1-7B/mm_projector",
|
| 107 |
+
"add_cross_attention": false,
|
| 108 |
+
"architectures": [
|
| 109 |
+
"MultimodalProjector"
|
| 110 |
+
],
|
| 111 |
+
"bad_words_ids": null,
|
| 112 |
+
"begin_suppress_tokens": null,
|
| 113 |
+
"bos_token_id": null,
|
| 114 |
+
"chunk_size_feed_forward": 0,
|
| 115 |
+
"cross_attention_hidden_size": null,
|
| 116 |
+
"decoder_start_token_id": null,
|
| 117 |
+
"diversity_penalty": 0.0,
|
| 118 |
+
"do_sample": false,
|
| 119 |
+
"early_stopping": false,
|
| 120 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 121 |
+
"eos_token_id": null,
|
| 122 |
+
"exponential_decay_length_penalty": null,
|
| 123 |
+
"finetuning_task": null,
|
| 124 |
+
"forced_bos_token_id": null,
|
| 125 |
+
"forced_eos_token_id": null,
|
| 126 |
+
"id2label": {
|
| 127 |
+
"0": "LABEL_0",
|
| 128 |
+
"1": "LABEL_1"
|
| 129 |
+
},
|
| 130 |
+
"is_decoder": false,
|
| 131 |
+
"is_encoder_decoder": false,
|
| 132 |
+
"label2id": {
|
| 133 |
+
"LABEL_0": 0,
|
| 134 |
+
"LABEL_1": 1
|
| 135 |
+
},
|
| 136 |
+
"length_penalty": 1.0,
|
| 137 |
+
"max_length": 20,
|
| 138 |
+
"min_length": 0,
|
| 139 |
+
"mm_projector_type": "mlp_downsample_2x2_fix",
|
| 140 |
+
"model_type": "v2l_projector",
|
| 141 |
+
"no_repeat_ngram_size": 0,
|
| 142 |
+
"num_beam_groups": 1,
|
| 143 |
+
"num_beams": 1,
|
| 144 |
+
"num_return_sequences": 1,
|
| 145 |
+
"output_attentions": false,
|
| 146 |
+
"output_hidden_states": false,
|
| 147 |
+
"output_scores": false,
|
| 148 |
+
"pad_token_id": null,
|
| 149 |
+
"prefix": null,
|
| 150 |
+
"problem_type": null,
|
| 151 |
+
"pruned_heads": {},
|
| 152 |
+
"remove_invalid_values": false,
|
| 153 |
+
"repetition_penalty": 1.0,
|
| 154 |
+
"return_dict": true,
|
| 155 |
+
"return_dict_in_generate": false,
|
| 156 |
+
"sep_token_id": null,
|
| 157 |
+
"suppress_tokens": null,
|
| 158 |
+
"task_specific_params": null,
|
| 159 |
+
"temperature": 1.0,
|
| 160 |
+
"tf_legacy_loss": false,
|
| 161 |
+
"tie_encoder_decoder": false,
|
| 162 |
+
"tie_word_embeddings": true,
|
| 163 |
+
"tokenizer_class": null,
|
| 164 |
+
"top_k": 50,
|
| 165 |
+
"top_p": 1.0,
|
| 166 |
+
"torch_dtype": "bfloat16",
|
| 167 |
+
"torchscript": false,
|
| 168 |
+
"typical_p": 1.0,
|
| 169 |
+
"use_bfloat16": false
|
| 170 |
+
},
|
| 171 |
+
"mm_projector_lr": null,
|
| 172 |
+
"mm_use_im_patch_token": false,
|
| 173 |
+
"mm_use_im_start_end": false,
|
| 174 |
+
"mm_vision_select_feature": "cls_patch",
|
| 175 |
+
"mm_vision_select_layer": -2,
|
| 176 |
+
"model_dtype": "torch.bfloat16",
|
| 177 |
+
"model_name_or_path": "./LongVILA-R1-7B",
|
| 178 |
+
"model_type": "vila",
|
| 179 |
+
"num_time_tokens": 0,
|
| 180 |
+
"num_video_frames": 256,
|
| 181 |
+
"resume_path": "./LongVILA-R1-7B",
|
| 182 |
+
"s2": false,
|
| 183 |
+
"s2_max_split_size": 336,
|
| 184 |
+
"s2_scales": "336,672,1008",
|
| 185 |
+
"soft_ce_std": 1.0,
|
| 186 |
+
"time_token_format": "<t{t}>",
|
| 187 |
+
"time_token_ids": [],
|
| 188 |
+
"transformers_version": "4.46.2",
|
| 189 |
+
"tune_language_model": true,
|
| 190 |
+
"tune_mm_projector": true,
|
| 191 |
+
"tune_vision_tower": true,
|
| 192 |
+
"version": "2.0",
|
| 193 |
+
"video_encoder": {
|
| 194 |
+
"_target_": "llava.model.encoders.TSPVideoEncoder",
|
| 195 |
+
"pool_sizes": [
|
| 196 |
+
[
|
| 197 |
+
8,
|
| 198 |
+
1,
|
| 199 |
+
1
|
| 200 |
+
]
|
| 201 |
+
]
|
| 202 |
+
},
|
| 203 |
+
"video_max_tiles": 1,
|
| 204 |
+
"vision_resolution": -1,
|
| 205 |
+
"vision_tower": "Efficient-Large-Model/paligemma-siglip-so400m-patch14-448",
|
| 206 |
+
"vision_tower_cfg": {
|
| 207 |
+
"_attn_implementation_autoset": false,
|
| 208 |
+
"_name_or_path": "./LongVILA-R1-7B/vision_tower",
|
| 209 |
+
"add_cross_attention": false,
|
| 210 |
+
"architectures": [
|
| 211 |
+
"SiglipVisionModel"
|
| 212 |
+
],
|
| 213 |
+
"attention_dropout": 0.0,
|
| 214 |
+
"bad_words_ids": null,
|
| 215 |
+
"begin_suppress_tokens": null,
|
| 216 |
+
"bos_token_id": null,
|
| 217 |
+
"chunk_size_feed_forward": 0,
|
| 218 |
+
"cross_attention_hidden_size": null,
|
| 219 |
+
"decoder_start_token_id": null,
|
| 220 |
+
"diversity_penalty": 0.0,
|
| 221 |
+
"do_sample": false,
|
| 222 |
+
"early_stopping": false,
|
| 223 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 224 |
+
"eos_token_id": null,
|
| 225 |
+
"exponential_decay_length_penalty": null,
|
| 226 |
+
"finetuning_task": null,
|
| 227 |
+
"forced_bos_token_id": null,
|
| 228 |
+
"forced_eos_token_id": null,
|
| 229 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 230 |
+
"hidden_size": 1152,
|
| 231 |
+
"id2label": {
|
| 232 |
+
"0": "LABEL_0",
|
| 233 |
+
"1": "LABEL_1"
|
| 234 |
+
},
|
| 235 |
+
"image_size": 448,
|
| 236 |
+
"intermediate_size": 4304,
|
| 237 |
+
"is_decoder": false,
|
| 238 |
+
"is_encoder_decoder": false,
|
| 239 |
+
"label2id": {
|
| 240 |
+
"LABEL_0": 0,
|
| 241 |
+
"LABEL_1": 1
|
| 242 |
+
},
|
| 243 |
+
"layer_norm_eps": 1e-06,
|
| 244 |
+
"length_penalty": 1.0,
|
| 245 |
+
"max_length": 20,
|
| 246 |
+
"min_length": 0,
|
| 247 |
+
"model_type": "siglip_vision_model",
|
| 248 |
+
"no_repeat_ngram_size": 0,
|
| 249 |
+
"num_attention_heads": 16,
|
| 250 |
+
"num_beam_groups": 1,
|
| 251 |
+
"num_beams": 1,
|
| 252 |
+
"num_channels": 3,
|
| 253 |
+
"num_hidden_layers": 27,
|
| 254 |
+
"num_image_tokens": 256,
|
| 255 |
+
"num_return_sequences": 1,
|
| 256 |
+
"output_attentions": false,
|
| 257 |
+
"output_hidden_states": false,
|
| 258 |
+
"output_scores": false,
|
| 259 |
+
"pad_token_id": null,
|
| 260 |
+
"patch_size": 14,
|
| 261 |
+
"prefix": null,
|
| 262 |
+
"problem_type": null,
|
| 263 |
+
"projection_dim": 2048,
|
| 264 |
+
"projector_hidden_act": "gelu_fast",
|
| 265 |
+
"pruned_heads": {},
|
| 266 |
+
"remove_invalid_values": false,
|
| 267 |
+
"repetition_penalty": 1.0,
|
| 268 |
+
"return_dict": true,
|
| 269 |
+
"return_dict_in_generate": false,
|
| 270 |
+
"sep_token_id": null,
|
| 271 |
+
"suppress_tokens": null,
|
| 272 |
+
"task_specific_params": null,
|
| 273 |
+
"temperature": 1.0,
|
| 274 |
+
"tf_legacy_loss": false,
|
| 275 |
+
"tie_encoder_decoder": false,
|
| 276 |
+
"tie_word_embeddings": true,
|
| 277 |
+
"tokenizer_class": null,
|
| 278 |
+
"top_k": 50,
|
| 279 |
+
"top_p": 1.0,
|
| 280 |
+
"torch_dtype": "bfloat16",
|
| 281 |
+
"torchscript": false,
|
| 282 |
+
"typical_p": 1.0,
|
| 283 |
+
"use_bfloat16": false,
|
| 284 |
+
"vision_use_head": false
|
| 285 |
+
},
|
| 286 |
+
"vision_tower_lr": null,
|
| 287 |
+
"weight_memory_efficient": true,
|
| 288 |
+
"xvila_mode": false,
|
| 289 |
+
"auto_map": {
|
| 290 |
+
"AutoProcessor": "auto_processor.VILAProcessor",
|
| 291 |
+
"AutoConfig": "modeling_vila.VILAConfig",
|
| 292 |
+
"AutoModel": "modeling_vila.VILAForCausalLM",
|
| 293 |
+
"AutoModelForCausalLM": "modeling_vila.VILAForCausalLM"
|
| 294 |
+
}
|
| 295 |
+
}
|
configuration_vila.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import os.path as osp
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
from threading import Thread
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torchvision
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from transformers import (
|
| 13 |
+
AutoProcessor,
|
| 14 |
+
PretrainedConfig,
|
| 15 |
+
PreTrainedModel,
|
| 16 |
+
Qwen2Config,
|
| 17 |
+
Qwen2ForCausalLM,
|
| 18 |
+
Qwen2PreTrainedModel,
|
| 19 |
+
TextIteratorStreamer,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class VILAConfig(PretrainedConfig):
|
| 24 |
+
model_type = "vila"
|
| 25 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
llm_cfg=None,
|
| 30 |
+
vision_tower_cfg=None,
|
| 31 |
+
mm_projector_cfg=None,
|
| 32 |
+
architectures=None,
|
| 33 |
+
resume_path=None,
|
| 34 |
+
hidden_size=None,
|
| 35 |
+
mm_hidden_size=None,
|
| 36 |
+
image_aspect_ratio=None,
|
| 37 |
+
num_video_frames=None,
|
| 38 |
+
fps=None,
|
| 39 |
+
mm_vision_select_layer=None,
|
| 40 |
+
mm_vision_select_feature=None,
|
| 41 |
+
mm_use_im_start_end=False,
|
| 42 |
+
mm_use_im_patch_token=False,
|
| 43 |
+
mm_projector_lr=None,
|
| 44 |
+
vision_tower_lr=None,
|
| 45 |
+
vision_resolution=None,
|
| 46 |
+
interpolate_mode=None,
|
| 47 |
+
s2=None,
|
| 48 |
+
dynamic_s2=None,
|
| 49 |
+
s2_scales=None,
|
| 50 |
+
s2_max_split_size=None,
|
| 51 |
+
s2_resize_output_to_scale_idx=0,
|
| 52 |
+
min_tiles: Optional[int] = 1,
|
| 53 |
+
max_tiles: Optional[int] = 12,
|
| 54 |
+
num_time_tokens=None,
|
| 55 |
+
time_token_format=None,
|
| 56 |
+
image_encoder: str = '{"_target_": "llava.model.encoders.BasicImageEncoder"}',
|
| 57 |
+
video_encoder: str = '{"_target_": "llava.model.encoders.BasicVideoEncoder"}',
|
| 58 |
+
**kwargs,
|
| 59 |
+
):
|
| 60 |
+
super().__init__(**kwargs)
|
| 61 |
+
|
| 62 |
+
self.architectures = architectures
|
| 63 |
+
self.llm_cfg = llm_cfg
|
| 64 |
+
self.vision_tower_cfg = vision_tower_cfg
|
| 65 |
+
self.mm_projector_cfg = mm_projector_cfg
|
| 66 |
+
self.resume_path = resume_path
|
| 67 |
+
|
| 68 |
+
self.hidden_size = hidden_size
|
| 69 |
+
self.mm_hidden_size = mm_hidden_size
|
| 70 |
+
self.image_aspect_ratio = image_aspect_ratio
|
| 71 |
+
self.num_video_frames = num_video_frames
|
| 72 |
+
self.fps = fps
|
| 73 |
+
self.mm_vision_select_layer = mm_vision_select_layer
|
| 74 |
+
self.mm_vision_select_feature = mm_vision_select_feature
|
| 75 |
+
self.mm_use_im_start_end = mm_use_im_start_end
|
| 76 |
+
self.mm_use_im_patch_token = mm_use_im_patch_token
|
| 77 |
+
self.mm_projector_lr = mm_projector_lr
|
| 78 |
+
self.vision_tower_lr = vision_tower_lr
|
| 79 |
+
self.vision_resolution = vision_resolution
|
| 80 |
+
self.interpolate_mode = interpolate_mode
|
| 81 |
+
self.s2 = s2
|
| 82 |
+
self.dynamic_s2 = dynamic_s2
|
| 83 |
+
self.s2_scales = s2_scales
|
| 84 |
+
self.s2_max_split_size = s2_max_split_size
|
| 85 |
+
self.s2_resize_output_to_scale_idx = s2_resize_output_to_scale_idx
|
| 86 |
+
self.min_tiles = min_tiles
|
| 87 |
+
self.max_tiles = max_tiles
|
| 88 |
+
self.num_time_tokens = num_time_tokens
|
| 89 |
+
self.time_token_format = time_token_format
|
| 90 |
+
|
| 91 |
+
self.image_encoder = image_encoder
|
| 92 |
+
self.video_encoder = video_encoder
|
constants.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
# This file is modified from https://github.com/haotian-liu/LLaVA/
|
| 18 |
+
|
| 19 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
| 20 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
| 21 |
+
|
| 22 |
+
LOGDIR = "."
|
| 23 |
+
|
| 24 |
+
# Model Constants
|
| 25 |
+
IGNORE_INDEX = -100
|
| 26 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
| 27 |
+
DEFAULT_SOUND_TOKEN = "<sound>"
|
| 28 |
+
DEFAULT_SPEECH_TOKEN = "<speech>"
|
| 29 |
+
SENTINEL_TOKEN = "<vila/sentinel>"
|
| 30 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
| 31 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
SENTINEL_TOKEN = "<vila/sentinel>"
|
| 35 |
+
|
| 36 |
+
MEDIA_TOKENS = {
|
| 37 |
+
"image": "<image>",
|
| 38 |
+
"video": "<vila/video>",
|
| 39 |
+
"speech": "<speech>",
|
| 40 |
+
"sound": "<sound>",
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
# <image> <vila/video> <vila/sentinel>
|
| 44 |
+
"""
|
| 45 |
+
vila:
|
| 46 |
+
151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 47 |
+
151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 48 |
+
151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 49 |
+
151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 50 |
+
151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 51 |
+
151648: AddedToken("<vila/sentinel>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 52 |
+
151649: AddedToken("<image>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 53 |
+
151650: AddedToken("<vila/video>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 54 |
+
|
| 55 |
+
xvila:
|
| 56 |
+
151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 57 |
+
151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 58 |
+
151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 59 |
+
151646: AddedToken("[BOS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 60 |
+
151647: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 61 |
+
151648: AddedToken("<vila/sentinel>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 62 |
+
151649: AddedToken("<image>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 63 |
+
151650: AddedToken("<vila/video>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 64 |
+
151651: AddedToken("<speech>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 65 |
+
151652: AddedToken("<sound>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 66 |
+
151653: AddedToken("<|image_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 67 |
+
151654: AddedToken("<|image_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 68 |
+
151655: AddedToken("<|video_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 69 |
+
151656: AddedToken("<|video_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 70 |
+
151657: AddedToken("<|speech_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 71 |
+
151658: AddedToken("<|speech_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 72 |
+
151659: AddedToken("<|sound_bos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 73 |
+
151660: AddedToken("<|sound_eos|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
|
| 74 |
+
"""
|
| 75 |
+
MM_BOS_EOS_TOKENS = {
|
| 76 |
+
"image": ["<|image_bos|>", "<|image_eos|>"],
|
| 77 |
+
"video": ["<|video_bos|>", "<|video_eos|>"],
|
| 78 |
+
"speech": ["<|speech_bos|>", "<|speech_eos|>"],
|
| 79 |
+
"sound": ["<|sound_bos|>", "<|sound_eos|>"],
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
NUM_EXTRA_TOKENS_VILA = 8
|
| 83 |
+
NUM_EXTRA_TOKENS_XVILA = 10
|
conversation.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
# This file is modified from https://github.com/haotian-liu/LLaVA/
|
| 17 |
+
|
| 18 |
+
import dataclasses
|
| 19 |
+
from enum import Enum, auto
|
| 20 |
+
from typing import List
|
| 21 |
+
|
| 22 |
+
# from llava.utils.logging import logger
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SeparatorStyle(Enum):
|
| 26 |
+
"""Different separator style."""
|
| 27 |
+
|
| 28 |
+
AUTO = auto()
|
| 29 |
+
TWO = auto()
|
| 30 |
+
MPT = auto()
|
| 31 |
+
PLAIN = auto()
|
| 32 |
+
LLAMA_3 = auto()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclasses.dataclass
|
| 36 |
+
class Conversation:
|
| 37 |
+
"""A class that keeps all conversation history."""
|
| 38 |
+
|
| 39 |
+
system: str
|
| 40 |
+
roles: List[str]
|
| 41 |
+
messages: List[List[str]]
|
| 42 |
+
sep_style: SeparatorStyle = SeparatorStyle.AUTO
|
| 43 |
+
sep: str = "###"
|
| 44 |
+
sep2: str = None
|
| 45 |
+
version: str = "Unknown"
|
| 46 |
+
|
| 47 |
+
def get_prompt(self):
|
| 48 |
+
messages = self.messages
|
| 49 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
| 50 |
+
messages = self.messages.copy()
|
| 51 |
+
init_role, init_msg = messages[0].copy()
|
| 52 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
| 53 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
| 54 |
+
|
| 55 |
+
if self.sep_style == SeparatorStyle.TWO:
|
| 56 |
+
seps = [self.sep, self.sep2]
|
| 57 |
+
ret = self.system + seps[0]
|
| 58 |
+
for i, (role, message) in enumerate(messages):
|
| 59 |
+
if message:
|
| 60 |
+
if type(message) is tuple:
|
| 61 |
+
message, _, _ = message
|
| 62 |
+
ret += role + ": " + message + seps[i % 2]
|
| 63 |
+
else:
|
| 64 |
+
ret += role + ":"
|
| 65 |
+
elif self.sep_style == SeparatorStyle.LLAMA_3:
|
| 66 |
+
ret = self.system + self.sep
|
| 67 |
+
for rid, (role, message) in enumerate(messages):
|
| 68 |
+
if message:
|
| 69 |
+
if type(message) is tuple:
|
| 70 |
+
message = message[0]
|
| 71 |
+
sep = self.sep if rid < len(messages) - 1 else self.sep2
|
| 72 |
+
ret += role + message + sep
|
| 73 |
+
else:
|
| 74 |
+
ret += role
|
| 75 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
| 76 |
+
ret = self.system + self.sep
|
| 77 |
+
for role, message in messages:
|
| 78 |
+
if message:
|
| 79 |
+
if type(message) is tuple:
|
| 80 |
+
message, _, _ = message
|
| 81 |
+
ret += role + message + self.sep
|
| 82 |
+
else:
|
| 83 |
+
ret += role
|
| 84 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
| 85 |
+
seps = [self.sep, self.sep2]
|
| 86 |
+
ret = self.system
|
| 87 |
+
for i, (role, message) in enumerate(messages):
|
| 88 |
+
if message:
|
| 89 |
+
if type(message) is tuple:
|
| 90 |
+
message, _, _ = message
|
| 91 |
+
ret += message + seps[i % 2]
|
| 92 |
+
else:
|
| 93 |
+
ret += ""
|
| 94 |
+
else:
|
| 95 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
| 96 |
+
|
| 97 |
+
return ret
|
| 98 |
+
|
| 99 |
+
def append_message(self, role, message):
|
| 100 |
+
self.messages.append([role, message])
|
| 101 |
+
|
| 102 |
+
def copy(self):
|
| 103 |
+
return Conversation(
|
| 104 |
+
system=self.system,
|
| 105 |
+
roles=self.roles,
|
| 106 |
+
messages=[[x, y] for x, y in self.messages],
|
| 107 |
+
sep_style=self.sep_style,
|
| 108 |
+
sep=self.sep,
|
| 109 |
+
sep2=self.sep2,
|
| 110 |
+
version=self.version,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
conv_auto = Conversation(
|
| 115 |
+
system="",
|
| 116 |
+
roles=("", ""),
|
| 117 |
+
messages=(),
|
| 118 |
+
sep_style=SeparatorStyle.AUTO,
|
| 119 |
+
sep="\n",
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
conv_vicuna_v1 = Conversation(
|
| 123 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
| 124 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
| 125 |
+
roles=("USER", "ASSISTANT"),
|
| 126 |
+
version="v1",
|
| 127 |
+
messages=(),
|
| 128 |
+
sep_style=SeparatorStyle.TWO,
|
| 129 |
+
sep=" ",
|
| 130 |
+
sep2="</s>",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
conv_llava_plain = Conversation(
|
| 134 |
+
system="",
|
| 135 |
+
roles=("", ""),
|
| 136 |
+
messages=(),
|
| 137 |
+
sep_style=SeparatorStyle.PLAIN,
|
| 138 |
+
sep="\n",
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
hermes_2 = Conversation(
|
| 142 |
+
system="<|im_start|>system\nAnswer the questions.",
|
| 143 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
| 144 |
+
sep_style=SeparatorStyle.MPT,
|
| 145 |
+
sep="<|im_end|>",
|
| 146 |
+
messages=(),
|
| 147 |
+
version="hermes-2",
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template.
|
| 151 |
+
llama_3_chat = Conversation(
|
| 152 |
+
system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
|
| 153 |
+
"You are able to understand the visual content that the user provides, "
|
| 154 |
+
"and assist the user with a variety of tasks using natural language.",
|
| 155 |
+
roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
|
| 156 |
+
version="llama_v3",
|
| 157 |
+
messages=(),
|
| 158 |
+
sep_style=SeparatorStyle.LLAMA_3,
|
| 159 |
+
sep="<|eot_id|>",
|
| 160 |
+
sep2="<|end_of_text|>",
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
default_conversation = conv_auto
|
| 165 |
+
conv_templates = {
|
| 166 |
+
"auto": conv_auto,
|
| 167 |
+
"hermes-2": hermes_2,
|
| 168 |
+
"llama_3": llama_3_chat,
|
| 169 |
+
"v1": conv_vicuna_v1,
|
| 170 |
+
"vicuna_v1": conv_vicuna_v1,
|
| 171 |
+
"plain": conv_llava_plain,
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
CONVERSATION_MODE_MAPPING = {
|
| 176 |
+
"vila1.5-3b": "vicuna_v1",
|
| 177 |
+
"vila1.5-8b": "llama_3",
|
| 178 |
+
"vila1.5-13b": "vicuna_v1",
|
| 179 |
+
"vila1.5-40b": "hermes-2",
|
| 180 |
+
"llama-3": "llama_3",
|
| 181 |
+
"llama3": "llama_3",
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def auto_set_conversation_mode(model_name_or_path: str) -> str:
|
| 186 |
+
global default_conversation
|
| 187 |
+
for k, v in CONVERSATION_MODE_MAPPING.items():
|
| 188 |
+
if k in model_name_or_path.lower():
|
| 189 |
+
print(f"Setting conversation mode to `{v}` based on model name/path `{model_name_or_path}`.")
|
| 190 |
+
default_conversation = conv_templates[v]
|
| 191 |
+
return
|
distributed.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import warnings
|
| 3 |
+
from typing import Any, List, Optional
|
| 4 |
+
|
| 5 |
+
from torch import distributed as dist
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"init",
|
| 9 |
+
"is_initialized",
|
| 10 |
+
"size",
|
| 11 |
+
"rank",
|
| 12 |
+
"local_size",
|
| 13 |
+
"local_rank",
|
| 14 |
+
"is_main",
|
| 15 |
+
"barrier",
|
| 16 |
+
"gather",
|
| 17 |
+
"all_gather",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def init() -> None:
|
| 22 |
+
if "RANK" not in os.environ:
|
| 23 |
+
warnings.warn("Environment variable `RANK` is not set. Skipping distributed initialization.")
|
| 24 |
+
return
|
| 25 |
+
dist.init_process_group(backend="nccl", init_method="env://")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def is_initialized() -> bool:
|
| 29 |
+
return dist.is_initialized()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def size() -> int:
|
| 33 |
+
return int(os.environ.get("WORLD_SIZE", 1))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def rank() -> int:
|
| 37 |
+
return int(os.environ.get("RANK", 0))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def local_size() -> int:
|
| 41 |
+
return int(os.environ.get("LOCAL_WORLD_SIZE", 1))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def local_rank() -> int:
|
| 45 |
+
return int(os.environ.get("LOCAL_RANK", 0))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def is_main() -> bool:
|
| 49 |
+
return rank() == 0
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def barrier() -> None:
|
| 53 |
+
dist.barrier()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def gather(obj: Any, dst: int = 0) -> Optional[List[Any]]:
|
| 57 |
+
if not is_initialized():
|
| 58 |
+
return [obj]
|
| 59 |
+
if is_main():
|
| 60 |
+
objs = [None for _ in range(size())]
|
| 61 |
+
dist.gather_object(obj, objs, dst=dst)
|
| 62 |
+
return objs
|
| 63 |
+
else:
|
| 64 |
+
dist.gather_object(obj, dst=dst)
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def all_gather(obj: Any) -> List[Any]:
|
| 69 |
+
if not is_initialized():
|
| 70 |
+
return [obj]
|
| 71 |
+
objs = [None for _ in range(size())]
|
| 72 |
+
dist.all_gather_object(objs, obj)
|
| 73 |
+
return objs
|
loss.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.nn.functional import cross_entropy
|
| 5 |
+
|
| 6 |
+
from .constants import IGNORE_INDEX
|
| 7 |
+
|
| 8 |
+
__all__ = ["soft_cross_entropy"]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def soft_cross_entropy(
|
| 12 |
+
outputs: torch.Tensor,
|
| 13 |
+
targets: torch.Tensor,
|
| 14 |
+
soft_tokens: Union[torch.Tensor, List[int]],
|
| 15 |
+
std: float = 1,
|
| 16 |
+
ignore_index: int = IGNORE_INDEX,
|
| 17 |
+
) -> torch.Tensor:
|
| 18 |
+
# Remove last token from outputs and first token from targets
|
| 19 |
+
outputs = outputs[..., :-1, :].contiguous()
|
| 20 |
+
targets = targets[..., 1:].contiguous()
|
| 21 |
+
|
| 22 |
+
# Flatten outputs and targets
|
| 23 |
+
targets = targets.view(-1)
|
| 24 |
+
outputs = outputs.view(targets.size(0), -1)
|
| 25 |
+
|
| 26 |
+
# Remove outputs and targets with ignore_index
|
| 27 |
+
indices = targets != ignore_index
|
| 28 |
+
outputs = outputs[indices]
|
| 29 |
+
targets = targets[indices]
|
| 30 |
+
|
| 31 |
+
# Convert soft token IDs to tensor
|
| 32 |
+
if isinstance(soft_tokens, list):
|
| 33 |
+
soft_tokens = torch.tensor(soft_tokens).to(targets)
|
| 34 |
+
|
| 35 |
+
# Calculate loss for non-soft tokens
|
| 36 |
+
indices = torch.isin(targets, soft_tokens, invert=True)
|
| 37 |
+
loss = cross_entropy(outputs[indices], targets[indices], reduction="sum")
|
| 38 |
+
|
| 39 |
+
# Calculate loss for soft tokens
|
| 40 |
+
indices = torch.isin(targets, soft_tokens)
|
| 41 |
+
targets_indices = torch.zeros_like(outputs[indices])
|
| 42 |
+
for k, target in enumerate(targets[indices]):
|
| 43 |
+
dist = torch.exp(-((target - soft_tokens) ** 2) / (2 * std**2))
|
| 44 |
+
targets_indices[k][soft_tokens] = dist / dist.sum()
|
| 45 |
+
loss += cross_entropy(outputs[indices], targets_indices, reduction="sum")
|
| 46 |
+
|
| 47 |
+
# Return average loss
|
| 48 |
+
return loss / targets.size(0)
|
media.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from typing import Any, Dict, List, Optional, Union
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import PIL
|
| 9 |
+
import PIL.Image
|
| 10 |
+
import requests
|
| 11 |
+
from transformers import PretrainedConfig
|
| 12 |
+
|
| 13 |
+
# from llava.constants import MEDIA_TOKENS
|
| 14 |
+
# from llava.media import Image, Video
|
| 15 |
+
# from llava.utils import make_list
|
| 16 |
+
# from llava.utils.logging import logger
|
| 17 |
+
|
| 18 |
+
MEDIA_TOKENS = {
|
| 19 |
+
"image": "<image>",
|
| 20 |
+
"video": "<vila/video>",
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Media:
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class File(Media):
|
| 29 |
+
def __init__(self, path: str) -> None:
|
| 30 |
+
self.path = path
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Image(File):
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Video(File):
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def make_list(obj: Any) -> List:
|
| 42 |
+
return obj if isinstance(obj, list) else [obj]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image:
|
| 46 |
+
if isinstance(image, Image):
|
| 47 |
+
if image.path.startswith("http://") or image.path.startswith("https://"):
|
| 48 |
+
image = PIL.Image.open(requests.get(image.path, stream=True).raw)
|
| 49 |
+
else:
|
| 50 |
+
image = PIL.Image.open(image.path)
|
| 51 |
+
return image
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]:
|
| 55 |
+
# Load video frames from a directory
|
| 56 |
+
if os.path.isdir(video_path):
|
| 57 |
+
frame_paths = sorted(glob.glob(os.path.join(video_path, "*")))
|
| 58 |
+
indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int)
|
| 59 |
+
return [PIL.Image.open(frame_paths[index]) for index in indices]
|
| 60 |
+
|
| 61 |
+
# Load video frames from a video file
|
| 62 |
+
vidcap = cv2.VideoCapture(video_path)
|
| 63 |
+
|
| 64 |
+
# Find the last frame as frame count might not be accurate
|
| 65 |
+
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 66 |
+
while frame_count > 0:
|
| 67 |
+
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1)
|
| 68 |
+
if vidcap.grab():
|
| 69 |
+
break
|
| 70 |
+
frame_count -= 1
|
| 71 |
+
else:
|
| 72 |
+
raise ValueError(f"Video '{video_path}' has no frames.")
|
| 73 |
+
|
| 74 |
+
# Extract frames uniformly
|
| 75 |
+
indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int)
|
| 76 |
+
frames = {}
|
| 77 |
+
for index in indices:
|
| 78 |
+
if index in frames:
|
| 79 |
+
continue
|
| 80 |
+
vidcap.set(cv2.CAP_PROP_POS_FRAMES, index)
|
| 81 |
+
success, frame = vidcap.read()
|
| 82 |
+
if not success:
|
| 83 |
+
print(f"Failed to read frame {index} from video '{video_path}'. Skipped.")
|
| 84 |
+
continue
|
| 85 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 86 |
+
frames[index] = PIL.Image.fromarray(frame)
|
| 87 |
+
return [frames[index] for index in indices if index in frames]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _extract_video(video, config: PretrainedConfig) -> List[PIL.Image.Image]:
|
| 91 |
+
num_frames = config.num_video_frames
|
| 92 |
+
video_path = video.path if isinstance(video, Video) else video["path"]
|
| 93 |
+
frames = _load_video(video_path, num_frames=num_frames)
|
| 94 |
+
return frames
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def extract_media(
|
| 98 |
+
messages: List[Dict[str, Any]],
|
| 99 |
+
config: Optional[PretrainedConfig] = None,
|
| 100 |
+
draft: bool = False,
|
| 101 |
+
) -> Dict[str, List[Any]]:
|
| 102 |
+
media = defaultdict(list)
|
| 103 |
+
for message in messages:
|
| 104 |
+
text = ""
|
| 105 |
+
for part in make_list(message["value"]):
|
| 106 |
+
if isinstance(part, str):
|
| 107 |
+
for token in MEDIA_TOKENS.values():
|
| 108 |
+
if token in part:
|
| 109 |
+
print(f"Media token '{token}' found in text: '{part}'. Removed.")
|
| 110 |
+
part = part.replace(token, "").strip()
|
| 111 |
+
text += part
|
| 112 |
+
elif isinstance(part, (Image, PIL.Image.Image)):
|
| 113 |
+
if draft:
|
| 114 |
+
media["image"].append(part)
|
| 115 |
+
else:
|
| 116 |
+
media["image"].append(_extract_image(part))
|
| 117 |
+
text += MEDIA_TOKENS["image"]
|
| 118 |
+
elif isinstance(part, dict) or isinstance(part, Video):
|
| 119 |
+
if draft:
|
| 120 |
+
media["video"].append(part)
|
| 121 |
+
else:
|
| 122 |
+
media["video"].append(_extract_video(part, config))
|
| 123 |
+
text += MEDIA_TOKENS["video"]
|
| 124 |
+
else:
|
| 125 |
+
raise ValueError(f"Unsupported prompt part type: {type(part)}")
|
| 126 |
+
message["value"] = text
|
| 127 |
+
|
| 128 |
+
if MEDIA_TOKENS["video"] in messages[0]["value"]:
|
| 129 |
+
messages[0]["value"] = "<vila/video>" + messages[0]["value"].replace("<vila/video>", "")
|
| 130 |
+
return media
|
media_encoder.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BaseEncoder(nn.Module):
|
| 9 |
+
def __init__(self, parent: nn.Module) -> None:
|
| 10 |
+
super().__init__()
|
| 11 |
+
self._parent = [parent]
|
| 12 |
+
|
| 13 |
+
@property
|
| 14 |
+
def parent(self) -> nn.Module:
|
| 15 |
+
return self._parent[0]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class BasicImageEncoder(BaseEncoder):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
parent: torch.nn.Module,
|
| 22 |
+
start_tokens: Optional[str] = None,
|
| 23 |
+
end_tokens: Optional[str] = "\n",
|
| 24 |
+
) -> None:
|
| 25 |
+
super().__init__(parent)
|
| 26 |
+
self.start_tokens = start_tokens
|
| 27 |
+
self.end_tokens = end_tokens
|
| 28 |
+
|
| 29 |
+
def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
|
| 30 |
+
if tokens is None:
|
| 31 |
+
return None
|
| 32 |
+
token_ids = self.parent.tokenizer(tokens).input_ids
|
| 33 |
+
token_ids = torch.tensor(token_ids, device=self.parent.device)
|
| 34 |
+
return self.parent.llm_model_embed_tokens(token_ids)
|
| 35 |
+
|
| 36 |
+
def _process_features(
|
| 37 |
+
self,
|
| 38 |
+
features: torch.Tensor,
|
| 39 |
+
start_token_embeds: Optional[torch.Tensor],
|
| 40 |
+
end_token_embeds: Optional[torch.Tensor],
|
| 41 |
+
) -> torch.Tensor:
|
| 42 |
+
if start_token_embeds is not None:
|
| 43 |
+
features = torch.cat([start_token_embeds, features], dim=0)
|
| 44 |
+
if end_token_embeds is not None:
|
| 45 |
+
features = torch.cat([features, end_token_embeds], dim=0)
|
| 46 |
+
return features
|
| 47 |
+
|
| 48 |
+
def forward(self, images: List[torch.Tensor], config: Dict[str, Any], device: torch.device) -> List[torch.Tensor]:
|
| 49 |
+
images = torch.stack(images, dim=0)
|
| 50 |
+
features = self.parent.encode_images(images, block_sizes=config.get("block_sizes"))
|
| 51 |
+
process_features = partial(
|
| 52 |
+
self._process_features,
|
| 53 |
+
start_token_embeds=self.embed_tokens(self.start_tokens),
|
| 54 |
+
end_token_embeds=self.embed_tokens(self.end_tokens),
|
| 55 |
+
)
|
| 56 |
+
return [process_features(f).to(device) for f in features]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class BasicVideoEncoder(BaseEncoder):
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
parent: torch.nn.Module,
|
| 63 |
+
start_tokens: Optional[str] = None,
|
| 64 |
+
end_tokens: Optional[str] = "\n",
|
| 65 |
+
) -> None:
|
| 66 |
+
super().__init__(parent)
|
| 67 |
+
self.start_tokens = start_tokens
|
| 68 |
+
self.end_tokens = end_tokens
|
| 69 |
+
|
| 70 |
+
def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]:
|
| 71 |
+
if tokens is None:
|
| 72 |
+
return None
|
| 73 |
+
token_ids = self.parent.tokenizer(tokens).input_ids
|
| 74 |
+
token_ids = torch.tensor(token_ids, device=self.parent.device)
|
| 75 |
+
return self.parent.llm_model_embed_tokens(token_ids)
|
| 76 |
+
|
| 77 |
+
def _process_features(
|
| 78 |
+
self,
|
| 79 |
+
features: torch.Tensor,
|
| 80 |
+
start_token_embeds: Optional[torch.Tensor],
|
| 81 |
+
end_token_embeds: Optional[torch.Tensor],
|
| 82 |
+
) -> torch.Tensor:
|
| 83 |
+
if start_token_embeds is not None:
|
| 84 |
+
start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0)
|
| 85 |
+
features = torch.cat([start_embeds, features], dim=1)
|
| 86 |
+
if end_token_embeds is not None:
|
| 87 |
+
end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0)
|
| 88 |
+
features = torch.cat([features, end_embeds], dim=1)
|
| 89 |
+
return features.flatten(0, 1)
|
| 90 |
+
|
| 91 |
+
def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
|
| 92 |
+
num_frames = [video.shape[0] for video in videos]
|
| 93 |
+
images = torch.cat(videos, dim=0)
|
| 94 |
+
features = self.parent.encode_images(images)
|
| 95 |
+
features = torch.split(features, num_frames)
|
| 96 |
+
process_features = partial(
|
| 97 |
+
self._process_features,
|
| 98 |
+
start_token_embeds=self.embed_tokens(self.start_tokens),
|
| 99 |
+
end_token_embeds=self.embed_tokens(self.end_tokens),
|
| 100 |
+
)
|
| 101 |
+
return [process_features(f) for f in features]
|
| 102 |
+
|
| 103 |
+
def pool(x: torch.Tensor, size: int, dim: int) -> torch.Tensor:
|
| 104 |
+
if x.shape[dim] % size == 0:
|
| 105 |
+
return x.view(x.shape[:dim] + (-1, size) + x.shape[dim + 1 :]).mean(dim + 1)
|
| 106 |
+
else:
|
| 107 |
+
return x.narrow(dim, start=0, length=1)
|
| 108 |
+
|
| 109 |
+
class TSPVideoEncoder(BasicVideoEncoder):
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
parent: torch.nn.Module,
|
| 113 |
+
#pool_sizes: List[Tuple[int, int, int]],
|
| 114 |
+
start_tokens: Optional[str] = None,
|
| 115 |
+
end_tokens: Optional[str] = "\n",
|
| 116 |
+
sep_tokens: Optional[str] = None,
|
| 117 |
+
) -> None:
|
| 118 |
+
super().__init__(parent, start_tokens=start_tokens, end_tokens=end_tokens)
|
| 119 |
+
self.pool_sizes = [[8, 1, 1]] #pool_sizes
|
| 120 |
+
self.sep_tokens = sep_tokens
|
| 121 |
+
|
| 122 |
+
def _process_features(
|
| 123 |
+
self,
|
| 124 |
+
inputs: torch.Tensor,
|
| 125 |
+
start_token_embeds: Optional[torch.Tensor],
|
| 126 |
+
end_token_embeds: Optional[torch.Tensor],
|
| 127 |
+
sep_token_embeds: Optional[torch.Tensor],
|
| 128 |
+
) -> torch.Tensor:
|
| 129 |
+
nt, ns = inputs.shape[:2]
|
| 130 |
+
nl = int(ns**0.5)
|
| 131 |
+
outputs = []
|
| 132 |
+
for pool_size in self.pool_sizes:
|
| 133 |
+
features = inputs.view(nt, nl, nl, -1)
|
| 134 |
+
for dim, p in enumerate(pool_size):
|
| 135 |
+
features = pool(features, p, dim=dim)
|
| 136 |
+
features = features.flatten(1, 2)
|
| 137 |
+
features = super()._process_features(
|
| 138 |
+
features,
|
| 139 |
+
start_token_embeds=start_token_embeds,
|
| 140 |
+
end_token_embeds=end_token_embeds,
|
| 141 |
+
)
|
| 142 |
+
if sep_token_embeds is not None:
|
| 143 |
+
features = torch.cat([features, sep_token_embeds], dim=0)
|
| 144 |
+
outputs.append(features)
|
| 145 |
+
return torch.cat(outputs, dim=0)
|
| 146 |
+
|
| 147 |
+
def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]:
|
| 148 |
+
num_frames = [video.shape[0] for video in videos]
|
| 149 |
+
images = torch.cat(videos, dim=0)
|
| 150 |
+
features = self.parent.encode_images(images)
|
| 151 |
+
features = torch.split(features, num_frames)
|
| 152 |
+
process_features = partial(
|
| 153 |
+
self._process_features,
|
| 154 |
+
start_token_embeds=self.embed_tokens(self.start_tokens),
|
| 155 |
+
end_token_embeds=self.embed_tokens(self.end_tokens),
|
| 156 |
+
sep_token_embeds=self.embed_tokens(self.sep_tokens),
|
| 157 |
+
)
|
| 158 |
+
return [process_features(f) for f in features]
|
mm_utils.py
ADDED
|
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
# dynamic_preprocess and find_closest_aspect_ratio are referenced from https://github.com/OpenGVLab/InternVL
|
| 18 |
+
|
| 19 |
+
import base64
|
| 20 |
+
import os
|
| 21 |
+
import tempfile
|
| 22 |
+
from io import BytesIO
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
from PIL import Image
|
| 27 |
+
from transformers import StoppingCriteria
|
| 28 |
+
|
| 29 |
+
from .constants import DEFAULT_IMAGE_TOKEN
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
|
| 33 |
+
import cv2
|
| 34 |
+
|
| 35 |
+
if fps == None or frame_count == None:
|
| 36 |
+
# if one of fps or frame_count is None, still recompute
|
| 37 |
+
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
| 38 |
+
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 39 |
+
if fps == 0 or frame_count == 0:
|
| 40 |
+
print(f"Video file not found. return empty images. {video_file_name}")
|
| 41 |
+
return [
|
| 42 |
+
Image.new("RGB", (720, 720)),
|
| 43 |
+
] * num_frames, 0
|
| 44 |
+
|
| 45 |
+
duration = frame_count / fps
|
| 46 |
+
frame_interval = frame_count // num_frames
|
| 47 |
+
if frame_interval == 0 and frame_count <= 1:
|
| 48 |
+
print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
|
| 49 |
+
return [
|
| 50 |
+
Image.new("RGB", (720, 720)),
|
| 51 |
+
] * num_frames, 0
|
| 52 |
+
# print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval)
|
| 53 |
+
|
| 54 |
+
images = []
|
| 55 |
+
count = 0
|
| 56 |
+
success = True
|
| 57 |
+
frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
|
| 58 |
+
while success:
|
| 59 |
+
# print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval)
|
| 60 |
+
if frame_count >= num_frames:
|
| 61 |
+
success, frame = vidcap.read()
|
| 62 |
+
if count in frame_indices:
|
| 63 |
+
try:
|
| 64 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 65 |
+
im_pil = Image.fromarray(img)
|
| 66 |
+
images.append(im_pil)
|
| 67 |
+
except BaseException:
|
| 68 |
+
continue
|
| 69 |
+
if len(images) >= num_frames:
|
| 70 |
+
return images, num_frames
|
| 71 |
+
count += 1
|
| 72 |
+
else:
|
| 73 |
+
# Left padding frames if the video is not long enough
|
| 74 |
+
success, frame = vidcap.read()
|
| 75 |
+
if success:
|
| 76 |
+
try:
|
| 77 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 78 |
+
im_pil = Image.fromarray(img)
|
| 79 |
+
images.append(im_pil)
|
| 80 |
+
except BaseException:
|
| 81 |
+
continue
|
| 82 |
+
count += 1
|
| 83 |
+
else:
|
| 84 |
+
break
|
| 85 |
+
if len(images) == 0:
|
| 86 |
+
raise ValueError("Did not find enough frames in the video. return empty image.")
|
| 87 |
+
|
| 88 |
+
return images, len(images)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_frame_from_vcap_with_fps(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
|
| 92 |
+
"""
|
| 93 |
+
num_frames is the max number of frames the model can support.
|
| 94 |
+
frame_count is the number of frames in the input video.
|
| 95 |
+
max_fps is the max FPS of the model can support.
|
| 96 |
+
fps is the fps of the input video.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
import random
|
| 100 |
+
|
| 101 |
+
import cv2
|
| 102 |
+
|
| 103 |
+
if fps == None or frame_count == None:
|
| 104 |
+
# if one of fps or frame_count is None, still recompute
|
| 105 |
+
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
| 106 |
+
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 107 |
+
|
| 108 |
+
if fps == 0 or frame_count == 0:
|
| 109 |
+
print(f"Video file not found. return empty images. {video_file_name}")
|
| 110 |
+
empty_video_frames = int(random.uniform(2, 8 * max_fps))
|
| 111 |
+
return [
|
| 112 |
+
Image.new("RGB", (720, 720)),
|
| 113 |
+
] * empty_video_frames, 0
|
| 114 |
+
|
| 115 |
+
duration = frame_count / fps
|
| 116 |
+
# print("duration:", duration, "frames:", frame_count, "fps:", fps, "num_frames:", num_frames, "max_fps:", max_fps)
|
| 117 |
+
# If the video is too long (longer than max_fps and num_frames can support),
|
| 118 |
+
# we will use lower fps to sample frames.
|
| 119 |
+
if duration >= num_frames / max_fps:
|
| 120 |
+
frame_interval = frame_count // num_frames
|
| 121 |
+
|
| 122 |
+
# If the video is too short, we will skip the video if there is only one frame.
|
| 123 |
+
if frame_interval == 0 and frame_count <= 1:
|
| 124 |
+
print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
|
| 125 |
+
empty_video_frames = int(random.uniform(2, 8 * max_fps))
|
| 126 |
+
return [
|
| 127 |
+
Image.new("RGB", (720, 720)),
|
| 128 |
+
] * empty_video_frames, 0
|
| 129 |
+
|
| 130 |
+
images = []
|
| 131 |
+
count = 0
|
| 132 |
+
success = True
|
| 133 |
+
frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
|
| 134 |
+
|
| 135 |
+
while success:
|
| 136 |
+
if frame_count >= num_frames:
|
| 137 |
+
# success, frame = vidcap.read()
|
| 138 |
+
if count in frame_indices:
|
| 139 |
+
success, frame = vidcap.read()
|
| 140 |
+
try:
|
| 141 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 142 |
+
im_pil = Image.fromarray(img)
|
| 143 |
+
images.append(im_pil)
|
| 144 |
+
except:
|
| 145 |
+
# print("Failed to read frame:", count)
|
| 146 |
+
continue
|
| 147 |
+
if len(images) >= num_frames:
|
| 148 |
+
return images, num_frames
|
| 149 |
+
else:
|
| 150 |
+
success = vidcap.grab()
|
| 151 |
+
count += 1
|
| 152 |
+
else:
|
| 153 |
+
# Left padding frames if the video is not long enough
|
| 154 |
+
success, frame = vidcap.read()
|
| 155 |
+
if success:
|
| 156 |
+
try:
|
| 157 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 158 |
+
im_pil = Image.fromarray(img)
|
| 159 |
+
images.append(im_pil)
|
| 160 |
+
except:
|
| 161 |
+
# print("Failed to read frame:", count)
|
| 162 |
+
continue
|
| 163 |
+
count += 1
|
| 164 |
+
else:
|
| 165 |
+
break
|
| 166 |
+
else:
|
| 167 |
+
frames_required = int(duration * max_fps)
|
| 168 |
+
frame_indices = np.linspace(0, frame_count - 1, frames_required, dtype=int)
|
| 169 |
+
if frames_required == 0:
|
| 170 |
+
print(f"frames_required is fewer than 2. Duration {duration}, return empty image.")
|
| 171 |
+
empty_video_frames = int(random.uniform(2, 8 * max_fps))
|
| 172 |
+
return [
|
| 173 |
+
Image.new("RGB", (720, 720)),
|
| 174 |
+
] * empty_video_frames, 0
|
| 175 |
+
elif frames_required == 1:
|
| 176 |
+
frame_indices = np.linspace(0, frame_count - 1, 2, dtype=int)
|
| 177 |
+
images = []
|
| 178 |
+
count = 0
|
| 179 |
+
looked = 0
|
| 180 |
+
success = True
|
| 181 |
+
|
| 182 |
+
while success:
|
| 183 |
+
success, frame = vidcap.read()
|
| 184 |
+
if success and (looked in frame_indices):
|
| 185 |
+
try:
|
| 186 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 187 |
+
im_pil = Image.fromarray(img)
|
| 188 |
+
images.append(im_pil)
|
| 189 |
+
except:
|
| 190 |
+
continue
|
| 191 |
+
count += 1
|
| 192 |
+
looked += 1
|
| 193 |
+
|
| 194 |
+
if len(images) == 0:
|
| 195 |
+
empty_video_frames = int(random.uniform(2, 8 * max_fps))
|
| 196 |
+
return [
|
| 197 |
+
Image.new("RGB", (720, 720)),
|
| 198 |
+
] * empty_video_frames, 0
|
| 199 |
+
else:
|
| 200 |
+
return images, len(images)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def opencv_extract_frames(vpath_or_bytesio, frames=6, max_fps=0.0, fps=None, frame_count=None):
|
| 204 |
+
"""
|
| 205 |
+
Extract frames from a video using OpenCV.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video.
|
| 209 |
+
frames (int): Number of frames to extract from the video.
|
| 210 |
+
fps (float): Frames per second of the video. If 0.0, the function will extract frames at equal intervals.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
list: List of PIL Images extracted from the video.
|
| 214 |
+
|
| 215 |
+
Raises:
|
| 216 |
+
NotImplementedError: If the type of `vpath_or_bytesio` is not supported.
|
| 217 |
+
"""
|
| 218 |
+
import cv2
|
| 219 |
+
|
| 220 |
+
if isinstance(vpath_or_bytesio, str):
|
| 221 |
+
vidcap = cv2.VideoCapture(vpath_or_bytesio)
|
| 222 |
+
if max_fps > 0.0:
|
| 223 |
+
return get_frame_from_vcap_with_fps(
|
| 224 |
+
vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
|
| 225 |
+
)
|
| 226 |
+
return get_frame_from_vcap(
|
| 227 |
+
vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
|
| 228 |
+
)
|
| 229 |
+
elif isinstance(vpath_or_bytesio, (BytesIO,)):
|
| 230 |
+
# assuming mp4
|
| 231 |
+
with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
|
| 232 |
+
temp_video.write(vpath_or_bytesio.read())
|
| 233 |
+
temp_video_name = temp_video.name
|
| 234 |
+
vidcap = cv2.VideoCapture(temp_video_name)
|
| 235 |
+
if max_fps > 0.0:
|
| 236 |
+
return get_frame_from_vcap_with_fps(
|
| 237 |
+
vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
|
| 238 |
+
)
|
| 239 |
+
return get_frame_from_vcap(
|
| 240 |
+
vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
raise NotImplementedError(type(vpath_or_bytesio))
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def load_image_from_base64(image):
|
| 247 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def expand2square(pil_img, background_color):
|
| 251 |
+
"""
|
| 252 |
+
Expand the given PIL image to a square shape by adding padding.
|
| 253 |
+
|
| 254 |
+
Parameters:
|
| 255 |
+
- pil_img: The PIL image to be expanded.
|
| 256 |
+
- background_color: The color of the padding to be added.
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
- The expanded PIL image.
|
| 260 |
+
|
| 261 |
+
If the image is already square, it is returned as is.
|
| 262 |
+
If the image is wider than it is tall, padding is added to the top and bottom.
|
| 263 |
+
If the image is taller than it is wide, padding is added to the left and right.
|
| 264 |
+
"""
|
| 265 |
+
width, height = pil_img.size
|
| 266 |
+
if pil_img.mode == "L":
|
| 267 |
+
background_color = background_color[0]
|
| 268 |
+
if width == height:
|
| 269 |
+
return pil_img
|
| 270 |
+
elif width > height:
|
| 271 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 272 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 273 |
+
return result
|
| 274 |
+
else:
|
| 275 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 276 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 277 |
+
return result
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
| 281 |
+
best_ratio_diff = float("inf")
|
| 282 |
+
best_ratio = (1, 1)
|
| 283 |
+
area = width * height
|
| 284 |
+
for ratio in target_ratios:
|
| 285 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
| 286 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
| 287 |
+
if ratio_diff < best_ratio_diff:
|
| 288 |
+
best_ratio_diff = ratio_diff
|
| 289 |
+
best_ratio = ratio
|
| 290 |
+
elif ratio_diff == best_ratio_diff:
|
| 291 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
| 292 |
+
best_ratio = ratio
|
| 293 |
+
return best_ratio
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=384, use_thumbnail=True):
|
| 297 |
+
orig_width, orig_height = image.size
|
| 298 |
+
aspect_ratio = orig_width / orig_height
|
| 299 |
+
|
| 300 |
+
# calculate the existing image aspect ratio
|
| 301 |
+
target_ratios = {
|
| 302 |
+
(i, j)
|
| 303 |
+
for n in range(min_num, max_num + 1)
|
| 304 |
+
for i in range(1, n + 1)
|
| 305 |
+
for j in range(1, n + 1)
|
| 306 |
+
if i * j <= max_num and i * j >= min_num
|
| 307 |
+
}
|
| 308 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
| 309 |
+
|
| 310 |
+
# find the closest aspect ratio to the target
|
| 311 |
+
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
| 312 |
+
|
| 313 |
+
# calculate the target width and height
|
| 314 |
+
target_width = image_size * target_aspect_ratio[0]
|
| 315 |
+
target_height = image_size * target_aspect_ratio[1]
|
| 316 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
| 317 |
+
|
| 318 |
+
# resize the image
|
| 319 |
+
resized_img = image.resize((target_width, target_height))
|
| 320 |
+
processed_images = []
|
| 321 |
+
for i in range(blocks):
|
| 322 |
+
box = (
|
| 323 |
+
(i % (target_width // image_size)) * image_size,
|
| 324 |
+
(i // (target_width // image_size)) * image_size,
|
| 325 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
| 326 |
+
((i // (target_width // image_size)) + 1) * image_size,
|
| 327 |
+
)
|
| 328 |
+
# split the image
|
| 329 |
+
split_img = resized_img.crop(box)
|
| 330 |
+
processed_images.append(split_img)
|
| 331 |
+
assert len(processed_images) == blocks
|
| 332 |
+
if use_thumbnail and len(processed_images) != 1:
|
| 333 |
+
thumbnail_img = image.resize((image_size, image_size))
|
| 334 |
+
processed_images.append(thumbnail_img)
|
| 335 |
+
return processed_images
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def dynamic_s2_preprocess(image, s2_scales=[384, 768, 1152], max_num=12, image_size=384):
|
| 339 |
+
orig_width, orig_height = image.size
|
| 340 |
+
aspect_ratio = orig_width / orig_height
|
| 341 |
+
min_num = (s2_scales[-1] // s2_scales[0]) ** 2 # at least use number of tiles as the largest scale
|
| 342 |
+
|
| 343 |
+
processed_images = []
|
| 344 |
+
|
| 345 |
+
##########################################################################################
|
| 346 |
+
############# Add tiles for all but the last scale using fixed squre ratio ###############
|
| 347 |
+
##########################################################################################
|
| 348 |
+
|
| 349 |
+
for scale in s2_scales[:-1]:
|
| 350 |
+
target_width = image_size * (scale // s2_scales[0])
|
| 351 |
+
target_height = image_size * (scale // s2_scales[0])
|
| 352 |
+
blocks = (scale // s2_scales[0]) ** 2
|
| 353 |
+
|
| 354 |
+
# resize the image
|
| 355 |
+
resized_img = image.resize((target_width, target_height))
|
| 356 |
+
for i in range(blocks):
|
| 357 |
+
box = (
|
| 358 |
+
(i % (target_width // image_size)) * image_size,
|
| 359 |
+
(i // (target_width // image_size)) * image_size,
|
| 360 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
| 361 |
+
((i // (target_width // image_size)) + 1) * image_size,
|
| 362 |
+
)
|
| 363 |
+
# split the image
|
| 364 |
+
split_img = resized_img.crop(box)
|
| 365 |
+
processed_images.append(split_img)
|
| 366 |
+
|
| 367 |
+
##########################################################################################
|
| 368 |
+
################ Add tiles for the last scale using dynamic aspect ratio #################
|
| 369 |
+
##########################################################################################
|
| 370 |
+
|
| 371 |
+
# calculate the existing image aspect ratio
|
| 372 |
+
target_ratios = {
|
| 373 |
+
(i, j)
|
| 374 |
+
for n in range(min_num, max_num + 1)
|
| 375 |
+
for i in range(1, n + 1)
|
| 376 |
+
for j in range(1, n + 1)
|
| 377 |
+
if i * j <= max_num and i * j >= min_num
|
| 378 |
+
}
|
| 379 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
| 380 |
+
|
| 381 |
+
# find the closest aspect ratio to the target
|
| 382 |
+
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
| 383 |
+
|
| 384 |
+
# calculate the target width and height
|
| 385 |
+
target_width = image_size * target_aspect_ratio[0]
|
| 386 |
+
target_height = image_size * target_aspect_ratio[1]
|
| 387 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
| 388 |
+
|
| 389 |
+
# resize the image
|
| 390 |
+
resized_img = image.resize((target_width, target_height))
|
| 391 |
+
for i in range(blocks):
|
| 392 |
+
box = (
|
| 393 |
+
(i % (target_width // image_size)) * image_size,
|
| 394 |
+
(i // (target_width // image_size)) * image_size,
|
| 395 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
| 396 |
+
((i // (target_width // image_size)) + 1) * image_size,
|
| 397 |
+
)
|
| 398 |
+
# split the image
|
| 399 |
+
split_img = resized_img.crop(box)
|
| 400 |
+
processed_images.append(split_img)
|
| 401 |
+
|
| 402 |
+
return processed_images, (target_aspect_ratio[1], target_aspect_ratio[0])
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def dynamic_process_images_and_prompt(images, prompt, data_args, image_folder=None, max_tiles=None):
|
| 406 |
+
prompt = prompt.split(DEFAULT_IMAGE_TOKEN)
|
| 407 |
+
idx = 0
|
| 408 |
+
all_images = []
|
| 409 |
+
for img in images:
|
| 410 |
+
processed_images = process_image(img, data_args, image_folder, enable_dynamic_res=True, max_tiles=max_tiles)
|
| 411 |
+
all_images.append(processed_images)
|
| 412 |
+
prompt.insert(idx + 1, f"{DEFAULT_IMAGE_TOKEN}\n" * processed_images.shape[0])
|
| 413 |
+
idx += 2
|
| 414 |
+
prompt = "".join(prompt)
|
| 415 |
+
if all_images:
|
| 416 |
+
all_images = torch.cat(all_images)
|
| 417 |
+
else:
|
| 418 |
+
all_images = None
|
| 419 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, "")
|
| 420 |
+
return all_images, prompt
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def dynamic_s2_process_images_and_prompt(images, prompt, data_args, image_folder=None):
|
| 424 |
+
idx = 0
|
| 425 |
+
all_images = []
|
| 426 |
+
all_block_size = []
|
| 427 |
+
for img in images:
|
| 428 |
+
processed_images, block_size = process_image(img, data_args, image_folder, enable_dynamic_s2=True)
|
| 429 |
+
all_images.append(processed_images)
|
| 430 |
+
all_block_size.append(block_size)
|
| 431 |
+
idx += 2
|
| 432 |
+
if all_images:
|
| 433 |
+
all_images = torch.cat(all_images)
|
| 434 |
+
else:
|
| 435 |
+
all_images = None
|
| 436 |
+
return all_images, all_block_size
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def process_image(
|
| 440 |
+
image_file, data_args, image_folder, enable_dynamic_res=False, enable_dynamic_s2=False, max_tiles=None
|
| 441 |
+
):
|
| 442 |
+
processor = data_args.image_processor
|
| 443 |
+
if isinstance(image_file, str):
|
| 444 |
+
if image_folder is not None:
|
| 445 |
+
image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
|
| 446 |
+
else:
|
| 447 |
+
image = Image.open(image_file).convert("RGB")
|
| 448 |
+
else:
|
| 449 |
+
# image is stored in bytearray
|
| 450 |
+
image = image_file
|
| 451 |
+
image = image.convert("RGB")
|
| 452 |
+
if hasattr(data_args.image_processor, "crop_size"):
|
| 453 |
+
# CLIP vision tower
|
| 454 |
+
crop_size = data_args.image_processor.crop_size
|
| 455 |
+
else:
|
| 456 |
+
# SIGLIP vision tower
|
| 457 |
+
assert hasattr(data_args.image_processor, "size")
|
| 458 |
+
crop_size = data_args.image_processor.size
|
| 459 |
+
if "dynamic_s2" in data_args.image_aspect_ratio and enable_dynamic_s2:
|
| 460 |
+
assert crop_size["height"] == crop_size["width"]
|
| 461 |
+
images, block_size = dynamic_s2_preprocess(
|
| 462 |
+
image, s2_scales=data_args.s2_scales, max_num=data_args.max_tiles, image_size=crop_size["height"]
|
| 463 |
+
)
|
| 464 |
+
images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
|
| 465 |
+
return torch.stack(images), block_size
|
| 466 |
+
if "dynamic" in data_args.image_aspect_ratio and enable_dynamic_res:
|
| 467 |
+
assert crop_size["height"] == crop_size["width"]
|
| 468 |
+
if max_tiles is not None:
|
| 469 |
+
max_num = max_tiles
|
| 470 |
+
else:
|
| 471 |
+
max_num = data_args.max_tiles
|
| 472 |
+
images = dynamic_preprocess(image, min_num=data_args.min_tiles, max_num=max_num, image_size=crop_size["height"])
|
| 473 |
+
images = [processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
|
| 474 |
+
return torch.stack(images)
|
| 475 |
+
|
| 476 |
+
if data_args.image_aspect_ratio == "resize":
|
| 477 |
+
image = image.resize((crop_size["width"], crop_size["height"]))
|
| 478 |
+
if data_args.image_aspect_ratio == "pad":
|
| 479 |
+
|
| 480 |
+
def expand2square(pil_img, background_color):
|
| 481 |
+
width, height = pil_img.size
|
| 482 |
+
if width == height:
|
| 483 |
+
return pil_img
|
| 484 |
+
elif width > height:
|
| 485 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 486 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 487 |
+
return result
|
| 488 |
+
else:
|
| 489 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 490 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 491 |
+
return result
|
| 492 |
+
|
| 493 |
+
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
|
| 494 |
+
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
| 495 |
+
else:
|
| 496 |
+
# Using default behavior of the vision encoder
|
| 497 |
+
# For CLIP, default is central crop
|
| 498 |
+
# For Radio, default is central crop
|
| 499 |
+
# For Siglip, default is resize
|
| 500 |
+
# For InternVIT, default is resize
|
| 501 |
+
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
| 502 |
+
return image
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def process_images(images, image_processor, model_cfg, enable_dynamic_res=False, max_tiles=None):
|
| 506 |
+
model_cfg.image_processor = image_processor
|
| 507 |
+
new_images = [
|
| 508 |
+
process_image(image, model_cfg, None, enable_dynamic_res=enable_dynamic_res, max_tiles=max_tiles)
|
| 509 |
+
for image in images
|
| 510 |
+
]
|
| 511 |
+
|
| 512 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
| 513 |
+
if len(new_images[0].shape) == 4:
|
| 514 |
+
new_images = torch.cat(new_images, dim=0)
|
| 515 |
+
elif len(new_images[0].shape) == 3:
|
| 516 |
+
new_images = torch.stack(new_images, dim=0)
|
| 517 |
+
else:
|
| 518 |
+
raise ValueError(f"new_images rank does not equal to 4, rank: {len(new_images[0].shape)}")
|
| 519 |
+
else:
|
| 520 |
+
raise ValueError("The shape of images in new_images is different!")
|
| 521 |
+
return new_images
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def tokenizer_image_token(prompt, tokenizer, return_tensors=None, return_ids=True):
|
| 525 |
+
if return_ids:
|
| 526 |
+
return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]
|
| 527 |
+
else:
|
| 528 |
+
return tokenizer(prompt, return_tensors=return_tensors)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def is_gemma_tokenizer(tokenizer):
|
| 532 |
+
return "gemma" in tokenizer.__class__.__name__.lower()
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def get_model_name_from_path(model_path):
|
| 536 |
+
model_path = model_path.strip("/")
|
| 537 |
+
model_paths = model_path.split("/")
|
| 538 |
+
if model_paths[-1].startswith("checkpoint-"):
|
| 539 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
| 540 |
+
else:
|
| 541 |
+
return model_paths[-1]
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
| 545 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
| 546 |
+
self.keywords = keywords
|
| 547 |
+
self.keyword_ids = []
|
| 548 |
+
self.max_keyword_len = 0
|
| 549 |
+
for keyword in keywords:
|
| 550 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
| 551 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
| 552 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
| 553 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
| 554 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
| 555 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
| 556 |
+
self.tokenizer = tokenizer
|
| 557 |
+
self.start_len = input_ids.shape[1]
|
| 558 |
+
|
| 559 |
+
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 560 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
| 561 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
| 562 |
+
for keyword_id in self.keyword_ids:
|
| 563 |
+
if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
|
| 564 |
+
return True
|
| 565 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
| 566 |
+
for keyword in self.keywords:
|
| 567 |
+
if keyword in outputs:
|
| 568 |
+
return True
|
| 569 |
+
return False
|
| 570 |
+
|
| 571 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 572 |
+
outputs = []
|
| 573 |
+
for i in range(output_ids.shape[0]):
|
| 574 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
| 575 |
+
return all(outputs)
|
model_utils_packing.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from importlib import import_module
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import transformers
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
__all__ = ["patch"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _get_unpad_data(attention_mask: torch.Tensor, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 13 |
+
if hasattr(_get_unpad_data, "seqlens_in_batch"):
|
| 14 |
+
seqlens_in_batch = _get_unpad_data.seqlens_in_batch
|
| 15 |
+
else:
|
| 16 |
+
seqlens_in_batch = torch.sum(attention_mask, dim=1)
|
| 17 |
+
|
| 18 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 19 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 20 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
| 21 |
+
return indices, cu_seqlens, max_seqlen_in_batch
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def set_seqlens_in_batch(seqlens_in_batch: torch.Tensor) -> None:
|
| 25 |
+
_get_unpad_data.seqlens_in_batch = seqlens_in_batch
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def patch(model: nn.Module) -> None:
|
| 29 |
+
if transformers.__version__ < "4.43.0":
|
| 30 |
+
m = import_module(model.__module__)
|
| 31 |
+
if not hasattr(m, "_get_unpad_data"):
|
| 32 |
+
raise ValueError(f"Module {m} does not have function '_get_unpad_data' for packing")
|
| 33 |
+
m._get_unpad_data = _get_unpad_data
|
| 34 |
+
else:
|
| 35 |
+
transformers.modeling_flash_attention_utils._get_unpad_data = _get_unpad_data
|
modeling_vila.py
ADDED
|
@@ -0,0 +1,1256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import os.path
|
| 7 |
+
import os.path as osp
|
| 8 |
+
import shutil
|
| 9 |
+
import warnings
|
| 10 |
+
from abc import ABC
|
| 11 |
+
from collections import OrderedDict, defaultdict, deque
|
| 12 |
+
from copy import deepcopy
|
| 13 |
+
from itertools import chain
|
| 14 |
+
from threading import Thread
|
| 15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import torchvision
|
| 22 |
+
from einops import rearrange
|
| 23 |
+
from PIL import Image
|
| 24 |
+
from transformers import (
|
| 25 |
+
AutoConfig,
|
| 26 |
+
AutoModel,
|
| 27 |
+
AutoProcessor,
|
| 28 |
+
AutoTokenizer,
|
| 29 |
+
GenerationConfig,
|
| 30 |
+
LogitsProcessor,
|
| 31 |
+
PretrainedConfig,
|
| 32 |
+
PreTrainedModel,
|
| 33 |
+
Qwen2Config,
|
| 34 |
+
Qwen2ForCausalLM,
|
| 35 |
+
Qwen2PreTrainedModel,
|
| 36 |
+
TextIteratorStreamer,
|
| 37 |
+
)
|
| 38 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 39 |
+
from transformers.modeling_utils import ContextManagers, no_init_weights
|
| 40 |
+
|
| 41 |
+
from .auto_processor import VILAProcessor
|
| 42 |
+
from .base_projector import MultimodalProjector, MultimodalProjectorConfig
|
| 43 |
+
from .builder import build_llm_and_tokenizer
|
| 44 |
+
from .configuration_vila import VILAConfig
|
| 45 |
+
from .constants import *
|
| 46 |
+
from .conversation import SeparatorStyle, default_conversation
|
| 47 |
+
from .distributed import all_gather as vila_all_gather
|
| 48 |
+
from .loss import soft_cross_entropy
|
| 49 |
+
from .media import extract_media
|
| 50 |
+
from .media_encoder import BasicImageEncoder, BasicVideoEncoder, TSPVideoEncoder
|
| 51 |
+
from .mm_utils import process_image, process_images
|
| 52 |
+
from .model_utils_packing import set_seqlens_in_batch
|
| 53 |
+
from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerDynamicS2, SiglipVisionTowerS2
|
| 54 |
+
from .tokenizer_utils import tokenize_conversation
|
| 55 |
+
from .utils import get_model_config, load_tokenizer_then_handle_media_tokens_and_chat_template
|
| 56 |
+
|
| 57 |
+
# from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS
|
| 58 |
+
|
| 59 |
+
# ease debugging
|
| 60 |
+
python_input = input
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# quick hack for remote code
|
| 64 |
+
def get_pg_manager():
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def get_model_weights_dtype(model: nn.Module):
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def build_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
|
| 73 |
+
if model_type_or_path is None:
|
| 74 |
+
return None
|
| 75 |
+
## load from pretrained model
|
| 76 |
+
if config.resume_path:
|
| 77 |
+
assert os.path.exists(model_type_or_path), f"Resume mm projector path {model_type_or_path} does not exist!"
|
| 78 |
+
return MultimodalProjector.from_pretrained(model_type_or_path, config)
|
| 79 |
+
## build from scratch
|
| 80 |
+
else:
|
| 81 |
+
mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
|
| 82 |
+
mm_projector = MultimodalProjector(mm_projector_cfg, config)
|
| 83 |
+
return mm_projector
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def check_dot_in_model_path(model_path: str):
|
| 87 |
+
"""Check if the model path contains dot, which will affect the remote code loading."""
|
| 88 |
+
if osp.isdir(model_path): # local model
|
| 89 |
+
if "." in osp.abspath(model_path):
|
| 90 |
+
return True
|
| 91 |
+
else: # remote model
|
| 92 |
+
if "." in model_path:
|
| 93 |
+
return True
|
| 94 |
+
return False
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_vila_version(model_path: str) -> str:
|
| 98 |
+
VERSIONS = ["vila1.5", "vila-u", "longvila", "nvila", "vila-m3"]
|
| 99 |
+
for version in VERSIONS:
|
| 100 |
+
if version in model_path.lower():
|
| 101 |
+
return version
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def generate_jinja_template(conv_mode: str) -> str:
|
| 106 |
+
if conv_mode == "vicuna_v1":
|
| 107 |
+
return """{% set system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. " %}
|
| 108 |
+
{% set roles = ["user", "assistant"] %}
|
| 109 |
+
{% set sep = " " %}
|
| 110 |
+
|
| 111 |
+
{{ system_prompt }}
|
| 112 |
+
|
| 113 |
+
{% for message in messages %}
|
| 114 |
+
{% if message['role'] == roles[0] %}
|
| 115 |
+
{{ "USER: " }}{{ sep }}{{ message['content'] }}{{ sep }}
|
| 116 |
+
{% else %}
|
| 117 |
+
{{ "ASSISTANT: " }}{{ sep }}{{ message['content'] }}{{ sep }}
|
| 118 |
+
{% endif %}
|
| 119 |
+
{% endfor %}
|
| 120 |
+
{% if messages[-1]['role'] == 'user' %}
|
| 121 |
+
{{ "ASSISTANT:" }}
|
| 122 |
+
{% endif %}
|
| 123 |
+
"""
|
| 124 |
+
elif conv_mode == "llama_3":
|
| 125 |
+
return """{% set system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|>" %}
|
| 126 |
+
{% set roles = ["<|start_header_id|>user<|end_header_id|>\\n\\n", "<|start_header_id|>assistant<|end_header_id|>\\n\\n"]%}
|
| 127 |
+
{% set sep = "<|eot_id|>" %}
|
| 128 |
+
|
| 129 |
+
{{ system_prompt }}
|
| 130 |
+
{% for message in messages %}
|
| 131 |
+
{% if message['role'] == 'user' %}
|
| 132 |
+
{{ roles[0] }}{{ message['content'] }}{{ sep }}
|
| 133 |
+
{% else %}
|
| 134 |
+
{{ roles[1] }}{{ message['content'] }}{{ sep }}
|
| 135 |
+
{% endif %}
|
| 136 |
+
{% endfor %}
|
| 137 |
+
{% if messages[-1]['role'] == 'user' %}
|
| 138 |
+
{{ roles[1] }}
|
| 139 |
+
{% endif %}
|
| 140 |
+
"""
|
| 141 |
+
elif conv_mode == "hermes_2":
|
| 142 |
+
return """{% set system_prompt = "<|im_start|>system\nAnswer the questions." %}
|
| 143 |
+
{% set roles = ["<|im_start|>user\n", "<|im_start|>assistant\n"] %}
|
| 144 |
+
{% set sep = "<|im_end|>" %}
|
| 145 |
+
|
| 146 |
+
{{ system_prompt }}{{ sep }}
|
| 147 |
+
|
| 148 |
+
{% for message in messages %}
|
| 149 |
+
{% if message['role'] == 'user' %}
|
| 150 |
+
{{ roles[0] }}{{ message['content'] }}{{ sep }}
|
| 151 |
+
{% else %}
|
| 152 |
+
{{ roles[1] }}{{ message['content'] }}{{ sep }}
|
| 153 |
+
{% endif %}
|
| 154 |
+
{% endfor %}"""
|
| 155 |
+
else:
|
| 156 |
+
raise NotImplementedError(f"Jinja template generation is not implemented for {conv_mode}.")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel:
|
| 160 |
+
## skip vision tower instantiation
|
| 161 |
+
if model_name_or_path is None:
|
| 162 |
+
return None
|
| 163 |
+
|
| 164 |
+
vision_tower_arch = None
|
| 165 |
+
if config.resume_path and "radio" not in model_name_or_path:
|
| 166 |
+
assert os.path.exists(model_name_or_path), f"Resume vision tower path {model_name_or_path} does not exist!"
|
| 167 |
+
vision_tower_cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
| 168 |
+
vision_tower_arch = vision_tower_cfg.architectures[0].lower()
|
| 169 |
+
vision_tower_name = vision_tower_arch if vision_tower_arch is not None else model_name_or_path
|
| 170 |
+
|
| 171 |
+
use_s2 = getattr(config, "s2", False)
|
| 172 |
+
use_dynamic_s2 = getattr(config, "dynamic_s2", False)
|
| 173 |
+
|
| 174 |
+
if "siglip" in vision_tower_name:
|
| 175 |
+
if use_dynamic_s2:
|
| 176 |
+
vision_tower = SiglipVisionTowerDynamicS2(model_name_or_path, config)
|
| 177 |
+
elif use_s2:
|
| 178 |
+
vision_tower = SiglipVisionTowerS2(model_name_or_path, config)
|
| 179 |
+
else:
|
| 180 |
+
vision_tower = SiglipVisionTower(model_name_or_path, config)
|
| 181 |
+
else:
|
| 182 |
+
raise NotImplementedError(f"Unknown vision tower: {model_name_or_path}")
|
| 183 |
+
|
| 184 |
+
config.mm_hidden_size = (
|
| 185 |
+
vision_tower.config.hidden_size if not (use_s2 or use_dynamic_s2) else vision_tower.hidden_size
|
| 186 |
+
)
|
| 187 |
+
return vision_tower
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class VILAPretrainedModel(PreTrainedModel):
|
| 191 |
+
config_class = VILAConfig
|
| 192 |
+
main_input_name = "input_embeds"
|
| 193 |
+
supports_gradient_checkpointing = True
|
| 194 |
+
_supports_flash_attn_2 = True
|
| 195 |
+
_no_split_modules = ["Qwen2DecoderLayer", "SiglipEncoderLayer"]
|
| 196 |
+
|
| 197 |
+
def __init__(self, config: VILAConfig, *args, **kwargs):
|
| 198 |
+
super().__init__(config)
|
| 199 |
+
self.config = config
|
| 200 |
+
cfgs = get_model_config(config)
|
| 201 |
+
if len(cfgs) == 3:
|
| 202 |
+
llm_cfg, vision_tower_cfg, mm_projector_cfg = cfgs
|
| 203 |
+
else:
|
| 204 |
+
raise ValueError("`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config.")
|
| 205 |
+
|
| 206 |
+
# loading on auto by default
|
| 207 |
+
device_map = kwargs.get("device_map", "auto")
|
| 208 |
+
self.mm_projector = build_mm_projector(mm_projector_cfg, config)
|
| 209 |
+
self.vision_tower = build_vision_tower(vision_tower_cfg, config)
|
| 210 |
+
if device_map in ["auto", "cuda"]:
|
| 211 |
+
self.mm_projector = self.mm_projector.cuda()
|
| 212 |
+
self.vision_tower = self.vision_tower.cuda()
|
| 213 |
+
# set device_map auto can autoamtically shard llm to different devices
|
| 214 |
+
self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map)
|
| 215 |
+
self.llm_model_embed_tokens = self.llm.model.embed_tokens
|
| 216 |
+
|
| 217 |
+
try:
|
| 218 |
+
use_tsp_encoder = "TSPVideoEncoder" in getattr(config, "video_encoder", None)["_target_"]
|
| 219 |
+
except:
|
| 220 |
+
use_tsp_encoder = False
|
| 221 |
+
print("use_tsp_encoder", use_tsp_encoder)
|
| 222 |
+
self.tokenizer.padding_side = "left"
|
| 223 |
+
self.encoders = {"image": BasicImageEncoder(self), "video": TSPVideoEncoder(self) if use_tsp_encoder else BasicVideoEncoder(self)}
|
| 224 |
+
|
| 225 |
+
self.post_config()
|
| 226 |
+
self.is_loaded = True
|
| 227 |
+
self.llm_only_need_embed = kwargs.get("llm_only_need_embed", False)
|
| 228 |
+
if self.llm_only_need_embed:
|
| 229 |
+
print("We only need the embed_tokens in llm.")
|
| 230 |
+
del self.llm
|
| 231 |
+
self.llm = None
|
| 232 |
+
torch.cuda.empty_cache()
|
| 233 |
+
|
| 234 |
+
assert (
|
| 235 |
+
self.llm is not None or self.vision_tower is not None or self.mm_projector is not None
|
| 236 |
+
), "At least one of the components must be instantiated."
|
| 237 |
+
|
| 238 |
+
@classmethod
|
| 239 |
+
def convert_vila_dev_ckpt_to_remote(
|
| 240 |
+
self,
|
| 241 |
+
model_path: str,
|
| 242 |
+
output_dir: str = None,
|
| 243 |
+
vila_version: str | None = None,
|
| 244 |
+
conv_mode: str | None = None,
|
| 245 |
+
copy: bool = False,
|
| 246 |
+
copy_weights: bool = True,
|
| 247 |
+
copy_code: bool = True,
|
| 248 |
+
*model_args,
|
| 249 |
+
**kwargs,
|
| 250 |
+
):
|
| 251 |
+
# assert type(self) == VILAForCasualLM, "This method is only available for VILAForCasualLM."
|
| 252 |
+
assert model_path != output_dir, "model_path and output_dir cannot be the same"
|
| 253 |
+
if os.path.isdir(model_path):
|
| 254 |
+
model_path = model_path
|
| 255 |
+
else:
|
| 256 |
+
from huggingface_hub import HfApi, snapshot_download
|
| 257 |
+
|
| 258 |
+
model_path = snapshot_download(model_path)
|
| 259 |
+
print("downloading HF model to", model_path)
|
| 260 |
+
|
| 261 |
+
if check_dot_in_model_path(model_path) and output_dir is None:
|
| 262 |
+
raise ValueError(
|
| 263 |
+
f"Model path {model_path} contains a dot, which will affect the remote code loading. Please specify the output directory without dot in the path to fix this issue."
|
| 264 |
+
)
|
| 265 |
+
if output_dir is not None and "." in output_dir:
|
| 266 |
+
raise ValueError(
|
| 267 |
+
f"Output directory {output_dir} contains a dot, which will affect the remote code loading. Please specify a valid output directory without dots."
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
if copy:
|
| 271 |
+
print("copy is set to True, copying weights and code to output_dir")
|
| 272 |
+
copy_weights = copy_code = True
|
| 273 |
+
# copy weights and code to output_dir
|
| 274 |
+
self.copy_or_symlink_directory(model_path, output_dir, copy=copy_weights)
|
| 275 |
+
self.copy_remote_py_files(output_dir, copy=copy_code)
|
| 276 |
+
|
| 277 |
+
if vila_version is None:
|
| 278 |
+
vila_version = get_vila_version(output_dir)
|
| 279 |
+
|
| 280 |
+
cfg_path = os.path.join(output_dir, "config.json")
|
| 281 |
+
config = json.load(open(cfg_path))
|
| 282 |
+
config["version"] = "2.0" # nvila tag
|
| 283 |
+
config["architectures"] = ["VILAForCausalLM"]
|
| 284 |
+
config["auto_map"] = {
|
| 285 |
+
"AutoProcessor": "auto_processor.VILAProcessor",
|
| 286 |
+
"AutoConfig": "modeling_vila.VILAConfig",
|
| 287 |
+
"AutoModel": "modeling_vila.VILAForCausalLM",
|
| 288 |
+
"AutoModelForCausalLM": "modeling_vila.VILAForCausalLM",
|
| 289 |
+
}
|
| 290 |
+
# vila1.5 legacy support
|
| 291 |
+
config["model_type"] = "vila"
|
| 292 |
+
if vila_version in ["vila1.5", "vila-m3"]:
|
| 293 |
+
if conv_mode is None:
|
| 294 |
+
raise ValueError(f"Please specify the conversation mode for {output_dir}.")
|
| 295 |
+
config["chat_template"] = conv_mode
|
| 296 |
+
jinja_template = generate_jinja_template(conv_mode)
|
| 297 |
+
jinja_path = os.path.join(output_dir, f"{conv_mode}.jinja")
|
| 298 |
+
with open(jinja_path, "w") as f:
|
| 299 |
+
f.write(jinja_template)
|
| 300 |
+
json.dump(config, open(cfg_path, "w"), indent=2)
|
| 301 |
+
|
| 302 |
+
##########################################################################################
|
| 303 |
+
config = AutoConfig.from_pretrained(output_dir, trust_remote_code=True)
|
| 304 |
+
tokenizer = load_tokenizer_then_handle_media_tokens_and_chat_template(output_dir, config)
|
| 305 |
+
tokenizer.save_pretrained(osp.join(output_dir, "llm"))
|
| 306 |
+
##########################################################################################
|
| 307 |
+
|
| 308 |
+
@classmethod
|
| 309 |
+
def copy_or_symlink_directory(cls, model_path, output_dir, copy=True):
|
| 310 |
+
# Create output directory if it doesn't exist
|
| 311 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 312 |
+
# Create symlinks for all files in model_path to output_dir
|
| 313 |
+
for item in os.listdir(model_path):
|
| 314 |
+
src_path = os.path.join(model_path, item)
|
| 315 |
+
dst_path = os.path.join(output_dir, item)
|
| 316 |
+
|
| 317 |
+
# Remove existing file/directory at destination if it exists
|
| 318 |
+
if os.path.exists(dst_path):
|
| 319 |
+
if os.path.islink(dst_path):
|
| 320 |
+
os.unlink(dst_path)
|
| 321 |
+
elif os.path.isdir(dst_path):
|
| 322 |
+
shutil.rmtree(dst_path)
|
| 323 |
+
else:
|
| 324 |
+
os.remove(dst_path)
|
| 325 |
+
|
| 326 |
+
# Create symlink
|
| 327 |
+
if copy:
|
| 328 |
+
if os.path.isdir(src_path):
|
| 329 |
+
shutil.copytree(src_path, dst_path)
|
| 330 |
+
else:
|
| 331 |
+
shutil.copy2(src_path, dst_path)
|
| 332 |
+
print(f"Copied {src_path} to {dst_path}")
|
| 333 |
+
else:
|
| 334 |
+
os.symlink(src_path, dst_path)
|
| 335 |
+
print(f"Created symlink from {src_path} to {dst_path}")
|
| 336 |
+
|
| 337 |
+
@classmethod
|
| 338 |
+
def copy_remote_py_files(cls, output_dir, copy=True):
|
| 339 |
+
## copy .py and REAMDE for next loading remote code
|
| 340 |
+
current_file_path = os.path.abspath(__file__)
|
| 341 |
+
current_folder = os.path.dirname(current_file_path)
|
| 342 |
+
for file_name in os.listdir(current_folder):
|
| 343 |
+
if file_name == "INSTRUCTIONS.md":
|
| 344 |
+
src_fname = os.path.join(current_folder, file_name)
|
| 345 |
+
dst_fname = os.path.join(output_dir, "README.md")
|
| 346 |
+
if os.path.exists(dst_fname):
|
| 347 |
+
old_reamde = open(dst_fname).read()
|
| 348 |
+
else:
|
| 349 |
+
old_reamde = ""
|
| 350 |
+
with open(src_fname) as src, open(dst_fname, "w") as dst:
|
| 351 |
+
dst.write(src.read())
|
| 352 |
+
dst.write(old_reamde)
|
| 353 |
+
print("[HF remote code] REAMDE ", src_fname, "to", dst_fname)
|
| 354 |
+
if file_name.endswith(".py") or file_name.endswith(".jinja"):
|
| 355 |
+
full_file_name = os.path.join(current_folder, file_name)
|
| 356 |
+
if os.path.isfile(full_file_name):
|
| 357 |
+
if copy:
|
| 358 |
+
shutil.copy(full_file_name, output_dir)
|
| 359 |
+
print("[HF remote code] copying", full_file_name, "to", output_dir)
|
| 360 |
+
else:
|
| 361 |
+
# symlink to ease development
|
| 362 |
+
if os.path.exists(os.path.join(output_dir, file_name)):
|
| 363 |
+
os.remove(os.path.join(output_dir, file_name))
|
| 364 |
+
os.symlink(full_file_name, os.path.join(output_dir, file_name))
|
| 365 |
+
print("[HF remote code] linking", full_file_name, "to", output_dir)
|
| 366 |
+
|
| 367 |
+
def save_pretrained(self, output_dir, state_dict=None, **kwargs):
|
| 368 |
+
if state_dict is None:
|
| 369 |
+
# other wise fetch from deepspeed
|
| 370 |
+
# state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
|
| 371 |
+
state_dict = self.state_dict()
|
| 372 |
+
|
| 373 |
+
if getattr(self, "tokenizer", None):
|
| 374 |
+
self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
|
| 375 |
+
|
| 376 |
+
if self.get_llm():
|
| 377 |
+
print(f"saving llm to {osp.join(output_dir, 'llm')}")
|
| 378 |
+
self.llm.config._name_or_path = osp.join(output_dir, "llm")
|
| 379 |
+
llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k})
|
| 380 |
+
self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict)
|
| 381 |
+
self.config.llm_cfg = self.llm.config
|
| 382 |
+
|
| 383 |
+
if self.get_vision_tower():
|
| 384 |
+
print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}")
|
| 385 |
+
self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower")
|
| 386 |
+
vision_tower_state_dict = OrderedDict(
|
| 387 |
+
{k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k}
|
| 388 |
+
)
|
| 389 |
+
self.vision_tower.vision_tower.save_pretrained(
|
| 390 |
+
os.path.join(output_dir, "vision_tower"),
|
| 391 |
+
state_dict=vision_tower_state_dict,
|
| 392 |
+
)
|
| 393 |
+
self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower"))
|
| 394 |
+
self.config.vision_tower_cfg = self.vision_tower.config
|
| 395 |
+
if hasattr(self.config.vision_tower_cfg, "auto_map"):
|
| 396 |
+
if "radio" not in self.get_vision_tower().__class__.__name__.lower():
|
| 397 |
+
delattr(self.config.vision_tower_cfg, "auto_map")
|
| 398 |
+
|
| 399 |
+
if self.get_mm_projector():
|
| 400 |
+
print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}")
|
| 401 |
+
self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector")
|
| 402 |
+
mm_projector_state_dict = OrderedDict(
|
| 403 |
+
{k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k}
|
| 404 |
+
)
|
| 405 |
+
self.mm_projector.save_pretrained(
|
| 406 |
+
os.path.join(output_dir, "mm_projector"),
|
| 407 |
+
state_dict=mm_projector_state_dict,
|
| 408 |
+
)
|
| 409 |
+
self.config.mm_projector_cfg = self.mm_projector.config
|
| 410 |
+
|
| 411 |
+
## update and save top-level config
|
| 412 |
+
self.config._name_or_path = output_dir
|
| 413 |
+
self.config.architectures = [self.__class__.__name__]
|
| 414 |
+
self.config.save_pretrained(output_dir)
|
| 415 |
+
|
| 416 |
+
## copy .py and REAMDE for next loading remote code
|
| 417 |
+
self.copy_remote_py_files(output_dir)
|
| 418 |
+
|
| 419 |
+
@classmethod
|
| 420 |
+
def from_pretrained(
|
| 421 |
+
cls,
|
| 422 |
+
pretrained_model_name_or_path: Optional[str] = None,
|
| 423 |
+
*model_args,
|
| 424 |
+
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
| 425 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
| 426 |
+
ignore_mismatched_sizes: bool = False,
|
| 427 |
+
force_download: bool = False,
|
| 428 |
+
local_files_only: bool = False,
|
| 429 |
+
token: Optional[Union[str, bool]] = None,
|
| 430 |
+
revision: str = "main",
|
| 431 |
+
use_safetensors: Optional[bool] = None,
|
| 432 |
+
weights_only: bool = True,
|
| 433 |
+
**kwargs,
|
| 434 |
+
):
|
| 435 |
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
|
| 436 |
+
return cls._from_config(config, **kwargs)
|
| 437 |
+
|
| 438 |
+
def init_llm(self, llm_config, config, *args, **kwargs):
|
| 439 |
+
self.llm, self.tokenizer = build_llm_and_tokenizer(llm_config, config, *args, **kwargs)
|
| 440 |
+
# hard coded for NVILA
|
| 441 |
+
# variables for XGrammar
|
| 442 |
+
NUM_EXTRA_TOKENS = len(self.tokenizer.added_tokens_encoder.keys())
|
| 443 |
+
|
| 444 |
+
self.pad_token_list = (
|
| 445 |
+
self.tokenizer.pad_token_id,
|
| 446 |
+
self.tokenizer.eos_token_id,
|
| 447 |
+
self.tokenizer.tokenize("<|endoftext|>")[0], # for qwen
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
self.vocab_size = self.tokenizer.vocab_size + NUM_EXTRA_TOKENS
|
| 451 |
+
# XGrammar tokenizer and grammar compiler
|
| 452 |
+
# lazy init only when specified json output during inference
|
| 453 |
+
self.grammar_compiler = None
|
| 454 |
+
self.llm.resize_token_embeddings(len(self.tokenizer))
|
| 455 |
+
return self.llm, self.tokenizer
|
| 456 |
+
|
| 457 |
+
def post_config(self):
|
| 458 |
+
######################################################################
|
| 459 |
+
self.llm = self.llm.to(torch.float16)
|
| 460 |
+
self.mm_projector = self.mm_projector.to(torch.float16)
|
| 461 |
+
self.vision_tower = self.vision_tower.to(torch.float16)
|
| 462 |
+
######################################################################
|
| 463 |
+
self.training = self.llm.training
|
| 464 |
+
if self.training:
|
| 465 |
+
self.train()
|
| 466 |
+
else:
|
| 467 |
+
self.eval()
|
| 468 |
+
## configuration
|
| 469 |
+
if getattr(self.config, "llm_cfg", None) is None:
|
| 470 |
+
self.config.llm_cfg = self.llm.config
|
| 471 |
+
if getattr(self.config, "vision_tower_cfg", None) is None:
|
| 472 |
+
self.config.vision_tower_cfg = self.vision_tower.config
|
| 473 |
+
if getattr(self.config, "mm_projector_cfg", None) is None:
|
| 474 |
+
self.config.mm_projector_cfg = self.mm_projector.config
|
| 475 |
+
|
| 476 |
+
def get_llm(self):
|
| 477 |
+
llm = getattr(self, "llm", None)
|
| 478 |
+
if type(llm) is list:
|
| 479 |
+
llm = llm[0]
|
| 480 |
+
return llm
|
| 481 |
+
|
| 482 |
+
def get_lm_head(self):
|
| 483 |
+
lm_head = getattr(self.get_llm(), "lm_head", None)
|
| 484 |
+
return lm_head
|
| 485 |
+
|
| 486 |
+
def get_vision_tower(self):
|
| 487 |
+
vision_tower = getattr(self, "vision_tower", None)
|
| 488 |
+
if type(vision_tower) is list:
|
| 489 |
+
vision_tower = vision_tower[0]
|
| 490 |
+
return vision_tower
|
| 491 |
+
|
| 492 |
+
def get_mm_projector(self):
|
| 493 |
+
mm_projector = getattr(self, "mm_projector", None)
|
| 494 |
+
if type(mm_projector) is list:
|
| 495 |
+
mm_projector = mm_projector[0]
|
| 496 |
+
return mm_projector
|
| 497 |
+
|
| 498 |
+
def freezed_module_patch(self):
|
| 499 |
+
"""
|
| 500 |
+
Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules.
|
| 501 |
+
"""
|
| 502 |
+
if self.training:
|
| 503 |
+
if self.get_llm() and not getattr(self.config, "tune_language_model", False):
|
| 504 |
+
pass
|
| 505 |
+
# logging.warning("Caution: Your LLM is currently in training mode, ensuring accurate gradient computation. Please be vigilant, particularly regarding BatchNorm and Dropout operations.")
|
| 506 |
+
if self.get_vision_tower() and not getattr(self.config, "tune_vision_tower", False):
|
| 507 |
+
self.get_vision_tower().eval()
|
| 508 |
+
if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False):
|
| 509 |
+
self.get_mm_projector().eval()
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
class VILAForCausalLM(VILAPretrainedModel):
|
| 513 |
+
def __init__(self, config: VILAConfig, *args, **kwargs):
|
| 514 |
+
super().__init__(config, *args, **kwargs)
|
| 515 |
+
|
| 516 |
+
def merge_features_for_dynamic_s2(self, image_features, block_sizes):
|
| 517 |
+
scales = self.get_vision_tower().scales
|
| 518 |
+
resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx
|
| 519 |
+
|
| 520 |
+
image_features_each_image = []
|
| 521 |
+
new_block_sizes = []
|
| 522 |
+
block_cnt = 0
|
| 523 |
+
for block_size_each_image in block_sizes:
|
| 524 |
+
if block_size_each_image is None:
|
| 525 |
+
cur_features = image_features[block_cnt : block_cnt + 1]
|
| 526 |
+
cur_features = rearrange(cur_features, "1 (h w) c -> 1 c h w", h=int(cur_features.shape[1] ** 0.5))
|
| 527 |
+
cur_features = cur_features.repeat(1, len(scales), 1, 1)
|
| 528 |
+
image_features_each_image.append(cur_features)
|
| 529 |
+
new_block_sizes.append((1, 1))
|
| 530 |
+
block_cnt += 1
|
| 531 |
+
else:
|
| 532 |
+
cur_features_each_scale = []
|
| 533 |
+
for scale in scales[:-1]:
|
| 534 |
+
num_blocks_this_scale = (scale // scales[0]) ** 2
|
| 535 |
+
cur_features_each_scale.append(
|
| 536 |
+
self.merge_chessboard(
|
| 537 |
+
image_features[block_cnt : block_cnt + num_blocks_this_scale],
|
| 538 |
+
num_split_h=scale // scales[0],
|
| 539 |
+
num_split_w=scale // scales[0],
|
| 540 |
+
)
|
| 541 |
+
) # 1 * C * H * W
|
| 542 |
+
block_cnt += num_blocks_this_scale
|
| 543 |
+
num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1]
|
| 544 |
+
cur_features_each_scale.append(
|
| 545 |
+
self.merge_chessboard(
|
| 546 |
+
image_features[block_cnt : block_cnt + num_blocks_last_scale],
|
| 547 |
+
num_split_h=block_size_each_image[0],
|
| 548 |
+
num_split_w=block_size_each_image[1],
|
| 549 |
+
)
|
| 550 |
+
) # 1 * C * H * W
|
| 551 |
+
block_cnt += num_blocks_last_scale
|
| 552 |
+
|
| 553 |
+
# resize and concat features from different scales
|
| 554 |
+
output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:]
|
| 555 |
+
cur_features = torch.cat(
|
| 556 |
+
[
|
| 557 |
+
F.interpolate(cur_features_each_scale[i].to(torch.float32), size=output_size, mode="area").to(
|
| 558 |
+
cur_features_each_scale[i].dtype
|
| 559 |
+
)
|
| 560 |
+
for i in range(len(cur_features_each_scale))
|
| 561 |
+
],
|
| 562 |
+
dim=1,
|
| 563 |
+
)
|
| 564 |
+
# cur_features = rearrange(cur_features, "1 c h w -> (h w) c")
|
| 565 |
+
|
| 566 |
+
image_features_each_image.append(cur_features)
|
| 567 |
+
|
| 568 |
+
if resize_output_to_scale_idx == len(scales) - 1 or resize_output_to_scale_idx == -1:
|
| 569 |
+
new_block_sizes.append(block_size_each_image)
|
| 570 |
+
else:
|
| 571 |
+
new_block_sizes.append(
|
| 572 |
+
(
|
| 573 |
+
scales[resize_output_to_scale_idx] // scales[0],
|
| 574 |
+
scales[resize_output_to_scale_idx] // scales[0],
|
| 575 |
+
)
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
assert block_cnt == len(image_features)
|
| 579 |
+
|
| 580 |
+
return image_features_each_image, new_block_sizes
|
| 581 |
+
|
| 582 |
+
def encode_images(self, images, block_sizes: Optional[Optional[Tuple[int, ...]]] = None):
|
| 583 |
+
if block_sizes is None:
|
| 584 |
+
block_sizes = [None] * len(images)
|
| 585 |
+
if getattr(self.config, "dynamic_s2", False):
|
| 586 |
+
image_features = self.get_vision_tower()(images)
|
| 587 |
+
image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes)
|
| 588 |
+
|
| 589 |
+
image_features = [
|
| 590 |
+
self.split_chessboard(x, block_size[0], block_size[1])
|
| 591 |
+
for x, block_size in zip(image_features, new_block_sizes)
|
| 592 |
+
] # list of B * C * H * W tensors
|
| 593 |
+
image_features = torch.cat(
|
| 594 |
+
[rearrange(x, "b c h w -> b (h w) c") for x in image_features], dim=0
|
| 595 |
+
) # B * N * C
|
| 596 |
+
image_features = self.get_mm_projector()(image_features)
|
| 597 |
+
image_features = list(
|
| 598 |
+
image_features.split([block_size[0] * block_size[1] for block_size in new_block_sizes], dim=0)
|
| 599 |
+
)
|
| 600 |
+
image_features = [
|
| 601 |
+
self.merge_chessboard(x, block_size[0], block_size[1])
|
| 602 |
+
for x, block_size in zip(image_features, new_block_sizes)
|
| 603 |
+
] # list of 1 * C * H * W tensors
|
| 604 |
+
image_features = [rearrange(x, "1 c h w -> (h w) c") for x in image_features] # list of N * C tensors
|
| 605 |
+
if all([feature.shape[0] == image_features[0].shape[0] for feature in image_features]):
|
| 606 |
+
image_features = torch.stack(image_features, dim=0)
|
| 607 |
+
else:
|
| 608 |
+
image_features = self.get_vision_tower()(images)
|
| 609 |
+
image_features = self.get_mm_projector()(image_features)
|
| 610 |
+
return image_features
|
| 611 |
+
|
| 612 |
+
def train(self, mode: bool = True):
|
| 613 |
+
super().train(mode)
|
| 614 |
+
return self
|
| 615 |
+
|
| 616 |
+
@torch.inference_mode()
|
| 617 |
+
def _embed(
|
| 618 |
+
self,
|
| 619 |
+
input_ids: torch.Tensor,
|
| 620 |
+
media: Dict[str, List[torch.Tensor]],
|
| 621 |
+
media_config: Dict[str, Dict[str, Any]],
|
| 622 |
+
labels: Optional[torch.Tensor],
|
| 623 |
+
attention_mask: Optional[torch.Tensor],
|
| 624 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 625 |
+
media = copy.deepcopy(media)
|
| 626 |
+
media_config = copy.deepcopy(media_config)
|
| 627 |
+
|
| 628 |
+
labels = labels if labels is not None else torch.full_like(input_ids, IGNORE_INDEX)
|
| 629 |
+
attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool)
|
| 630 |
+
|
| 631 |
+
PROCESS_GROUP_MANAGER = get_pg_manager()
|
| 632 |
+
if PROCESS_GROUP_MANAGER is not None:
|
| 633 |
+
for name in media:
|
| 634 |
+
self.encoders[name].end_tokens = None
|
| 635 |
+
|
| 636 |
+
# Extract text and media embeddings
|
| 637 |
+
text_embeds = self.llm_model_embed_tokens(input_ids)
|
| 638 |
+
|
| 639 |
+
use_cache = False
|
| 640 |
+
if "use_cache" in media_config:
|
| 641 |
+
use_cache = media_config.pop("use_cache")
|
| 642 |
+
|
| 643 |
+
if use_cache:
|
| 644 |
+
print("Use cached embedding")
|
| 645 |
+
if media is not None:
|
| 646 |
+
media_embeds = media if use_cache else self.__embed_media_tokens(media, media_config)
|
| 647 |
+
else:
|
| 648 |
+
# no media was provided, so we just return an empty dict
|
| 649 |
+
media_embeds = {}
|
| 650 |
+
|
| 651 |
+
# This is a workaround to make sure the dummy embeddings are consumed
|
| 652 |
+
while media_embeds.get("dummy"):
|
| 653 |
+
dummy_embed = media_embeds["dummy"].popleft()
|
| 654 |
+
text_embeds += torch.sum(dummy_embed) * 0
|
| 655 |
+
|
| 656 |
+
# Remove padding
|
| 657 |
+
batch_size = labels.shape[0]
|
| 658 |
+
text_embeds = [text_embeds[k][attention_mask[k]] for k in range(batch_size)]
|
| 659 |
+
labels = [labels[k][attention_mask[k]] for k in range(batch_size)]
|
| 660 |
+
|
| 661 |
+
# Build inverse mapping from token ID to media name
|
| 662 |
+
media_tokens = {}
|
| 663 |
+
for name, token_id in self.tokenizer.media_token_ids.items():
|
| 664 |
+
media_tokens[token_id] = name
|
| 665 |
+
|
| 666 |
+
# Fuse text and media embeddings
|
| 667 |
+
inputs_m, labels_m = [], []
|
| 668 |
+
for k in range(batch_size):
|
| 669 |
+
inputs_mk, labels_mk = [], []
|
| 670 |
+
pos = 0
|
| 671 |
+
while pos < len(labels[k]):
|
| 672 |
+
if input_ids[k][pos].item() in media_tokens:
|
| 673 |
+
end = pos + 1
|
| 674 |
+
name = media_tokens[input_ids[k][pos].item()]
|
| 675 |
+
input = media_embeds[name].popleft()
|
| 676 |
+
label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype)
|
| 677 |
+
elif input_ids[k][pos].item() in self.pad_token_list:
|
| 678 |
+
# skip pad tokens
|
| 679 |
+
end = pos + 1
|
| 680 |
+
pos = end
|
| 681 |
+
continue
|
| 682 |
+
else:
|
| 683 |
+
end = pos
|
| 684 |
+
while end < len(labels[k]) and input_ids[k][end].item() not in media_tokens:
|
| 685 |
+
end += 1
|
| 686 |
+
input = text_embeds[k][pos:end]
|
| 687 |
+
label = labels[k][pos:end]
|
| 688 |
+
|
| 689 |
+
inputs_mk.append(input)
|
| 690 |
+
labels_mk.append(label)
|
| 691 |
+
pos = end
|
| 692 |
+
inputs_m.append(torch.cat(inputs_mk, dim=0))
|
| 693 |
+
labels_m.append(torch.cat(labels_mk, dim=0))
|
| 694 |
+
inputs, labels = inputs_m, labels_m
|
| 695 |
+
|
| 696 |
+
# Check if all media embeddings are consumed
|
| 697 |
+
for name in media_embeds:
|
| 698 |
+
if media_embeds[name]:
|
| 699 |
+
raise ValueError(f"Not all {name} embeddings are consumed! Still {len(media_embeds[name])} left.")
|
| 700 |
+
|
| 701 |
+
# Truncate sequences to `model_max_length` as media embeddings are inserted
|
| 702 |
+
inputs, labels = self.__truncate_sequence(inputs, labels)
|
| 703 |
+
|
| 704 |
+
# Pad sequences to the longest one in the batch
|
| 705 |
+
return self.__batchify_sequence(inputs, labels)
|
| 706 |
+
|
| 707 |
+
def __embed_media_tokens(
|
| 708 |
+
self,
|
| 709 |
+
media: Dict[str, List[torch.Tensor]],
|
| 710 |
+
media_config: Dict[str, Dict[str, Any]],
|
| 711 |
+
) -> Dict[str, List[torch.Tensor]]:
|
| 712 |
+
embeds = defaultdict(deque)
|
| 713 |
+
for name in media:
|
| 714 |
+
if self.training:
|
| 715 |
+
# Gather metainfo of media objects from all ranks
|
| 716 |
+
info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])]
|
| 717 |
+
infos = list(chain(vila_all_gather(info)))
|
| 718 |
+
|
| 719 |
+
# The entire batch does not contain any media objects of this type.
|
| 720 |
+
if not infos:
|
| 721 |
+
continue
|
| 722 |
+
|
| 723 |
+
# Create a dummy tensor to ensure the encoder is called, otherwise the training will hang.
|
| 724 |
+
if media.get(name) is None or len(media[name]) == 0:
|
| 725 |
+
dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device)
|
| 726 |
+
embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name]))
|
| 727 |
+
continue
|
| 728 |
+
embeds[name] = deque(self.encoders[name](media[name], media_config[name]))
|
| 729 |
+
return embeds
|
| 730 |
+
|
| 731 |
+
def __truncate_sequence(
|
| 732 |
+
self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
|
| 733 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 734 |
+
if self.training and any(len(input) > self.tokenizer.model_max_length for input in inputs):
|
| 735 |
+
warnings.warn(f"Truncating sequences to `model_max_length` ({self.tokenizer.model_max_length}).")
|
| 736 |
+
inputs = [input[: self.tokenizer.model_max_length] for input in inputs]
|
| 737 |
+
labels = [label[: self.tokenizer.model_max_length] for label in labels]
|
| 738 |
+
return inputs, labels
|
| 739 |
+
|
| 740 |
+
def __batchify_sequence(
|
| 741 |
+
self, inputs: List[torch.Tensor], labels: List[torch.Tensor]
|
| 742 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 743 |
+
batch_size = len(inputs)
|
| 744 |
+
device = inputs[0].device
|
| 745 |
+
hidden_size = inputs[0].shape[1]
|
| 746 |
+
max_length = max(inputs[k].shape[0] for k in range(batch_size))
|
| 747 |
+
attention_mask = torch.ones((batch_size, max_length), dtype=torch.bool, device=device)
|
| 748 |
+
|
| 749 |
+
inputs_p, labels_p = [], []
|
| 750 |
+
for k in range(batch_size):
|
| 751 |
+
size_pk = max_length - inputs[k].shape[0]
|
| 752 |
+
inputs_pk = torch.zeros((size_pk, hidden_size), dtype=inputs[k].dtype, device=device)
|
| 753 |
+
labels_pk = torch.full((size_pk,), IGNORE_INDEX, dtype=labels[k].dtype, device=device)
|
| 754 |
+
if self.tokenizer.padding_side == "right":
|
| 755 |
+
attention_mask[k, inputs[k].shape[0] :] = False
|
| 756 |
+
inputs_pk = torch.cat([inputs[k], inputs_pk], dim=0)
|
| 757 |
+
labels_pk = torch.cat([labels[k], labels_pk], dim=0)
|
| 758 |
+
else:
|
| 759 |
+
attention_mask[k, : -inputs[k].shape[0]] = False
|
| 760 |
+
inputs_pk = torch.cat([inputs_pk, inputs[k]], dim=0)
|
| 761 |
+
labels_pk = torch.cat([labels_pk, labels[k]], dim=0)
|
| 762 |
+
inputs_p.append(inputs_pk)
|
| 763 |
+
labels_p.append(labels_pk)
|
| 764 |
+
|
| 765 |
+
inputs = torch.stack(inputs_p, dim=0)
|
| 766 |
+
labels = torch.stack(labels_p, dim=0)
|
| 767 |
+
return inputs, labels, attention_mask
|
| 768 |
+
|
| 769 |
+
def repack_multimodal_data(self, inputs_embeds, attention_mask, position_ids, labels):
|
| 770 |
+
# Handle sequence parallelism
|
| 771 |
+
PROCESS_GROUP_MANAGER = get_pg_manager()
|
| 772 |
+
|
| 773 |
+
# We do re-sharding instead of packing here to ensure the sequence length is the same across all ranks.
|
| 774 |
+
if PROCESS_GROUP_MANAGER is not None:
|
| 775 |
+
sp_degree = PROCESS_GROUP_MANAGER.sp_degree
|
| 776 |
+
sp_rank = PROCESS_GROUP_MANAGER.sp_rank
|
| 777 |
+
sp_group = PROCESS_GROUP_MANAGER.sp_pg
|
| 778 |
+
ring_degree = PROCESS_GROUP_MANAGER.ring_degree
|
| 779 |
+
ring_rank = PROCESS_GROUP_MANAGER.ring_rank
|
| 780 |
+
ring_type = PROCESS_GROUP_MANAGER.ring_type
|
| 781 |
+
ulysses_degree = PROCESS_GROUP_MANAGER.ulysses_degree
|
| 782 |
+
ulysses_rank = PROCESS_GROUP_MANAGER.ulysses_rank
|
| 783 |
+
|
| 784 |
+
bs, shard_seqlen = position_ids.shape
|
| 785 |
+
sp_seq_len = [torch.zeros(1, dtype=torch.int64, device=position_ids.device) for _ in range(sp_degree)]
|
| 786 |
+
dist.all_gather(sp_seq_len, torch.tensor(shard_seqlen, device=position_ids.device), group=sp_group)
|
| 787 |
+
sp_seq_len_cat = torch.cat(sp_seq_len, dim=0)
|
| 788 |
+
|
| 789 |
+
if sp_rank == 0:
|
| 790 |
+
original_start_id = 0
|
| 791 |
+
else:
|
| 792 |
+
original_start_id = torch.sum(sp_seq_len_cat[:sp_rank]).item()
|
| 793 |
+
original_end_id = torch.sum(sp_seq_len_cat[: sp_rank + 1]).item()
|
| 794 |
+
|
| 795 |
+
# Gather attention_mask, position_ids, labels and input_embeds
|
| 796 |
+
all_inputs_embeds = torch.zeros(
|
| 797 |
+
bs,
|
| 798 |
+
torch.sum(sp_seq_len_cat),
|
| 799 |
+
inputs_embeds.shape[-1],
|
| 800 |
+
dtype=inputs_embeds.dtype,
|
| 801 |
+
device=inputs_embeds.device,
|
| 802 |
+
).contiguous()
|
| 803 |
+
all_inputs_embeds[:, original_start_id:original_end_id, :] += inputs_embeds
|
| 804 |
+
dist.barrier(group=sp_group)
|
| 805 |
+
dist.all_reduce(all_inputs_embeds, group=sp_group)
|
| 806 |
+
dist.barrier(group=sp_group)
|
| 807 |
+
|
| 808 |
+
attention_mask_list = [
|
| 809 |
+
torch.zeros((bs, sp_seq_len[i]), dtype=attention_mask.dtype, device=attention_mask.device)
|
| 810 |
+
for i in range(sp_degree)
|
| 811 |
+
]
|
| 812 |
+
position_ids_list = [
|
| 813 |
+
torch.zeros((bs, sp_seq_len[i]), dtype=position_ids.dtype, device=position_ids.device)
|
| 814 |
+
for i in range(sp_degree)
|
| 815 |
+
]
|
| 816 |
+
labels_list = [
|
| 817 |
+
torch.zeros((bs, sp_seq_len[i]), dtype=labels.dtype, device=labels.device) for i in range(sp_degree)
|
| 818 |
+
]
|
| 819 |
+
|
| 820 |
+
dist.all_gather(attention_mask_list, attention_mask, group=sp_group)
|
| 821 |
+
dist.all_gather(position_ids_list, position_ids, group=sp_group)
|
| 822 |
+
dist.all_gather(labels_list, labels, group=sp_group)
|
| 823 |
+
|
| 824 |
+
effective_seqlen_list = [attention_mask_list[i].sum(dim=-1) for i in range(sp_degree)]
|
| 825 |
+
effective_seqlen = torch.stack(effective_seqlen_list, dim=-1)
|
| 826 |
+
effective_seqlen_batch_list = torch.unbind(effective_seqlen, dim=0)
|
| 827 |
+
|
| 828 |
+
global_attention_mask_list = []
|
| 829 |
+
global_position_ids_list = []
|
| 830 |
+
global_labels_list = []
|
| 831 |
+
global_inputs_embeds_list = []
|
| 832 |
+
for i in range(bs):
|
| 833 |
+
global_attention_mask_batch_list = []
|
| 834 |
+
global_position_ids_batch_list = []
|
| 835 |
+
global_labels_batch_list = []
|
| 836 |
+
global_inputs_embeds_batch_list = []
|
| 837 |
+
for j in range(sp_degree):
|
| 838 |
+
eff_len = effective_seqlen_batch_list[i][j]
|
| 839 |
+
prev_len = torch.sum(sp_seq_len_cat[:j]).item() if j > 0 else 0
|
| 840 |
+
|
| 841 |
+
global_attention_mask_batch_list.append(attention_mask_list[j][i, :eff_len])
|
| 842 |
+
global_position_ids_batch_list.append(position_ids_list[j][i, :eff_len])
|
| 843 |
+
global_labels_batch_list.append(labels_list[j][i, :eff_len])
|
| 844 |
+
global_inputs_embeds_batch_list.append(all_inputs_embeds[i, prev_len : prev_len + eff_len, :])
|
| 845 |
+
global_attention_mask_list.append(torch.cat(global_attention_mask_batch_list, dim=0))
|
| 846 |
+
global_position_ids_list.append(torch.cat(global_position_ids_batch_list, dim=0))
|
| 847 |
+
global_labels_list.append(torch.cat(global_labels_batch_list, dim=0))
|
| 848 |
+
global_inputs_embeds_list.append(torch.cat(global_inputs_embeds_batch_list, dim=0))
|
| 849 |
+
|
| 850 |
+
global_attention_mask = torch.nn.utils.rnn.pad_sequence(
|
| 851 |
+
global_attention_mask_list, batch_first=True, padding_value=False
|
| 852 |
+
)
|
| 853 |
+
global_position_ids = torch.nn.utils.rnn.pad_sequence(
|
| 854 |
+
global_position_ids_list, batch_first=True, padding_value=-1
|
| 855 |
+
)
|
| 856 |
+
global_labels = torch.nn.utils.rnn.pad_sequence(
|
| 857 |
+
global_labels_list, batch_first=True, padding_value=IGNORE_INDEX
|
| 858 |
+
)
|
| 859 |
+
global_inputs_embeds = torch.nn.utils.rnn.pad_sequence(
|
| 860 |
+
global_inputs_embeds_list, batch_first=True, padding_value=0
|
| 861 |
+
)
|
| 862 |
+
|
| 863 |
+
# Re-shard the inputs
|
| 864 |
+
if ring_degree > 1:
|
| 865 |
+
total_effective_seqlen = torch.sum(effective_seqlen, dim=1)
|
| 866 |
+
new_seqlen_per_rank = total_effective_seqlen // sp_degree
|
| 867 |
+
assert torch.all(
|
| 868 |
+
total_effective_seqlen % sp_degree == 0
|
| 869 |
+
), "total_effective_seqlen must be divisible by sp_degree"
|
| 870 |
+
|
| 871 |
+
max_new_seqlen = torch.max(new_seqlen_per_rank).item()
|
| 872 |
+
|
| 873 |
+
new_attention_mask = torch.zeros(
|
| 874 |
+
(bs, max_new_seqlen), dtype=global_attention_mask.dtype, device=global_attention_mask.device
|
| 875 |
+
)
|
| 876 |
+
new_position_ids = torch.zeros(
|
| 877 |
+
(bs, max_new_seqlen), dtype=global_position_ids.dtype, device=global_position_ids.device
|
| 878 |
+
)
|
| 879 |
+
new_labels = torch.full(
|
| 880 |
+
(bs, max_new_seqlen), IGNORE_INDEX, dtype=global_labels.dtype, device=global_labels.device
|
| 881 |
+
)
|
| 882 |
+
new_inputs_embeds = torch.zeros(
|
| 883 |
+
(bs, max_new_seqlen, global_inputs_embeds.shape[-1]),
|
| 884 |
+
dtype=global_inputs_embeds.dtype,
|
| 885 |
+
device=global_inputs_embeds.device,
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
if ring_type == "ring_varlen":
|
| 889 |
+
for i in range(bs):
|
| 890 |
+
start_idx = new_seqlen_per_rank[i] * sp_rank
|
| 891 |
+
end_idx = start_idx + new_seqlen_per_rank[i]
|
| 892 |
+
new_attention_mask[i, : new_seqlen_per_rank[i]] = global_attention_mask[i, start_idx:end_idx]
|
| 893 |
+
new_position_ids[i, : new_seqlen_per_rank[i]] = global_position_ids[i, start_idx:end_idx]
|
| 894 |
+
new_labels[i, : new_seqlen_per_rank[i]] = global_labels[i, start_idx:end_idx]
|
| 895 |
+
new_inputs_embeds[i, : new_seqlen_per_rank[i], :] = global_inputs_embeds[
|
| 896 |
+
i, start_idx:end_idx, :
|
| 897 |
+
]
|
| 898 |
+
elif ring_type == "zigzag_ring_varlen":
|
| 899 |
+
chunk_size = total_effective_seqlen // (2 * sp_degree)
|
| 900 |
+
for i in range(bs):
|
| 901 |
+
# Zigzag pattern indices
|
| 902 |
+
if sp_degree == ring_degree:
|
| 903 |
+
forward_rank_idx = sp_rank
|
| 904 |
+
backward_rank_idx = 2 * sp_degree - sp_rank - 1
|
| 905 |
+
else:
|
| 906 |
+
ulysses_offset = ulysses_rank * ring_degree * 2
|
| 907 |
+
forward_rank_idx = ring_rank + ulysses_offset
|
| 908 |
+
backward_rank_idx = sp_degree - ring_rank - 1 + ulysses_offset
|
| 909 |
+
|
| 910 |
+
# Calculate start and end indices for the forward and backward zigzag
|
| 911 |
+
start_idx_fwd = forward_rank_idx * chunk_size[i]
|
| 912 |
+
end_idx_fwd = start_idx_fwd + chunk_size[i]
|
| 913 |
+
|
| 914 |
+
start_idx_bwd = backward_rank_idx * chunk_size[i]
|
| 915 |
+
end_idx_bwd = start_idx_bwd + chunk_size[i]
|
| 916 |
+
|
| 917 |
+
# Fill new tensors with zigzag data
|
| 918 |
+
new_attention_mask[i, : chunk_size[i]] = global_attention_mask[i, start_idx_fwd:end_idx_fwd]
|
| 919 |
+
new_attention_mask[i, chunk_size[i] : 2 * chunk_size[i]] = global_attention_mask[
|
| 920 |
+
i, start_idx_bwd:end_idx_bwd
|
| 921 |
+
]
|
| 922 |
+
|
| 923 |
+
new_position_ids[i, : chunk_size[i]] = global_position_ids[i, start_idx_fwd:end_idx_fwd]
|
| 924 |
+
new_position_ids[i, chunk_size[i] : 2 * chunk_size[i]] = global_position_ids[
|
| 925 |
+
i, start_idx_bwd:end_idx_bwd
|
| 926 |
+
]
|
| 927 |
+
|
| 928 |
+
new_labels[i, : chunk_size[i]] = global_labels[i, start_idx_fwd:end_idx_fwd]
|
| 929 |
+
new_labels[i, chunk_size[i] : 2 * chunk_size[i]] = global_labels[i, start_idx_bwd:end_idx_bwd]
|
| 930 |
+
|
| 931 |
+
new_inputs_embeds[i, : chunk_size[i], :] = global_inputs_embeds[i, start_idx_fwd:end_idx_fwd, :]
|
| 932 |
+
new_inputs_embeds[i, chunk_size[i] : 2 * chunk_size[i], :] = global_inputs_embeds[
|
| 933 |
+
i, start_idx_bwd:end_idx_bwd, :
|
| 934 |
+
]
|
| 935 |
+
else:
|
| 936 |
+
raise ValueError(f"Invalid ring_type: {ring_type}")
|
| 937 |
+
else:
|
| 938 |
+
global_seq_len = global_attention_mask.shape[-1]
|
| 939 |
+
seq_len_sharded = global_seq_len // sp_degree
|
| 940 |
+
start_idx_reshard = seq_len_sharded * sp_rank
|
| 941 |
+
end_idx_reshard = start_idx_reshard + seq_len_sharded if sp_rank < sp_degree - 1 else global_seq_len
|
| 942 |
+
|
| 943 |
+
new_attention_mask = torch.narrow(
|
| 944 |
+
global_attention_mask, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
|
| 945 |
+
)
|
| 946 |
+
new_position_ids = torch.narrow(
|
| 947 |
+
global_position_ids, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
|
| 948 |
+
)
|
| 949 |
+
new_labels = torch.narrow(global_labels, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)
|
| 950 |
+
new_inputs_embeds = torch.narrow(
|
| 951 |
+
global_inputs_embeds, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard
|
| 952 |
+
)
|
| 953 |
+
|
| 954 |
+
return new_inputs_embeds, new_attention_mask, new_position_ids, new_labels
|
| 955 |
+
|
| 956 |
+
device = inputs_embeds.device
|
| 957 |
+
batch_size = inputs_embeds.shape[0]
|
| 958 |
+
seqlens = [attention_mask[k].sum().item() for k in range(batch_size)]
|
| 959 |
+
|
| 960 |
+
# Pack all sequences together
|
| 961 |
+
inputs_embeds_p = [inputs_embeds[k][attention_mask[k]] for k in range(batch_size)]
|
| 962 |
+
attention_mask_p = [torch.ones(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
|
| 963 |
+
position_ids_p = [torch.arange(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)]
|
| 964 |
+
labels_p = [labels[k][attention_mask[k]] for k in range(batch_size)]
|
| 965 |
+
|
| 966 |
+
# Add one dummy token at the end of the packed sequence to ensure that `_get_unpacked_data` will be called
|
| 967 |
+
inputs_embeds_p.append(torch.zeros(1, inputs_embeds.shape[-1], dtype=inputs_embeds.dtype, device=device))
|
| 968 |
+
attention_mask_p.append(torch.tensor([0], dtype=torch.int, device=device))
|
| 969 |
+
position_ids_p.append(torch.tensor([0], dtype=torch.int, device=device))
|
| 970 |
+
labels_p.append(torch.tensor([IGNORE_INDEX], dtype=torch.int, device=device))
|
| 971 |
+
|
| 972 |
+
# Mask the first token of each sequence to avoid contamination
|
| 973 |
+
for label in labels_p:
|
| 974 |
+
label[0] = IGNORE_INDEX
|
| 975 |
+
|
| 976 |
+
# Batch the data
|
| 977 |
+
inputs_embeds_p = torch.cat(inputs_embeds_p, dim=0).unsqueeze(0)
|
| 978 |
+
attention_mask_p = torch.cat(attention_mask_p, dim=0).unsqueeze(0)
|
| 979 |
+
position_ids_p = torch.cat(position_ids_p, dim=0).unsqueeze(0)
|
| 980 |
+
labels_p = torch.cat(labels_p, dim=0).unsqueeze(0)
|
| 981 |
+
|
| 982 |
+
if hasattr(
|
| 983 |
+
self, "pad_to_multiple_of"
|
| 984 |
+
): # related to quantization, please refer to ModelArguments for more information.
|
| 985 |
+
assert len(labels_p.shape) == 2
|
| 986 |
+
batch_size, max_length, cur_length = labels_p.shape[0], labels_p.shape[1], labels_p.shape[1]
|
| 987 |
+
hidden_size = inputs_embeds_p.shape[-1]
|
| 988 |
+
|
| 989 |
+
if max_length % self.pad_to_multiple_of != 0:
|
| 990 |
+
max_length = ((max_length // self.pad_to_multiple_of) + 1) * self.pad_to_multiple_of
|
| 991 |
+
difference = max_length - cur_length
|
| 992 |
+
|
| 993 |
+
inputs_embeds_p = torch.cat(
|
| 994 |
+
(
|
| 995 |
+
inputs_embeds_p,
|
| 996 |
+
torch.full((batch_size, difference, hidden_size), self.llm.pad_token_id).to(inputs_embeds_p),
|
| 997 |
+
),
|
| 998 |
+
dim=1,
|
| 999 |
+
)
|
| 1000 |
+
labels_p = torch.cat((labels_p, torch.full((batch_size, difference), IGNORE_INDEX).to(labels_p)), dim=1)
|
| 1001 |
+
attention_mask_p = torch.cat(
|
| 1002 |
+
(
|
| 1003 |
+
attention_mask_p,
|
| 1004 |
+
torch.zeros((batch_size, difference), dtype=torch.bool).to(attention_mask_p),
|
| 1005 |
+
),
|
| 1006 |
+
dim=1,
|
| 1007 |
+
)
|
| 1008 |
+
position_ids_p = torch.cat(
|
| 1009 |
+
(position_ids_p, torch.full((batch_size, difference), -1).to(position_ids_p)), dim=1
|
| 1010 |
+
)
|
| 1011 |
+
|
| 1012 |
+
return inputs_embeds_p, attention_mask_p, position_ids_p, labels_p
|
| 1013 |
+
|
| 1014 |
+
def get_xgr_logits_processor(self, response_format) -> List[LogitsProcessor]:
|
| 1015 |
+
raise NotImplementedError("This method is not implemented for VILA model.")
|
| 1016 |
+
# Convert response format to logits processor
|
| 1017 |
+
import xgrammar as xgr
|
| 1018 |
+
|
| 1019 |
+
logging.info("[XGrammar] Compiling grammar for contrained output")
|
| 1020 |
+
|
| 1021 |
+
if self.grammar_compiler is None:
|
| 1022 |
+
# logging.info(f"[XGrammar] {self.tokenizer}, {self.tokenizer.vocab_size}, {self.vocab_size}")
|
| 1023 |
+
self.grammar_compiler = xgr.GrammarCompiler(
|
| 1024 |
+
xgr.TokenizerInfo.from_huggingface(self.tokenizer, vocab_size=self.vocab_size)
|
| 1025 |
+
)
|
| 1026 |
+
|
| 1027 |
+
if response_format.type == "json_schema":
|
| 1028 |
+
compiled_grammar = self.grammar_compiler.compile_json_schema(
|
| 1029 |
+
response_format.json_schema.schema_,
|
| 1030 |
+
indent=2,
|
| 1031 |
+
)
|
| 1032 |
+
else:
|
| 1033 |
+
compiled_grammar = self.grammar_compiler.compile_builtin_json_grammar()
|
| 1034 |
+
|
| 1035 |
+
return [xgr.contrib.hf.LogitsProcessor(compiled_grammar)]
|
| 1036 |
+
|
| 1037 |
+
def forward(
|
| 1038 |
+
self,
|
| 1039 |
+
input_ids: torch.LongTensor = None,
|
| 1040 |
+
media: Optional[Dict[str, List[torch.Tensor]]] = None,
|
| 1041 |
+
images: Optional[torch.FloatTensor] = None,
|
| 1042 |
+
media_config: Optional[List] = None,
|
| 1043 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1044 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1045 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1046 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1047 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1048 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1049 |
+
packing: bool = True,
|
| 1050 |
+
force_packing: bool = False,
|
| 1051 |
+
seqlens_in_batch: Optional[torch.LongTensor] = None,
|
| 1052 |
+
dpo_forward: bool = False,
|
| 1053 |
+
**kwargs,
|
| 1054 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 1055 |
+
self.freezed_module_patch()
|
| 1056 |
+
|
| 1057 |
+
if images is not None:
|
| 1058 |
+
if media is not None:
|
| 1059 |
+
raise ValueError("Both 'media' and 'images' are provided. Please provide only one.")
|
| 1060 |
+
print("The 'images' argument is deprecated. Please use 'media' instead.")
|
| 1061 |
+
media = {"image": images}
|
| 1062 |
+
|
| 1063 |
+
if media_config is None:
|
| 1064 |
+
media_config = defaultdict(dict)
|
| 1065 |
+
|
| 1066 |
+
if inputs_embeds is None:
|
| 1067 |
+
inputs_embeds, labels, attention_mask = self._embed(input_ids, media, media_config, labels, attention_mask)
|
| 1068 |
+
|
| 1069 |
+
if force_packing or (packing and self.training and not dpo_forward):
|
| 1070 |
+
if seqlens_in_batch is None:
|
| 1071 |
+
seqlens_in_batch = torch.sum(attention_mask, dim=1)
|
| 1072 |
+
set_seqlens_in_batch(seqlens_in_batch)
|
| 1073 |
+
|
| 1074 |
+
(inputs_embeds, attention_mask, position_ids, labels) = self.repack_multimodal_data(
|
| 1075 |
+
inputs_embeds, attention_mask, position_ids, labels
|
| 1076 |
+
)
|
| 1077 |
+
|
| 1078 |
+
outputs = self.llm(
|
| 1079 |
+
inputs_embeds=inputs_embeds,
|
| 1080 |
+
attention_mask=attention_mask,
|
| 1081 |
+
position_ids=position_ids,
|
| 1082 |
+
past_key_values=past_key_values,
|
| 1083 |
+
labels=labels,
|
| 1084 |
+
**kwargs,
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
if self.training and getattr(self.config, "time_token_ids", []):
|
| 1088 |
+
outputs.loss = soft_cross_entropy(
|
| 1089 |
+
outputs.logits,
|
| 1090 |
+
labels,
|
| 1091 |
+
soft_tokens=self.config.time_token_ids,
|
| 1092 |
+
std=self.config.soft_ce_std,
|
| 1093 |
+
)
|
| 1094 |
+
|
| 1095 |
+
if dpo_forward:
|
| 1096 |
+
return outputs.logits, labels
|
| 1097 |
+
|
| 1098 |
+
return outputs
|
| 1099 |
+
|
| 1100 |
+
# @torch.inference_mode()
|
| 1101 |
+
def generate(
|
| 1102 |
+
self,
|
| 1103 |
+
input_ids: Optional[torch.FloatTensor] = None,
|
| 1104 |
+
media: Optional[Dict[str, List[torch.Tensor]]] = None,
|
| 1105 |
+
media_config: Dict[str, Dict[str, Any]] = None,
|
| 1106 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1107 |
+
return_output_ids_only: bool = True,
|
| 1108 |
+
**generation_kwargs,
|
| 1109 |
+
) -> torch.LongTensor:
|
| 1110 |
+
"""
|
| 1111 |
+
input_tokens: <image> describe the image
|
| 1112 |
+
media: [Tensor(1, 3, 384, 384), ]
|
| 1113 |
+
----------->
|
| 1114 |
+
input_tokens: 36000 001 002 003 004
|
| 1115 |
+
input_emds: <media emd> 001 002 003 004
|
| 1116 |
+
"""
|
| 1117 |
+
inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask)
|
| 1118 |
+
output_ids = self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs)
|
| 1119 |
+
|
| 1120 |
+
if return_output_ids_only:
|
| 1121 |
+
return_value = output_ids
|
| 1122 |
+
else:
|
| 1123 |
+
# by default, return the input_ids and output_ids concatenated to keep consistency with the community VLMs like qwen
|
| 1124 |
+
generation_config = generation_kwargs.get("generation_config", None)
|
| 1125 |
+
if generation_config is not None:
|
| 1126 |
+
num_generations = generation_config.num_return_sequences
|
| 1127 |
+
repeat_input_ids = input_ids.repeat_interleave(num_generations, dim=0)
|
| 1128 |
+
return_value = torch.cat([repeat_input_ids, output_ids], dim=-1)
|
| 1129 |
+
else:
|
| 1130 |
+
return_value = torch.cat([input_ids, output_ids], dim=-1)
|
| 1131 |
+
|
| 1132 |
+
return return_value
|
| 1133 |
+
|
| 1134 |
+
@torch.inference_mode()
|
| 1135 |
+
def generate_content(
|
| 1136 |
+
self,
|
| 1137 |
+
prompt: Union[str, List],
|
| 1138 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 1139 |
+
response_format=None,
|
| 1140 |
+
) -> str:
|
| 1141 |
+
conversation = [{"from": "human", "value": prompt}]
|
| 1142 |
+
|
| 1143 |
+
# Convert response format to logits processor
|
| 1144 |
+
xgr_logits_processor = None
|
| 1145 |
+
|
| 1146 |
+
# Extract media from the conversation
|
| 1147 |
+
|
| 1148 |
+
media = extract_media(conversation, self.config)
|
| 1149 |
+
|
| 1150 |
+
# Process media
|
| 1151 |
+
media_config = defaultdict(dict)
|
| 1152 |
+
for name in media:
|
| 1153 |
+
if name == "image":
|
| 1154 |
+
if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]:
|
| 1155 |
+
self.config.image_processor = self.vision_tower.image_processor
|
| 1156 |
+
if self.config.image_aspect_ratio == "dynamic":
|
| 1157 |
+
images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half()
|
| 1158 |
+
conversation[0]["value"] = conversation[0]["value"].replace(
|
| 1159 |
+
DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0]
|
| 1160 |
+
)
|
| 1161 |
+
else:
|
| 1162 |
+
if type(self.config.s2_scales) is str:
|
| 1163 |
+
self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
|
| 1164 |
+
images, block_sizes = process_image(
|
| 1165 |
+
media["image"][0], self.config, None, enable_dynamic_s2=True
|
| 1166 |
+
)
|
| 1167 |
+
images = images.half()
|
| 1168 |
+
media_config[name]["block_sizes"] = [block_sizes]
|
| 1169 |
+
else:
|
| 1170 |
+
images = process_images(media["image"], self.vision_tower.image_processor, self.config).half()
|
| 1171 |
+
media[name] = [image for image in images]
|
| 1172 |
+
elif name == "video":
|
| 1173 |
+
if self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1:
|
| 1174 |
+
media[name] = [
|
| 1175 |
+
process_images(
|
| 1176 |
+
images,
|
| 1177 |
+
self.vision_tower.image_processor,
|
| 1178 |
+
self.config,
|
| 1179 |
+
enable_dynamic_res=True,
|
| 1180 |
+
max_tiles=self.config.video_max_tiles,
|
| 1181 |
+
).half()
|
| 1182 |
+
for images in media[name]
|
| 1183 |
+
]
|
| 1184 |
+
elif self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1:
|
| 1185 |
+
self.config.image_processor = self.vision_tower.image_processor
|
| 1186 |
+
if type(self.config.s2_scales) is str:
|
| 1187 |
+
self.config.s2_scales = list(map(int, self.config.s2_scales.split(",")))
|
| 1188 |
+
media[name] = [
|
| 1189 |
+
torch.cat(
|
| 1190 |
+
[
|
| 1191 |
+
process_image(
|
| 1192 |
+
image,
|
| 1193 |
+
self.config,
|
| 1194 |
+
None,
|
| 1195 |
+
enable_dynamic_s2=True,
|
| 1196 |
+
max_tiles=self.config.video_max_tiles,
|
| 1197 |
+
)[0].half()
|
| 1198 |
+
for image in images
|
| 1199 |
+
]
|
| 1200 |
+
)
|
| 1201 |
+
for images in media[name]
|
| 1202 |
+
]
|
| 1203 |
+
else:
|
| 1204 |
+
media[name] = [
|
| 1205 |
+
process_images(images, self.vision_tower.image_processor, self.config).half()
|
| 1206 |
+
for images in media[name]
|
| 1207 |
+
]
|
| 1208 |
+
else:
|
| 1209 |
+
raise ValueError(f"Unsupported media type: {name}")
|
| 1210 |
+
|
| 1211 |
+
# Tokenize the conversation
|
| 1212 |
+
input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).unsqueeze(0).cuda()
|
| 1213 |
+
|
| 1214 |
+
# Set up the generation config
|
| 1215 |
+
generation_config = generation_config or self.default_generation_config
|
| 1216 |
+
|
| 1217 |
+
# Generate the response
|
| 1218 |
+
try:
|
| 1219 |
+
output_ids = self.generate(
|
| 1220 |
+
input_ids=input_ids,
|
| 1221 |
+
media=media,
|
| 1222 |
+
media_config=media_config,
|
| 1223 |
+
generation_config=generation_config,
|
| 1224 |
+
logits_processor=xgr_logits_processor, # structured generation
|
| 1225 |
+
)
|
| 1226 |
+
except ValueError:
|
| 1227 |
+
if not generation_config.do_sample:
|
| 1228 |
+
raise
|
| 1229 |
+
logging.warning("Generation failed with sampling, retrying with greedy decoding.")
|
| 1230 |
+
generation_config.do_sample = False
|
| 1231 |
+
output_ids = self.generate(
|
| 1232 |
+
input_ids=input_ids,
|
| 1233 |
+
media=media,
|
| 1234 |
+
media_config=media_config,
|
| 1235 |
+
generation_config=generation_config,
|
| 1236 |
+
logits_processor=xgr_logits_processor,
|
| 1237 |
+
)
|
| 1238 |
+
|
| 1239 |
+
# Decode the response
|
| 1240 |
+
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
|
| 1241 |
+
return response
|
| 1242 |
+
|
| 1243 |
+
@property
|
| 1244 |
+
def default_generation_config(self) -> GenerationConfig:
|
| 1245 |
+
generation_config = copy.deepcopy(self.generation_config or GenerationConfig())
|
| 1246 |
+
if self.tokenizer.eos_token_id is None:
|
| 1247 |
+
raise ValueError("Tokenizer must have an EOS token")
|
| 1248 |
+
if generation_config.max_length == GenerationConfig().max_length:
|
| 1249 |
+
generation_config.max_length = self.tokenizer.model_max_length
|
| 1250 |
+
if generation_config.pad_token_id is None:
|
| 1251 |
+
generation_config.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
|
| 1252 |
+
if generation_config.bos_token_id is None:
|
| 1253 |
+
generation_config.bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
|
| 1254 |
+
if generation_config.eos_token_id is None:
|
| 1255 |
+
generation_config.eos_token_id = self.tokenizer.eos_token_id
|
| 1256 |
+
return generation_config
|
siglip_encoder.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from accelerate.hooks import add_hook_to_module
|
| 21 |
+
from einops import rearrange
|
| 22 |
+
from s2wrapper import forward as multiscale_forward
|
| 23 |
+
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, SiglipImageProcessor
|
| 24 |
+
from transformers.image_processing_utils import BaseImageProcessor
|
| 25 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 26 |
+
from transformers.models.siglip import SiglipVisionModel
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class VisionTower(nn.Module):
|
| 30 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
self.is_loaded = False
|
| 34 |
+
|
| 35 |
+
self.vision_tower_name = vision_tower
|
| 36 |
+
self.select_layer = getattr(args, "mm_vision_select_layer", -2)
|
| 37 |
+
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
|
| 38 |
+
|
| 39 |
+
self.cfg_only = None
|
| 40 |
+
|
| 41 |
+
def feature_select(self, image_forward_outs):
|
| 42 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
| 43 |
+
if self.select_feature == "patch":
|
| 44 |
+
image_features = image_features[:, 1:]
|
| 45 |
+
elif self.select_feature == "cls_patch":
|
| 46 |
+
image_features = image_features
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
| 49 |
+
return image_features
|
| 50 |
+
|
| 51 |
+
def _maybe_resize_pos_embeds(
|
| 52 |
+
self,
|
| 53 |
+
model: PreTrainedModel,
|
| 54 |
+
image_processor: BaseImageProcessor,
|
| 55 |
+
resolution: int = -1,
|
| 56 |
+
interpolate_mode: str = "linear",
|
| 57 |
+
):
|
| 58 |
+
if resolution in [model.config.image_size, -1]:
|
| 59 |
+
return
|
| 60 |
+
print(
|
| 61 |
+
f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..."
|
| 62 |
+
)
|
| 63 |
+
embeddings = model.vision_model.embeddings
|
| 64 |
+
patch_size = embeddings.patch_size
|
| 65 |
+
num_new_tokens = int((resolution // patch_size) ** 2)
|
| 66 |
+
|
| 67 |
+
old_embeddings = embeddings.position_embedding
|
| 68 |
+
match interpolate_mode:
|
| 69 |
+
case "linear":
|
| 70 |
+
## Step 1: Calculate the corresponding patch ID (pid) in the current resolution (M patches) based on the target resolution (N patches). Formula: pid = pid / N * M
|
| 71 |
+
## Step 2: Obtain new embeddings by interpolating between the embeddings of the two nearest calculated patch IDs. Formula: new_embeds = (pid - floor(pid)) * embeds[ceil(pid)] + (ceil(pid) - pid) * embeds[floor(pid)]
|
| 72 |
+
import torch
|
| 73 |
+
import torch.nn as nn
|
| 74 |
+
|
| 75 |
+
if is_deepspeed_zero3_enabled():
|
| 76 |
+
try:
|
| 77 |
+
import deepspeed
|
| 78 |
+
except ImportError:
|
| 79 |
+
raise ImportError("DeepSpeed is not installed. Please install it with `pip install deepspeed`.")
|
| 80 |
+
with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
|
| 81 |
+
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
| 82 |
+
else:
|
| 83 |
+
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
| 84 |
+
new_embeddings = nn.Embedding(
|
| 85 |
+
num_new_tokens,
|
| 86 |
+
old_embedding_dim,
|
| 87 |
+
dtype=old_embeddings.weight.dtype,
|
| 88 |
+
device=old_embeddings.weight.device,
|
| 89 |
+
)
|
| 90 |
+
mapped_indices = (
|
| 91 |
+
torch.arange(num_new_tokens).to(old_embeddings.weight.device)
|
| 92 |
+
/ (num_new_tokens - 1)
|
| 93 |
+
* (old_num_tokens - 1)
|
| 94 |
+
)
|
| 95 |
+
floor_indices = torch.clamp(mapped_indices.floor().long(), min=0, max=old_num_tokens - 1)
|
| 96 |
+
ceil_indices = torch.clamp(mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1)
|
| 97 |
+
if is_deepspeed_zero3_enabled():
|
| 98 |
+
params = [old_embeddings.weight, new_embeddings.weight]
|
| 99 |
+
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
| 100 |
+
interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
|
| 101 |
+
ceil_indices, :
|
| 102 |
+
] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
|
| 103 |
+
else:
|
| 104 |
+
interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
|
| 105 |
+
ceil_indices, :
|
| 106 |
+
] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
|
| 107 |
+
new_embeddings.weight.data = interpolated_embeds
|
| 108 |
+
case _:
|
| 109 |
+
raise NotImplementedError
|
| 110 |
+
|
| 111 |
+
if hasattr(old_embeddings, "_hf_hook"):
|
| 112 |
+
hook = old_embeddings._hf_hook
|
| 113 |
+
add_hook_to_module(new_embeddings, hook)
|
| 114 |
+
new_embeddings.requires_grad_(old_embeddings.weight.requires_grad)
|
| 115 |
+
## update vision encoder's configurations
|
| 116 |
+
model.config.image_size = resolution
|
| 117 |
+
if hasattr(image_processor, "crop_size"):
|
| 118 |
+
# CLIP vision tower
|
| 119 |
+
image_processor.crop_size = resolution
|
| 120 |
+
else:
|
| 121 |
+
# SIGLIP vision tower
|
| 122 |
+
assert hasattr(image_processor, "size")
|
| 123 |
+
image_processor.size = {"height": resolution, "width": resolution}
|
| 124 |
+
embeddings.position_embedding = new_embeddings
|
| 125 |
+
embeddings.image_size = resolution
|
| 126 |
+
embeddings.num_patches = embeddings.num_positions = num_new_tokens
|
| 127 |
+
embeddings.position_ids = (
|
| 128 |
+
torch.arange(embeddings.num_positions).expand((1, -1)).to(old_embeddings.weight.device)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def forward(self, images):
|
| 132 |
+
if type(images) is list:
|
| 133 |
+
image_features = []
|
| 134 |
+
for image in images:
|
| 135 |
+
image_forward_out = self.vision_tower(
|
| 136 |
+
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
|
| 137 |
+
output_hidden_states=True,
|
| 138 |
+
)
|
| 139 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
| 140 |
+
image_features.append(image_feature)
|
| 141 |
+
else:
|
| 142 |
+
image_forward_outs = self.vision_tower(
|
| 143 |
+
images.to(device=self.device, dtype=self.dtype),
|
| 144 |
+
output_hidden_states=True,
|
| 145 |
+
)
|
| 146 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
| 147 |
+
|
| 148 |
+
return image_features
|
| 149 |
+
|
| 150 |
+
@property
|
| 151 |
+
def dummy_feature(self):
|
| 152 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
| 153 |
+
|
| 154 |
+
@property
|
| 155 |
+
def dtype(self):
|
| 156 |
+
return self.vision_tower.dtype
|
| 157 |
+
|
| 158 |
+
@property
|
| 159 |
+
def device(self):
|
| 160 |
+
return self.vision_tower.device
|
| 161 |
+
|
| 162 |
+
@property
|
| 163 |
+
def config(self):
|
| 164 |
+
if self.is_loaded:
|
| 165 |
+
return self.vision_tower.config
|
| 166 |
+
else:
|
| 167 |
+
return self.cfg_only
|
| 168 |
+
|
| 169 |
+
@property
|
| 170 |
+
def hidden_size(self):
|
| 171 |
+
return self.config.hidden_size
|
| 172 |
+
|
| 173 |
+
@property
|
| 174 |
+
def num_patches(self):
|
| 175 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class VisionTowerS2(VisionTower):
|
| 179 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
| 180 |
+
super().__init__(vision_tower, args, delay_load)
|
| 181 |
+
|
| 182 |
+
self.scales = list(map(int, args.s2_scales.split(",")))
|
| 183 |
+
self.scales.sort()
|
| 184 |
+
self.max_split_size = args.s2_max_split_size
|
| 185 |
+
self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0)
|
| 186 |
+
|
| 187 |
+
def forward_feature(self, images):
|
| 188 |
+
image_forward_outs = self.vision_tower(
|
| 189 |
+
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
|
| 190 |
+
)
|
| 191 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
| 192 |
+
return image_features
|
| 193 |
+
|
| 194 |
+
def forward(self, images):
|
| 195 |
+
if type(images) is list:
|
| 196 |
+
image_features = []
|
| 197 |
+
for image in images:
|
| 198 |
+
image_feature = multiscale_forward(
|
| 199 |
+
self.forward_feature,
|
| 200 |
+
image.unsqueeze(0),
|
| 201 |
+
img_sizes=self.scales,
|
| 202 |
+
max_split_size=self.max_split_size,
|
| 203 |
+
resize_output_to_idx=self.resize_output_to_scale_idx,
|
| 204 |
+
)
|
| 205 |
+
image_features.append(image_feature)
|
| 206 |
+
else:
|
| 207 |
+
image_features = multiscale_forward(
|
| 208 |
+
self.forward_feature,
|
| 209 |
+
images,
|
| 210 |
+
img_sizes=self.scales,
|
| 211 |
+
max_split_size=self.max_split_size,
|
| 212 |
+
resize_output_to_idx=self.resize_output_to_scale_idx,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
return image_features
|
| 216 |
+
|
| 217 |
+
@property
|
| 218 |
+
def hidden_size(self):
|
| 219 |
+
return self.config.hidden_size * len(self.scales)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class VisionTowerDynamicS2(VisionTower):
|
| 223 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
| 224 |
+
super().__init__(vision_tower, args, delay_load)
|
| 225 |
+
|
| 226 |
+
self.scales = list(map(int, args.s2_scales.split(",")))
|
| 227 |
+
self.scales.sort()
|
| 228 |
+
self.max_split_size = args.s2_max_split_size
|
| 229 |
+
self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0)
|
| 230 |
+
|
| 231 |
+
def forward_feature(self, images):
|
| 232 |
+
image_forward_outs = self.vision_tower(
|
| 233 |
+
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
|
| 234 |
+
)
|
| 235 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
| 236 |
+
return image_features
|
| 237 |
+
|
| 238 |
+
def forward(self, images):
|
| 239 |
+
assert type(images) is not list
|
| 240 |
+
image_features = self.forward_feature(images)
|
| 241 |
+
|
| 242 |
+
return image_features
|
| 243 |
+
|
| 244 |
+
@property
|
| 245 |
+
def hidden_size(self):
|
| 246 |
+
return self.config.hidden_size * len(self.scales)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class SiglipVisionTower(VisionTower):
|
| 250 |
+
def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
|
| 251 |
+
super().__init__(model_name_or_path, config)
|
| 252 |
+
self.vision_tower = SiglipVisionModel.from_pretrained(
|
| 253 |
+
model_name_or_path,
|
| 254 |
+
attn_implementation=config._attn_implementation,
|
| 255 |
+
torch_dtype=eval(config.model_dtype),
|
| 256 |
+
)
|
| 257 |
+
self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
|
| 258 |
+
self.is_loaded = True
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class SiglipVisionTowerS2(VisionTowerS2):
|
| 262 |
+
def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
|
| 263 |
+
super().__init__(model_name_or_path, config)
|
| 264 |
+
self.vision_tower = SiglipVisionModel.from_pretrained(
|
| 265 |
+
model_name_or_path,
|
| 266 |
+
attn_implementation=config._attn_implementation,
|
| 267 |
+
torch_dtype=eval(config.model_dtype),
|
| 268 |
+
)
|
| 269 |
+
self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
|
| 270 |
+
# Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
|
| 271 |
+
self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[-1]
|
| 272 |
+
self.is_loaded = True
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class SiglipVisionTowerDynamicS2(VisionTowerDynamicS2):
|
| 276 |
+
def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None:
|
| 277 |
+
super().__init__(model_name_or_path, config)
|
| 278 |
+
self.vision_tower = SiglipVisionModel.from_pretrained(
|
| 279 |
+
model_name_or_path,
|
| 280 |
+
attn_implementation=config._attn_implementation,
|
| 281 |
+
torch_dtype=eval(config.model_dtype),
|
| 282 |
+
)
|
| 283 |
+
self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
|
| 284 |
+
# Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
|
| 285 |
+
self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[0]
|
| 286 |
+
self.is_loaded = True
|
tokenizer_utils.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
from typing import Any, Dict, List, Optional, Sequence
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import transformers
|
| 21 |
+
|
| 22 |
+
from .constants import IGNORE_INDEX, SENTINEL_TOKEN
|
| 23 |
+
from .conversation import SeparatorStyle, default_conversation
|
| 24 |
+
from .mm_utils import tokenizer_image_token
|
| 25 |
+
|
| 26 |
+
DUMMY_CONVERSATION = [
|
| 27 |
+
{"from": "human", "value": "question"},
|
| 28 |
+
{"from": "gpt", "value": "answer"},
|
| 29 |
+
] * 10
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def tokenize_conversation_legacy(
|
| 33 |
+
messages: Sequence[Dict[str, str]],
|
| 34 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 35 |
+
add_generation_prompt: bool = False,
|
| 36 |
+
overrides: Optional[Dict[str, str]] = None,
|
| 37 |
+
no_system_prompt: bool = False,
|
| 38 |
+
) -> torch.Tensor:
|
| 39 |
+
conv = default_conversation.copy()
|
| 40 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| 41 |
+
|
| 42 |
+
if no_system_prompt:
|
| 43 |
+
conv.system = ""
|
| 44 |
+
|
| 45 |
+
# Skip the first message if it is not from human
|
| 46 |
+
if messages[0]["from"] != "human":
|
| 47 |
+
messages = messages[1:]
|
| 48 |
+
|
| 49 |
+
# Add a generation prompt if needed
|
| 50 |
+
if add_generation_prompt:
|
| 51 |
+
messages.append({"from": "gpt", "value": None})
|
| 52 |
+
|
| 53 |
+
conv.messages = []
|
| 54 |
+
for turn, message in enumerate(messages):
|
| 55 |
+
role = roles[message["from"]]
|
| 56 |
+
assert role == conv.roles[turn % 2]
|
| 57 |
+
if overrides is not None and message["from"] in overrides:
|
| 58 |
+
conv.append_message(role, overrides[message["from"]])
|
| 59 |
+
else:
|
| 60 |
+
conv.append_message(role, message["value"])
|
| 61 |
+
|
| 62 |
+
return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def tokenize_conversation(
|
| 66 |
+
messages: Sequence[Dict[str, str]],
|
| 67 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 68 |
+
add_generation_prompt: bool = False,
|
| 69 |
+
overrides: Optional[Dict[str, str]] = None,
|
| 70 |
+
no_system_prompt: bool = False,
|
| 71 |
+
return_ids_only=True,
|
| 72 |
+
) -> torch.Tensor:
|
| 73 |
+
# Normalize the conversation before tokenization
|
| 74 |
+
for message in messages:
|
| 75 |
+
message["value"] = message["value"].strip()
|
| 76 |
+
|
| 77 |
+
if default_conversation.sep_style != SeparatorStyle.AUTO:
|
| 78 |
+
return tokenize_conversation_legacy(
|
| 79 |
+
messages,
|
| 80 |
+
tokenizer,
|
| 81 |
+
add_generation_prompt=add_generation_prompt,
|
| 82 |
+
overrides=overrides,
|
| 83 |
+
no_system_prompt=no_system_prompt,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
conversation = []
|
| 87 |
+
for m in messages:
|
| 88 |
+
message = {}
|
| 89 |
+
if m["from"] == "human":
|
| 90 |
+
message["role"] = "user"
|
| 91 |
+
elif m["from"] == "gpt":
|
| 92 |
+
message["role"] = "assistant"
|
| 93 |
+
elif m["from"] == "system":
|
| 94 |
+
message["role"] = "system"
|
| 95 |
+
if no_system_prompt:
|
| 96 |
+
raise ValueError("message[role]=system is not allowed when no_system_prompt is set to True.")
|
| 97 |
+
else:
|
| 98 |
+
raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.")
|
| 99 |
+
|
| 100 |
+
message["content"] = m["value"]
|
| 101 |
+
if overrides is not None and m["from"] in overrides:
|
| 102 |
+
message["content"] = overrides[m["from"]]
|
| 103 |
+
conversation.append(message)
|
| 104 |
+
|
| 105 |
+
if no_system_prompt:
|
| 106 |
+
conversation = [{"role": "system", "content": ""}] + conversation
|
| 107 |
+
|
| 108 |
+
text = tokenizer.apply_chat_template(
|
| 109 |
+
conversation,
|
| 110 |
+
add_generation_prompt=add_generation_prompt,
|
| 111 |
+
tokenize=False,
|
| 112 |
+
)
|
| 113 |
+
return tokenizer_image_token(text, tokenizer, return_tensors="pt", return_ids=return_ids_only)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None:
|
| 117 |
+
if not hasattr(tokenizer, "sentinel_token"):
|
| 118 |
+
tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True)
|
| 119 |
+
tokenizer.sentinel_token = SENTINEL_TOKEN
|
| 120 |
+
tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def preprocess_conversation(
|
| 124 |
+
conversation: Sequence[Dict[str, str]],
|
| 125 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 126 |
+
no_system_prompt: bool = False,
|
| 127 |
+
retried: bool = False,
|
| 128 |
+
**kwargs: Any,
|
| 129 |
+
) -> Dict[str, Any]:
|
| 130 |
+
inputs = tokenize_conversation(conversation, tokenizer, no_system_prompt=no_system_prompt)
|
| 131 |
+
labels = torch.ones_like(inputs) * IGNORE_INDEX
|
| 132 |
+
|
| 133 |
+
# Generate the template by replacing the assistant's response with a sentinel.
|
| 134 |
+
_maybe_add_sentinel_token(tokenizer)
|
| 135 |
+
template = tokenize_conversation(
|
| 136 |
+
conversation, tokenizer, overrides={"gpt": SENTINEL_TOKEN}, no_system_prompt=no_system_prompt
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Remove sentinel tokens from the template.
|
| 140 |
+
mask = torch.ones_like(template, dtype=torch.bool)
|
| 141 |
+
for k in range(template.size(0) - 1):
|
| 142 |
+
if template[k] == tokenizer.sentinel_token_id:
|
| 143 |
+
mask[k : k + 2] = False
|
| 144 |
+
if k > 0 and retried:
|
| 145 |
+
mask[k - 1] = False
|
| 146 |
+
template = template[mask]
|
| 147 |
+
|
| 148 |
+
# Match the tokenized conversation with the template (with no assistant's response).
|
| 149 |
+
# Every token that is not matched will be included in the label for training.
|
| 150 |
+
p = 0
|
| 151 |
+
for k in range(inputs.size(0)):
|
| 152 |
+
if p < template.size(0) and inputs[k] == template[p]:
|
| 153 |
+
p += 1
|
| 154 |
+
else:
|
| 155 |
+
labels[k] = inputs[k]
|
| 156 |
+
|
| 157 |
+
# Mask all tokens in the label if the template is not fully matched.
|
| 158 |
+
if p < template.size(0):
|
| 159 |
+
if not retried:
|
| 160 |
+
return preprocess_conversation(
|
| 161 |
+
conversation,
|
| 162 |
+
tokenizer,
|
| 163 |
+
no_system_prompt=no_system_prompt,
|
| 164 |
+
retried=True,
|
| 165 |
+
)
|
| 166 |
+
print(f"Failed to process the conversation: '{conversation}'. All tokens will be masked in the label.")
|
| 167 |
+
labels[:] = IGNORE_INDEX
|
| 168 |
+
|
| 169 |
+
return {"input_ids": inputs, "labels": labels}
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]:
|
| 173 |
+
_maybe_add_sentinel_token(tokenizer)
|
| 174 |
+
template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN})
|
| 175 |
+
|
| 176 |
+
stop_tokens = {tokenizer.eos_token}
|
| 177 |
+
for k in range(template.size(0) - 1):
|
| 178 |
+
if template[k] == tokenizer.sentinel_token_id:
|
| 179 |
+
stop_token = tokenizer.decode(template[k + 1])
|
| 180 |
+
stop_tokens.add(stop_token)
|
| 181 |
+
return list(stop_tokens)
|
utils.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
# This file is modified from https://github.com/haotian-liu/LLaVA/
|
| 17 |
+
import os
|
| 18 |
+
import os.path as osp
|
| 19 |
+
|
| 20 |
+
from huggingface_hub import repo_exists, snapshot_download
|
| 21 |
+
from huggingface_hub.utils import HFValidationError, validate_repo_id
|
| 22 |
+
from transformers import AutoConfig, AutoTokenizer, PretrainedConfig
|
| 23 |
+
|
| 24 |
+
from .configuration_vila import VILAConfig
|
| 25 |
+
from .constants import MEDIA_TOKENS
|
| 26 |
+
from .tokenizer_utils import infer_stop_tokens
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_tokenizer_then_handle_media_tokens_and_chat_template(
|
| 30 |
+
model_name_or_path, config: VILAConfig, model_max_length=None
|
| 31 |
+
):
|
| 32 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 33 |
+
osp.join(model_name_or_path, "llm"), padding_side="right", use_fast=True, legacy=False
|
| 34 |
+
)
|
| 35 |
+
if model_max_length is not None:
|
| 36 |
+
tokenizer.model_max_length = model_max_length
|
| 37 |
+
|
| 38 |
+
# Load chat template if specified.
|
| 39 |
+
if getattr(config, "chat_template", None) is not None:
|
| 40 |
+
print(f"Using chat template: {config.chat_template}")
|
| 41 |
+
fpath = os.path.join(os.path.dirname(__file__), "chat_templates", f"{config.chat_template}.jinja")
|
| 42 |
+
if not os.path.exists(fpath):
|
| 43 |
+
fpath = os.path.join(model_name_or_path, f"{config.chat_template}.jinja")
|
| 44 |
+
with open(fpath) as fd:
|
| 45 |
+
chat_template = fd.read()
|
| 46 |
+
tokenizer.chat_template = chat_template.replace(" ", "").replace("\n", "")
|
| 47 |
+
|
| 48 |
+
# Set stop tokens for the tokenizer
|
| 49 |
+
tokenizer.stop_tokens = infer_stop_tokens(tokenizer)
|
| 50 |
+
tokenizer.stop_token_ids = tokenizer.convert_tokens_to_ids(tokenizer.stop_tokens)
|
| 51 |
+
|
| 52 |
+
# Add media tokens to the tokenizer
|
| 53 |
+
tokenizer.media_tokens = MEDIA_TOKENS
|
| 54 |
+
tokenizer.media_token_ids = {}
|
| 55 |
+
for name, token in MEDIA_TOKENS.items():
|
| 56 |
+
tokenizer.add_tokens([token], special_tokens=True)
|
| 57 |
+
tokenizer.media_token_ids[name] = tokenizer.convert_tokens_to_ids(token)
|
| 58 |
+
|
| 59 |
+
return tokenizer
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_model_config(config):
|
| 63 |
+
default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg"]
|
| 64 |
+
|
| 65 |
+
if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
|
| 66 |
+
root_path = config._name_or_path
|
| 67 |
+
else:
|
| 68 |
+
root_path = config.resume_path
|
| 69 |
+
|
| 70 |
+
# download from huggingface
|
| 71 |
+
if root_path is not None and not osp.exists(root_path):
|
| 72 |
+
try:
|
| 73 |
+
valid_hf_repo = repo_exists(root_path)
|
| 74 |
+
except HFValidationError as e:
|
| 75 |
+
valid_hf_repo = False
|
| 76 |
+
if valid_hf_repo:
|
| 77 |
+
root_path = snapshot_download(root_path)
|
| 78 |
+
|
| 79 |
+
return_list = []
|
| 80 |
+
for key in default_keys:
|
| 81 |
+
cfg = getattr(config, key, None)
|
| 82 |
+
if isinstance(cfg, dict):
|
| 83 |
+
try:
|
| 84 |
+
return_list.append(os.path.join(root_path, key[:-4]))
|
| 85 |
+
except:
|
| 86 |
+
raise ValueError(f"Cannot find resume path in config for {key}!")
|
| 87 |
+
elif isinstance(cfg, PretrainedConfig):
|
| 88 |
+
return_list.append(os.path.join(root_path, key[:-4]))
|
| 89 |
+
elif isinstance(cfg, str):
|
| 90 |
+
return_list.append(cfg)
|
| 91 |
+
|
| 92 |
+
return return_list
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_model_config_fp8(config):
|
| 96 |
+
default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg"]
|
| 97 |
+
|
| 98 |
+
if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
|
| 99 |
+
root_path = config._name_or_path
|
| 100 |
+
else:
|
| 101 |
+
root_path = config.resume_path
|
| 102 |
+
|
| 103 |
+
# download from huggingface
|
| 104 |
+
if root_path is not None and not osp.exists(root_path):
|
| 105 |
+
try:
|
| 106 |
+
valid_hf_repo = repo_exists(root_path)
|
| 107 |
+
except HFValidationError as e:
|
| 108 |
+
valid_hf_repo = False
|
| 109 |
+
if valid_hf_repo:
|
| 110 |
+
root_path = snapshot_download(root_path)
|
| 111 |
+
|
| 112 |
+
return_list = []
|
| 113 |
+
for key in default_keys:
|
| 114 |
+
cfg = getattr(config, key, None)
|
| 115 |
+
if isinstance(cfg, dict):
|
| 116 |
+
try:
|
| 117 |
+
return_list.append(os.path.join(root_path, key[:-4]))
|
| 118 |
+
except:
|
| 119 |
+
raise ValueError(f"Cannot find resume path in config for {key}!")
|
| 120 |
+
elif isinstance(cfg, PretrainedConfig):
|
| 121 |
+
return_list.append(os.path.join(root_path, key[:-4]))
|
| 122 |
+
elif isinstance(cfg, str):
|
| 123 |
+
return_list.append(cfg)
|
| 124 |
+
|
| 125 |
+
# fp8_llm
|
| 126 |
+
key = "fp8_llm_cfg"
|
| 127 |
+
directory_path = os.path.join(root_path, key[:-4])
|
| 128 |
+
assert os.path.isdir(directory_path) and os.listdir(
|
| 129 |
+
directory_path
|
| 130 |
+
), "You need to first convert the model weights to FP8 explicitly."
|
| 131 |
+
return_list.append(directory_path)
|
| 132 |
+
|
| 133 |
+
return return_list
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def get_model_config_fp8(config):
|
| 137 |
+
default_keys = ["llm_cfg", "vision_tower_cfg", "mm_projector_cfg"]
|
| 138 |
+
|
| 139 |
+
if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
|
| 140 |
+
root_path = config._name_or_path
|
| 141 |
+
else:
|
| 142 |
+
root_path = config.resume_path
|
| 143 |
+
|
| 144 |
+
# download from huggingface
|
| 145 |
+
if root_path is not None and not osp.exists(root_path):
|
| 146 |
+
try:
|
| 147 |
+
valid_hf_repo = repo_exists(root_path)
|
| 148 |
+
except HFValidationError as e:
|
| 149 |
+
valid_hf_repo = False
|
| 150 |
+
if valid_hf_repo:
|
| 151 |
+
root_path = snapshot_download(root_path)
|
| 152 |
+
|
| 153 |
+
return_list = []
|
| 154 |
+
for key in default_keys:
|
| 155 |
+
cfg = getattr(config, key, None)
|
| 156 |
+
if isinstance(cfg, dict):
|
| 157 |
+
try:
|
| 158 |
+
return_list.append(os.path.join(root_path, key[:-4]))
|
| 159 |
+
except:
|
| 160 |
+
raise ValueError(f"Cannot find resume path in config for {key}!")
|
| 161 |
+
elif isinstance(cfg, PretrainedConfig):
|
| 162 |
+
return_list.append(os.path.join(root_path, key[:-4]))
|
| 163 |
+
elif isinstance(cfg, str):
|
| 164 |
+
return_list.append(cfg)
|
| 165 |
+
|
| 166 |
+
# fp8_llm
|
| 167 |
+
key = "fp8_llm_cfg"
|
| 168 |
+
directory_path = os.path.join(root_path, key[:-4])
|
| 169 |
+
assert os.path.isdir(directory_path) and os.listdir(
|
| 170 |
+
directory_path
|
| 171 |
+
), "You need to first convert the model weights to FP8 explicitly."
|
| 172 |
+
return_list.append(directory_path)
|
| 173 |
+
|
| 174 |
+
return return_list
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def is_mm_model(model_path):
|
| 178 |
+
"""
|
| 179 |
+
Check if the model at the given path is a visual language model.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
model_path (str): The path to the model.
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
bool: True if the model is an MM model, False otherwise.
|
| 186 |
+
"""
|
| 187 |
+
config = AutoConfig.from_pretrained(model_path)
|
| 188 |
+
architectures = config.architectures
|
| 189 |
+
for architecture in architectures:
|
| 190 |
+
if "llava" in architecture.lower():
|
| 191 |
+
return True
|
| 192 |
+
return False
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def auto_upgrade(config):
|
| 196 |
+
cfg = AutoConfig.from_pretrained(config)
|
| 197 |
+
if "llava" in config and "llava" not in cfg.model_type:
|
| 198 |
+
assert cfg.model_type == "llama"
|
| 199 |
+
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
|
| 200 |
+
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
|
| 201 |
+
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
|
| 202 |
+
if confirm.lower() in ["y", "yes"]:
|
| 203 |
+
print("Upgrading checkpoint...")
|
| 204 |
+
assert len(cfg.architectures) == 1
|
| 205 |
+
setattr(cfg.__class__, "model_type", "llava")
|
| 206 |
+
cfg.architectures[0] = "LlavaLlamaForCausalLM"
|
| 207 |
+
cfg.save_pretrained(config)
|
| 208 |
+
print("Checkpoint upgraded.")
|
| 209 |
+
else:
|
| 210 |
+
print("Checkpoint upgrade aborted.")
|
| 211 |
+
exit(1)
|