Add files using upload-large-folder tool
Browse files- .gitignore +1 -0
- LICENSE +34 -0
- README.md +108 -6
- README_EN.md +111 -0
- checklist.chk +99 -0
- config.json +40 -0
- configuration_openpangu_moe.py +82 -0
- doc/docker.md +31 -0
- doc/docker_EN.md +31 -0
- doc/vllm_ascend_for_openpangu_ultra_moe_718b.md +215 -0
- doc/vllm_ascend_for_openpangu_ultra_moe_718b_EN.md +216 -0
- generation_config.json +11 -0
- inference/generate.py +106 -0
- inference/generate.sh +70 -0
- inference/model.py +918 -0
- inference/runner.py +411 -0
- inference/runner_config/tp1.yaml +30 -0
- inference/runner_config/tp32.yaml +30 -0
- inference/split_weight.py +387 -0
- inference/split_weight.sh +13 -0
- inference/vllm_ascend/_build_info.py +3 -0
- inference/vllm_ascend/attention/attention.py +1220 -0
- inference/vllm_ascend/attention/mla_v1.py +1224 -0
- inference/vllm_ascend/entrypoints/openai/reasoning_parsers/__init__.py +6 -0
- inference/vllm_ascend/entrypoints/openai/reasoning_parsers/pangu_reasoning_parser.py +171 -0
- inference/vllm_ascend/entrypoints/openai/tool_parsers/__init__.py +6 -0
- inference/vllm_ascend/entrypoints/openai/tool_parsers/pangu_tool_parser.py +300 -0
- inference/vllm_ascend/envs.py +153 -0
- inference/vllm_ascend/models/__init__.py +68 -0
- inference/vllm_ascend/models/open_pangu.py +1127 -0
- inference/vllm_ascend/ops/fused_moe.py +1530 -0
- inference/vllm_ascend/patch/worker/patch_common/__init__.py +27 -0
- inference/vllm_ascend/patch/worker/patch_common/patch_config.py +97 -0
- inference/vllm_ascend/patch/worker/patch_common/patch_parsers.py +26 -0
- inference/vllm_ascend/patch/worker/patch_common/patch_sampler.py +159 -0
- inference/vllm_ascend/quantization/w8a8.py +757 -0
- inference/vllm_ascend/quantization/w8a8_dynamic.py +831 -0
- inference/vllm_ascend/utils.py +563 -0
- inference/vllm_ascend/worker/model_runner_v1.py +0 -0
- inference/vllm_ascend/worker/npu_input_batch.py +796 -0
- model-00002-of-000062.safetensors +3 -0
- model-00003-of-000062.safetensors +3 -0
- model-00005-of-000062.safetensors +3 -0
- model-00045-of-000062.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_openpangu_moe.py +653 -0
- special_tokens_map.json +30 -0
- tokenization_openpangu.py +273 -0
- tokenizer_config.json +1 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
LICENSE
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0
|
| 2 |
+
|
| 3 |
+
This OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0 (the "Agreement") is a legal agreement between You and Huawei Technologies Co., Ltd. ("Huawei", "We" or "Us"), and it governs Your reproducing, use, modification, and distribution of openPangu as made available by Huawei under this Agreement.
|
| 4 |
+
|
| 5 |
+
By using, reproducing, modifying, distributing, performing or displaying any portion or element of openPangu, or otherwise accepting the terms of this Agreement, You agree to be bound by this Agreement.
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
1.1. “openPangu” or “Model” means openPangu large language models and software, including trained model weights, parameters (including optimizer states), accompanying source code and scripts released under this Agreement.
|
| 9 |
+
1.2. “Derivative Model” means all (1) modifications to the Model, (2) works based on the Model, and (3) any other derivative works of the Model. For clarity, information or content results from operating or otherwise using the Model is not a Derivative Model.
|
| 10 |
+
1.3. “You” or “Your” means an individual or Legal Entity exercising permissions granted by this Agreement and/or using the Model for any purpose.
|
| 11 |
+
1.4. “Third Party” or “Third Parties” means individuals or legal entities that are not under common control with Us or You.
|
| 12 |
+
|
| 13 |
+
2. License Grant. Subject to Your full compliance with the terms and conditions of this Agreement, We hereby grant to You a perpetual, worldwide, non-exclusive, non-transferable, no-charge, royalty-free license (except as stated in Section 3) to use, reproduce, modify, and distribute the Model.
|
| 14 |
+
|
| 15 |
+
3. Conditions for License Grant. You represent and warrant that You will not, access, download, install, run, deploy, integrate, modify, or otherwise use the Model, directly or indirectly, within the European Union.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
4. Redistribution.
|
| 19 |
+
4.1. If You distribute the Model or Derivative Model, You shall retain in Your distribution (1) a copy of this agreement, and (2) all copyright notices and other notices of origin included in the Model that are applicable to Your distribution.
|
| 20 |
+
4.2. Further, if You distribute or make available to Third Parties a product or service (including another AI model) based on the Model, You are required to (1) display the acknowledgement “Powered by openPangu” and (2) include a trademark notice “openPangu is a trademark of Huawei Technologies Co., Ltd.” on related webpages, user manuals, product documentations or other advertising materials mentioning features of the Model.
|
| 21 |
+
4.3. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for Derivative Model made by You as a whole, provided Your use, reproduction, and distribution of the Model otherwise complies with the terms and conditions of this Agreement.
|
| 22 |
+
|
| 23 |
+
5. Ownership. We do not claim ownership to any information or content generated using the Model or Derivative Model that are made by You. You are solely responsible for evaluating the accuracy and appropriateness of such information or content for Your use case.
|
| 24 |
+
|
| 25 |
+
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of Huawei, except as required for complying with Section 4.2.
|
| 26 |
+
|
| 27 |
+
7. Indemnity. You will indemnify and hold harmless Huawei from and against any claim by any third party arising out of or related to Your use or distribution of the Model or Derivative Model made by You (e.g. a violation against Section 3). For avoidance of doubt, “third party” in this clause include supervisory authorities.
|
| 28 |
+
|
| 29 |
+
8. THE MODEL IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, NONINFRINGEMENT, ACCURACY, OR THE ABSENCE OF LATENT OR OTHER DEFECTS OR ERRORS, WHETHER OR NOT DISCOVERABLE, ALL TO THE GREATEST EXTENT PERMISSIBLE UNDER APPLICABLE LAW.
|
| 30 |
+
|
| 31 |
+
9. IN NO EVENT SHALL WE BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE MODEL, IN WHOLE OR IN PART, NO MATTER HOW IT’S CAUSED OR THE LEGAL THEORY IT IS BASED ON, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
END OF THE TERMS AND CONDITIONS
|
README.md
CHANGED
|
@@ -1,6 +1,108 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 开源盘古 Ultra-MoE-718B
|
| 2 |
+
中文 | [English](README_EN.md)
|
| 3 |
+
|
| 4 |
+
## 1. 简介
|
| 5 |
+
openPangu-Ultra-MoE-718B 是基于昇腾NPU从零训练的大规模混合专家语言模型,总参数量为718B,激活参数量为39B。openPangu-Ultra-MoE-718B 训练了约19T tokens,具备快慢思考融合能力。
|
| 6 |
+
|
| 7 |
+
## 2. 模型架构
|
| 8 |
+
openPangu-Ultra-MoE-718B 的模型架构采用了业界主流的Multi-head Latent Attention (MLA)、Multi-Token Prediction (MTP)、大稀疏比等架构,以及一些特有的设计:
|
| 9 |
+
|
| 10 |
+
- Depth-Scaled Sandwich-Norm和TinyInit:通过调整层归一化结构与参数初始化,提升训练稳定性。
|
| 11 |
+
- 基于EP-Group的负载均衡策略:通过优化负载均衡损失函数,改善专家特化效果。
|
| 12 |
+
|
| 13 |
+
## 3. 测评结果
|
| 14 |
+
|
| 15 |
+
| 测评集 | 测评指标 | 慢思考 |
|
| 16 |
+
|:----------------:|:----------------------------:|:-----:|
|
| 17 |
+
| **通用能力** | | |
|
| 18 |
+
| C-Eval | Acc | 91.06 |
|
| 19 |
+
| CLUEWSC | Acc | 94.67 |
|
| 20 |
+
| MMLU-Pro | Exact Match | 82.40 |
|
| 21 |
+
| ArenaHard_v0.1 | w/o Style Control | 96.80 |
|
| 22 |
+
| GPQA-Diamond | Avg@4 | 76.77 |
|
| 23 |
+
| SuperGPQA | Acc | 61.67 |
|
| 24 |
+
| IF-Eval | Prompt Strict | 80.59 |
|
| 25 |
+
| SysBench | Constraint Satisfaction Rate | 91.43 |
|
| 26 |
+
| **数学能力** | | |
|
| 27 |
+
| CNMO 2024 | Avg@32 | 80.73 |
|
| 28 |
+
| AIME25 | Avg@16 | 75.21 |
|
| 29 |
+
| AIME24 | Avg@16 | 80.21 |
|
| 30 |
+
| MATH-500 | Avg@1 | 97.40 |
|
| 31 |
+
| **代码能力** | | |
|
| 32 |
+
| LiveCodeBench | Avg@3 (01/25~05/25) | 61.14 |
|
| 33 |
+
| MBPP+ | Avg@2 | 81.48 |
|
| 34 |
+
|
| 35 |
+
**注:** 评测过程中,system prompt 为空。
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
## 4. 部署和使用
|
| 39 |
+
### 4.1 环境准备
|
| 40 |
+
#### 硬件规格
|
| 41 |
+
Atlas 800T A2 (64GB, >=32卡),驱动与固件安装包获取请参照[[Atlas 800T A2](https://www.hiascend.com/hardware/firmware-drivers/community?product=4&model=26&cann=8.2.RC1.alpha003&driver=Ascend+HDK+25.0.RC1)]
|
| 42 |
+
|
| 43 |
+
#### 软件环境
|
| 44 |
+
- 方式一:基于裸机环境安装以下配套软件
|
| 45 |
+
- 操作系统:Linux(推荐openEuler>=24.03)
|
| 46 |
+
- CANN==8.1.RC1,安装准备及流程请参照[[CANN Install](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/82RC1alpha002/softwareinst/instg/instg_0001.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)]
|
| 47 |
+
- python==3.10
|
| 48 |
+
- torch==2.1.0
|
| 49 |
+
- torch-npu==2.1.0.post12
|
| 50 |
+
- transformers>=4.48.2
|
| 51 |
+
|
| 52 |
+
- 方式二:从docker镜像启动容器
|
| 53 |
+
|
| 54 |
+
参考[[Docker使用指南](doc/docker.md)]
|
| 55 |
+
|
| 56 |
+
以上软件配套经过验证,理论可以支持更高的版本,如有疑问,可以提交issue。
|
| 57 |
+
|
| 58 |
+
### 4.2 权重完整性校验
|
| 59 |
+
请参考以下方法对下载内容进行完整性校验,hash 值存储在 checklist.chk 文件中。
|
| 60 |
+
|
| 61 |
+
```
|
| 62 |
+
#!/usr/bin/env bash
|
| 63 |
+
ARCH=$(uname -m)
|
| 64 |
+
MODEL_PATH="${TARGET_FOLDER}/${MODEL_FOLDER_PATH}"
|
| 65 |
+
cd "$MODEL_PATH" || exit 1
|
| 66 |
+
if [ "$ARCH" = "arm64" ]; then
|
| 67 |
+
sha256sum checklist.chk
|
| 68 |
+
else
|
| 69 |
+
sha256sum -c checklist.chk
|
| 70 |
+
fi
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
### 4.3 推理权重转换
|
| 74 |
+
本次样例 openPangu-Ultra-MoE-718B 推理采用 Tensor Parallel 并行策略,叠加昇腾 NPU 融合大算子,需要提前对 safetensors 权重进行切分,下述内容提供32卡并行推理的权重切分示例,切分后的权重会保存在`model/`目录下:
|
| 75 |
+
```bash
|
| 76 |
+
cd inference
|
| 77 |
+
bash split_weight.sh
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
### 4.4 推理样例
|
| 81 |
+
openPangu-Ultra-MoE-718B 在 Atlas 800T A2 上4机32卡bfloat16推理示例,主节点选取节点IP0:
|
| 82 |
+
```bash
|
| 83 |
+
cd inference
|
| 84 |
+
# 主节点IP0: ${NNODES} ${NODE_RANK} ${NPROC_PER_NODE} ${MASTER_ADDR} ${PROMPT}
|
| 85 |
+
bash generate.sh 4 0 8 IP0 "3*7=?"
|
| 86 |
+
# 从节点IP1
|
| 87 |
+
bash generate.sh 4 1 8 IP0 "3*7=?"
|
| 88 |
+
# 从节点IP2
|
| 89 |
+
bash generate.sh 4 2 8 IP0 "3*7=?"
|
| 90 |
+
# 从节点IP3
|
| 91 |
+
bash generate.sh 4 3 8 IP0 "3*7=?"
|
| 92 |
+
```
|
| 93 |
+
模型默认为慢思考模式,可以通过以下手段切换至快思考模式:如`generate.py`示例中`fast_thinking_template`所示,在用户输入结尾添加` /no_think`标记可以将当前轮次切换至快思考模式。
|
| 94 |
+
|
| 95 |
+
### 4.5 使用推理框架
|
| 96 |
+
vllm_ascend:参考[[vllm_ascend_for_openPangu_ultra_moe_718b](doc/vllm_ascend_for_openpangu_ultra_moe_718b.md)]
|
| 97 |
+
|
| 98 |
+
## 5. 模型许可证
|
| 99 |
+
除文件中对开源许可证另有约定外,openPangu-Ultra-MoE-718B 模型根据 OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0 授权,旨在允许使用并促进人工智能技术的进一步发展。有关详细信息,请参阅模型存储库根目录中的 [LICENSE](LICENSE) 文件。
|
| 100 |
+
|
| 101 |
+
## 6. 免责声明
|
| 102 |
+
由于 openPangu-Ultra-MoE-718B (“模型”)所依赖的技术固有的限制,以及人工智能生成的内容是由盘古自动生成的,华为无法对以下事项做出任何保证:
|
| 103 |
+
- 该模型的输出通过AI算法自动生成,不能排除某些信息可能存在缺陷、不合理或引起不适的可能性,生成的内容不代表华为的态度或立场;
|
| 104 |
+
- 无法保证该模型100%准确、可靠、功能齐全、及时、安全、无错误、不间断、持续稳定或无任何故障;
|
| 105 |
+
- 该模型的输出内容不构成任何建议或决策,也不保证生成的内容的真实性、完整性、准确性、及时性、合法性、功能性或实用性。生成的内容不能替代医疗、法律等领域的专业人士回答您的问题。生成的内容仅供参考,不代表华为的任何态度、立场或观点。您需要根据实际情况做出独立判断,华为不承担任何责任。
|
| 106 |
+
|
| 107 |
+
## 7. 反馈
|
| 108 |
+
如果有任何意见和建议,请提交issue或联系[openPangu@huawei.com](url)。
|
README_EN.md
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# openPangu-Ultra-MoE-718B
|
| 2 |
+
English | [中文](README.md)
|
| 3 |
+
|
| 4 |
+
## 1. Introduction
|
| 5 |
+
The openPangu-Ultra-MoE-718B is a large-scale mixture-of-experts language model trained from scratch on Ascend NPU, with a total parameter count of 718B and 39B activated parameters per token. The openPangu-Ultra-MoE-718B is trained on approximately 19 trillion tokens, and equipped with the capability to switch between fast and slow thinking.
|
| 6 |
+
|
| 7 |
+
## 2. Model Architecture
|
| 8 |
+
The architecture of the openPangu-Ultra-MoE-718B adopts the mainstream Multi-head Latent Attention (MLA), Multi-Token Prediction (MTP), high MoE sparsity, and features several different designs:
|
| 9 |
+
|
| 10 |
+
- Depth-Scaled Sandwich-Norm and TinyInit: These techniques adjust the layer normalization structure and parameter initialization for improved training stability.
|
| 11 |
+
|
| 12 |
+
- EP-Group load balancing loss: This technique optimizes the load balancing loss, achieving better expert specialization.
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
## 3. Results
|
| 16 |
+
|
| 17 |
+
| Benchmark | Metric | Slow-thinking |
|
| 18 |
+
|:-------------------------:|:------------------------------:|:-----------------:|
|
| 19 |
+
| **General** | | |
|
| 20 |
+
| C-Eval | Acc | 91.06 |
|
| 21 |
+
| CLUEWSC | Acc | 94.67 |
|
| 22 |
+
| MMLU-Pro | Exact Match | 82.40 |
|
| 23 |
+
| ArenaHard_v0.1 | w/o Style Control | 96.80 |
|
| 24 |
+
| GPQA-Diamond | Avg@4 | 76.77 |
|
| 25 |
+
| SuperGPQA | Acc | 61.67 |
|
| 26 |
+
| IF-Eval | Prompt Strict | 80.59 |
|
| 27 |
+
| SysBench | Constraint Satisfaction Rate | 91.43 |
|
| 28 |
+
| **Math** | | |
|
| 29 |
+
| CNMO 2024 | Avg@32 | 80.73 |
|
| 30 |
+
| AIME25 | Avg@16 | 75.21 |
|
| 31 |
+
| AIME24 | Avg@16 | 80.21 |
|
| 32 |
+
| MATH-500 | Avg@1 | 97.40 |
|
| 33 |
+
| **Coding** | | |
|
| 34 |
+
| LiveCodeBench | Avg@3 (01/25~05/25) | 61.14 |
|
| 35 |
+
| MBPP+ | Avg@2 | 81.48 |
|
| 36 |
+
|
| 37 |
+
**Note:** The system prompt is empty during the evaluation process.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
## 4. Deployment
|
| 41 |
+
### 4.1 Environment
|
| 42 |
+
#### Hardware Requirements
|
| 43 |
+
Atlas 800T A2 (64GB, >=32 NPUs), please refer to [[Atlas 800T A2](https://www.hiascend.com/hardware/firmware-drivers/community?product=4&model=26&cann=8.2.RC1.alpha003&driver=Ascend+HDK+25.0.RC1)] for obtaining the driver and firmware installation packages.
|
| 44 |
+
|
| 45 |
+
#### System Requirements & Dependencies
|
| 46 |
+
- Method 1: Install the following supporting software in a bare-metal environment.
|
| 47 |
+
- System: Linux (openEuler ≥ 24.03 recommended)
|
| 48 |
+
- CANN==8.1.RC1, please refer to [[CANN Install](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/82RC1alpha002/softwareinst/instg/instg_0001.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)] for installation
|
| 49 |
+
- python==3.10
|
| 50 |
+
- torch==2.1.0
|
| 51 |
+
- torch-npu==2.1.0.post12
|
| 52 |
+
- transformers>=4.48.2
|
| 53 |
+
|
| 54 |
+
- Method 2: Start a container from a docker image.
|
| 55 |
+
|
| 56 |
+
Refer to the [[Docker User Guide](doc/docker_EN.md)]
|
| 57 |
+
|
| 58 |
+
The above software environment has been verified, and theoretically supports newer versions. For any questions, please submit an issue.
|
| 59 |
+
|
| 60 |
+
### 4.2 Integrity Check
|
| 61 |
+
Please refer to the following methods to verify the integrity of the downloaded content. The hash values are stored in the `checklist.chk` file.
|
| 62 |
+
|
| 63 |
+
```
|
| 64 |
+
#!/usr/bin/env bash
|
| 65 |
+
ARCH=$(uname -m)
|
| 66 |
+
MODEL_PATH="${TARGET_FOLDER}/${MODEL_FOLDER_PATH}"
|
| 67 |
+
cd "$MODEL_PATH" || exit 1
|
| 68 |
+
if [ "$ARCH" = "arm64" ]; then
|
| 69 |
+
sha256sum checklist.chk
|
| 70 |
+
else
|
| 71 |
+
sha256sum -c checklist.chk
|
| 72 |
+
fi
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
### 4.3 Model Weights Conversion
|
| 76 |
+
This inference example of the openPangu-Ultra-MoE-718B adopts Tensor Parallel strategy with fused operators on Ascend NPU. It requires pre-sharding of the model weights. The following provides an example of weight sharding for 32-NPU parallel inference, with the split weights saved in the `model/` directory.
|
| 77 |
+
```bash
|
| 78 |
+
cd inference
|
| 79 |
+
bash split_weight.sh
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
### 4.4 Inference Examples
|
| 83 |
+
The following provides a simple bfloat16 inference example of the openPangu-Ultra-MoE-718B deployed on a 4-node 32-NPU Atlas 800T A2 cluster, for which the node IP0 is selected as the master node:
|
| 84 |
+
```bash
|
| 85 |
+
cd inference
|
| 86 |
+
# Master node IP0: ${NNODES} ${NODE_RANK} ${NPROC_PER_NODE} ${MASTER_ADDR} ${PROMPT}
|
| 87 |
+
bash generate.sh 4 0 8 IP0 "3*7=?"
|
| 88 |
+
# Worker node IP1
|
| 89 |
+
bash generate.sh 4 1 8 IP0 "3*7=?"
|
| 90 |
+
# Worker node IP2
|
| 91 |
+
bash generate.sh 4 2 8 IP0 "3*7=?"
|
| 92 |
+
# Worker node IP3
|
| 93 |
+
bash generate.sh 4 3 8 IP0 "3*7=?"
|
| 94 |
+
```
|
| 95 |
+
The model operates by default in slow thinking mode and can be configured to fast thinking mode through the following method: As demonstrated in the `fast_thinking_template` within the `generate.py` example, appending the ` /no_think` flag to the end of user input.
|
| 96 |
+
|
| 97 |
+
### 4.5 Using Inference Framework
|
| 98 |
+
Vllm-ascend:please refer to [[vllm_ascend_for_openpangu_ultra_moe_718b_EN](doc/vllm_ascend_for_openpangu_ultra_moe_718b_EN.md)]
|
| 99 |
+
|
| 100 |
+
## 5. Model License
|
| 101 |
+
Unless otherwise noted, the openPangu-Ultra-MoE-718B model is licensed under the terms and conditions of OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0, which is intended to be used permissively and enable the further development of artificial intelligence technologies. Please refer to the [LICENSE](LICENSE) file located in the root directory of the model repository for details.
|
| 102 |
+
|
| 103 |
+
## 6. Disclaimer
|
| 104 |
+
Due to the technical limitations inherent in the technology on which the openPangu-Ultra-MoE-718B (“Model”) relies and the fact that the artificial intelligence generated content is automatically produced by Model, Huawei cannot make any guarantees regarding the following matters:
|
| 105 |
+
|
| 106 |
+
- The output of this Model is automatically generated via AI algorithms, it does not rule out the possibility that some of the information may be flawed, unreasonable, or cause discomfort, and the generated content does not represent Huawei's attitude or standpoint;
|
| 107 |
+
- There is no guarantee that this Model is 100% accurate, reliable, functional, timely, secure and safety, error-free, uninterrupted, continuously stable, or free of any faults;
|
| 108 |
+
- The output of this Model does not constitute any advices or decisions for you, and it does not guarantee the authenticity, completeness, accuracy, timeliness, legality, functionality, or practicality of the generated content. The generated content cannot replace professionals in medical, legal, and other fields in answering your questions. The generated content is for your reference only and does not represent any attitude, standpoint, or position of Huawei. You need to make independent judgments based on your actual situation, and Huawei does not assume any responsibilities.
|
| 109 |
+
|
| 110 |
+
## 7. Contact
|
| 111 |
+
If you have any question, please raise an issue or contact us at [openPangu@huawei.com](url).
|
checklist.chk
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
714e6540d1371ad78a5816a6c69e2f0e330594939dd14b9dac7041e784d701a4 *./tokenizer_config.json
|
| 2 |
+
6b16f1558c0cd4ae6ef1a2c605713be0a514f50e1ce2d2c878979ce988c148ec *./tokenizer.model
|
| 3 |
+
81c7a7c24ed70acdeeff295a5a091a401cacfb49dcc30308a5b451956b051223 *./tokenization_openpangu.py
|
| 4 |
+
b34cf5e7c7660889303b6e2d0a346c440356385c9db551d06f6615cf9fc600d1 *./special_tokens_map.json
|
| 5 |
+
035c0bc169b317e19096ac2f81d6fa790850533f573bedda9d3e88a1183aed9b *./modeling_openpangu_moe.py
|
| 6 |
+
6580557dcb86f3285d3c056eeb21eb9b702e0fa908577ba471cc91ec6c9b3fcc *./model.safetensors.index.json
|
| 7 |
+
bc6505adabc0498ad07b49187858788c65c13dbf9446fd0bcf177a3e1b27220d *./inference/vllm_ascend/worker/npu_input_batch.py
|
| 8 |
+
62c6734d1283e3d649a6478d2004f46bfee2f7878af7f2849c979b124e355302 *./inference/vllm_ascend/worker/model_runner_v1.py
|
| 9 |
+
e2457c558f048876afe069d1226e7080ac214478f1a9ac28ae472928b81b5a06 *./inference/vllm_ascend/utils.py
|
| 10 |
+
6adfaa8a67ea9b561dec2e6a2392f6fc85ff376fb2030d8761c34c6c6d3f4cbf *./inference/vllm_ascend/quantization/w8a8_dynamic.py
|
| 11 |
+
743bd96cfc109975a11fe5412c4b5de46f880501dcbbbdd10e11cbeb865fa4f2 *./inference/vllm_ascend/quantization/w8a8.py
|
| 12 |
+
e712ea36caf16c2a9dd21c5288f9d8e34c7fd2face444da44dca6db6c21f6c1b *./inference/vllm_ascend/patch/worker/patch_common/patch_sampler.py
|
| 13 |
+
8c59df8086bde0cd4df674403f83000921a34403651a8ff2b31de9b28768247a *./inference/vllm_ascend/patch/worker/patch_common/patch_parsers.py
|
| 14 |
+
8436ab93933989431160e55627b5dce5326f0fc5ec18263653902764ac8ace7b *./inference/vllm_ascend/patch/worker/patch_common/patch_config.py
|
| 15 |
+
63a6ba0d0b0158d4586219c979bf96d5fe87b74123af93f1c8d9ed842db96500 *./inference/vllm_ascend/patch/worker/patch_common/__init__.py
|
| 16 |
+
09273eb0e4696d2fb530881ba1ad9d331897dd81c0cd2f203ed3d0a475b4d39b *./inference/vllm_ascend/ops/fused_moe.py
|
| 17 |
+
b654e72ece161b3f04080e5c4d2476641c024939ac5308115fe1c65a6c5c7215 *./inference/vllm_ascend/models/open_pangu.py
|
| 18 |
+
e98aa2549f02017a35b07499216fe569e86400684087821820cf2d971c8fcbac *./inference/vllm_ascend/models/__init__.py
|
| 19 |
+
52a968f10ebaebeb626248afd3e1d1b92f8fbfcaad19ebf05cafbc0bd03192cb *./inference/vllm_ascend/envs.py
|
| 20 |
+
91eab52cdc19603b7b705b302e25345d849e18fa66875261a1135d5382392123 *./inference/vllm_ascend/entrypoints/openai/tool_parsers/pangu_tool_parser.py
|
| 21 |
+
d07256c9014f911f81269e65aad6c0d7dd61d4e82f5cb399e05285d5c1bc8fa8 *./inference/vllm_ascend/entrypoints/openai/tool_parsers/__init__.py
|
| 22 |
+
f9577c29bc4dc19a4cc41ccfcca17065402c9dd92221bef987c74808b23ed124 *./inference/vllm_ascend/entrypoints/openai/reasoning_parsers/pangu_reasoning_parser.py
|
| 23 |
+
9070682b058a79d2b2874ba5e07ce72beff6efb870f75cdac30cdcf6ba8fadc7 *./inference/vllm_ascend/entrypoints/openai/reasoning_parsers/__init__.py
|
| 24 |
+
2254aeca0be7b8922318e10c4a950f39afb30ba5fe3b46564a58671b237ac612 *./inference/vllm_ascend/attention/mla_v1.py
|
| 25 |
+
ba6d7edcf1cf464d6fd787b12a9bda2a16fea0ac0d5d1e54136baec503d6e696 *./inference/vllm_ascend/attention/attention.py
|
| 26 |
+
4aaf57e6f6d2e139b3847b10ee59d738398ebfc4927a22325b27dad384874aec *./inference/vllm_ascend/_build_info.py
|
| 27 |
+
5cd02a8ec3b7494e08c6cea07927c377b5da177b9cf704988d3c2ca367528c09 *./inference/split_weight.sh
|
| 28 |
+
a2ea110364d7eecf039e180b42986f1c830acc3fd6ac487f45f93f89f285c669 *./inference/split_weight.py
|
| 29 |
+
bd8f9bb4a9cd7d182a499307746c939514b4487ebbecbdf9527482f2c31aed9a *./inference/runner_config/tp32.yaml
|
| 30 |
+
379d51e424a24b6599a7e5b85187ced84a8d016c051133432d0a4eaa58a5c900 *./inference/runner_config/tp1.yaml
|
| 31 |
+
72f5bf3c6e4a990d5c9b2412f4b6bf9534b99ce7e1c196f7a42740a6f006e7ad *./inference/runner.py
|
| 32 |
+
85841e6a1bc03eff679e3cf00958864f86e3260dc90e69379132f2e3dc6674ad *./inference/model.py
|
| 33 |
+
d646343af70b851f1942ee6dbdc1807686c13a48d4d660ebac777dba29eafdd1 *./inference/generate.sh
|
| 34 |
+
d4d48848ec2c7945670a6196323a570d1830e4edcf5336919641e58bcbc9da0a *./inference/generate.py
|
| 35 |
+
feb36ae08104d00af5bd32e6e20b67a11588e02911b15b3261650b22e1db3ad8 *./generation_config.json
|
| 36 |
+
634ef0ce7809d0e44d31fecf72863146a08e609a8ba9cbe16b427f0de12fe2e0 *./configuration_openpangu_moe.py
|
| 37 |
+
9988eca928867694568154f72de3cedaca34b258954652b3c95772d9a5f5118e *./config.json
|
| 38 |
+
fca1e83c102d7a4b9c9f1bcfdc8766df2037d77550ca33dc45c8342fd5b73d0d *./model-00062-of-000062.safetensors
|
| 39 |
+
0548a05b2afa374e731f8785ad5ee7302335dcba952d3797b759ad920f3f4fce *./model-00061-of-000062.safetensors
|
| 40 |
+
4351cc1440af69fca1735baa78c09658b08c9ae67fcc076ad4fa9ea2d25e084c *./model-00060-of-000062.safetensors
|
| 41 |
+
8802bb0b554a1e6404b58b13f328c82fb4f2e2d6d7524f7b4f14d7bb6e81d0f3 *./model-00059-of-000062.safetensors
|
| 42 |
+
59a1f7667adcfff55145b4b21789e7c494aa80d8bc2fa466d0fa43cd0f3ff43e *./model-00058-of-000062.safetensors
|
| 43 |
+
65c1cd81cf879606bcddc4af60b506e9bcda80955be72693e642d3a325ffd8e7 *./model-00057-of-000062.safetensors
|
| 44 |
+
b72df9faac3f73cf7191dbb454229f74bc22e1a98d5ad7ef2aea85e0b14d4123 *./model-00056-of-000062.safetensors
|
| 45 |
+
4ded3c2baad08d7646cf3da25dcecda3bcbdd542f90acf76848d93000b3ab23b *./model-00055-of-000062.safetensors
|
| 46 |
+
86cc6e275ab190244cc2bfc0cff5689fb7c145623ec54d03ed351df6af829ec2 *./model-00054-of-000062.safetensors
|
| 47 |
+
a5bba51810248ad0a9a54e55b56b946732737ee5f5f289fab580b7c36f21d3f6 *./model-00053-of-000062.safetensors
|
| 48 |
+
cd153a6a0fd5dea192768257c71d47655d7a7a1dbdffdaffeab4dfb7ad27b3eb *./model-00052-of-000062.safetensors
|
| 49 |
+
85000bb98b77232c47542490df67faec31e3d5105b7d4de34e46466d53220bd6 *./model-00051-of-000062.safetensors
|
| 50 |
+
0da40653e3dcb3a32aa929674062224f3a74fa3eeefb6dcc5a6698cd9f596708 *./model-00050-of-000062.safetensors
|
| 51 |
+
678d2b3ac3c73f387a18a446d83dd866ffc925f82ff56541d3b71a747c6b7d06 *./model-00049-of-000062.safetensors
|
| 52 |
+
a49b5e3f1be43a6a3bff9fec25d5d8dffad5a8ebd96c8e389933f40287f1888e *./model-00048-of-000062.safetensors
|
| 53 |
+
54d984279e435df0426cb052ed99067b7a4a1637234a24a8f05bcb7bbd89d0d2 *./model-00047-of-000062.safetensors
|
| 54 |
+
4fe5e98cb4e492639bafe37c3e3144c3fe8f27a9fd45e48c92c780f39599becc *./model-00046-of-000062.safetensors
|
| 55 |
+
97458250006f5949c65b338a26503738200c5fb2415f4cc664a6b224aa9dce70 *./model-00045-of-000062.safetensors
|
| 56 |
+
fd548640dbe4cc04ef4834ac32cada57fb43b0fb9971f0e36030d94041dd1b0d *./model-00044-of-000062.safetensors
|
| 57 |
+
58847a3be6d606e21941769346e86165891f4fa7242cc84cda7edc9620298ad2 *./model-00043-of-000062.safetensors
|
| 58 |
+
f0be4dc1d9326543061e858d566955aefa9d4a757ebd8f92df648bd9c16a236b *./model-00042-of-000062.safetensors
|
| 59 |
+
d6de2cffa758d32d790c89ac7df7aa313ec1a2923faf52549357a3e8ff16d74f *./model-00041-of-000062.safetensors
|
| 60 |
+
cf9fb7c2ca6e977d9e6f19bd3204ba8b8ad15d33cef9a0ad9f6e9b867319fc8b *./model-00040-of-000062.safetensors
|
| 61 |
+
2f12c46798cd8f51740964ea58b59ef682eca6ee2ae9402283214f1f0fd4113c *./model-00039-of-000062.safetensors
|
| 62 |
+
0a531b364281f6060567549b741a1a53a81c1d9211170354bf32c88494b486e9 *./model-00038-of-000062.safetensors
|
| 63 |
+
e3fe1e5795ffb480aa934410434f30711a4dc982f7eea3df5ac1da57783bc619 *./model-00037-of-000062.safetensors
|
| 64 |
+
e36d0dbfbfbb7a792246c931c89d95be2e7afff1872ff455114c58d82b7398d2 *./model-00036-of-000062.safetensors
|
| 65 |
+
af8bfd55a58902cdd8293c67f2c9c5e069ca38f3086645a05bfb9140fe6d51ba *./model-00035-of-000062.safetensors
|
| 66 |
+
9aa7ab7d596db78af87e87a92e8111fc673000b3132cffe02fa2e164d77b8b32 *./model-00034-of-000062.safetensors
|
| 67 |
+
dfa6683035c9caca4de02afc6b3b4fc5fe786e0ee687ddfa1f72c4171eb33821 *./model-00033-of-000062.safetensors
|
| 68 |
+
cdb597fd48c542dd1508d9ce069b8e7c19007908fbdd2dc8a443d73bee6f1754 *./model-00032-of-000062.safetensors
|
| 69 |
+
457f476851463f36e3fdcb54e54dc39c9890e64b6b06b100ac5013c3c72385e4 *./model-00031-of-000062.safetensors
|
| 70 |
+
3d1ee1e248181a08f2401d314f71d342a95e8c9fbee5877800b392be54b1343f *./model-00030-of-000062.safetensors
|
| 71 |
+
1856663f2077287ba21f3bd821705c1f55bac979e768d67fc31c52727b34ae2f *./model-00029-of-000062.safetensors
|
| 72 |
+
0cfe77fd4bfac0f9623411d1234872e141b211eb0b3cf61238380f5b34c3c043 *./model-00028-of-000062.safetensors
|
| 73 |
+
933263ed6db42b1e16407b4400b70260425b1cc11f9ed8fdaad6ef5935f05fb4 *./model-00027-of-000062.safetensors
|
| 74 |
+
d15f1371a364df11676b105291e481fbbd1999ea2152ec0b14905f5d9cb854fa *./model-00026-of-000062.safetensors
|
| 75 |
+
744dedfcb6bc74624351cc299772ee6be389147301b05f4fe645ebac7bedb53b *./model-00025-of-000062.safetensors
|
| 76 |
+
fd9cf078f3a819e230a78fdf8201e37fd25696f576e1fd0d46fb122deb11c2c8 *./model-00024-of-000062.safetensors
|
| 77 |
+
3ec91acf57a576b8550d14312dda2acf345ce09769407d1596e0cbdff5a1200a *./model-00023-of-000062.safetensors
|
| 78 |
+
6df9b0923f9ec0ca7ab20dd0934fc36467d02f02d0cb996904b2f1181d37502b *./model-00022-of-000062.safetensors
|
| 79 |
+
0e4ac1edb16a327ed30dd75325fdb9d6e7da6d35724fc086c66c74822d6a1de7 *./model-00021-of-000062.safetensors
|
| 80 |
+
9837d1382bdab63f5781e6b76f3634775a94cd792adadb4d51763693aa670c36 *./model-00020-of-000062.safetensors
|
| 81 |
+
6a1de462c450f0ddc2e379224cc5ca0ebd31b9bcb9102ed4eb7effe8298de963 *./model-00019-of-000062.safetensors
|
| 82 |
+
1bf72bcf187656c13ac2f003fa534bbb79d22374711a7d5315766b2299800c4e *./model-00018-of-000062.safetensors
|
| 83 |
+
41bda994ada5d86c56f660db9350acd900201d1bc131379ce892e60564154e5f *./model-00017-of-000062.safetensors
|
| 84 |
+
009a80118b5e38fc072a5bcaf20efa832148c14e55765ae2e4176edda11d6608 *./model-00016-of-000062.safetensors
|
| 85 |
+
310713a4eeebdb60bd62b1c9c0a51bcd6efd6fdefde5c31d4adc8ce854d06f23 *./model-00015-of-000062.safetensors
|
| 86 |
+
91ef0b66652923f58267a69bb39a18da89ce71732a7a92b458447bb938fb17e7 *./model-00014-of-000062.safetensors
|
| 87 |
+
deabb0cc16d4bcaa6576583acefa13c92e0bc16c6209e8db0fdea2bb55b45501 *./model-00013-of-000062.safetensors
|
| 88 |
+
95d7980844ee1411f71358f8268d537f0f6c2d5d87a15e101e956d1b3c9c61b2 *./model-00012-of-000062.safetensors
|
| 89 |
+
4661dd325fda474f8490505bfe9b3652ae2976284825d17cfb7bd60d01aacae2 *./model-00011-of-000062.safetensors
|
| 90 |
+
9f3168f097ba85d266b062732aab8cb28daae08fef108305291819750be1b384 *./model-00010-of-000062.safetensors
|
| 91 |
+
91547f7296b08e93903a250eec9a25562f714a7bdfab511ae4fd0f358aaec832 *./model-00009-of-000062.safetensors
|
| 92 |
+
5ee6ca3f506708b453076c56027fe7366958ef4baa06fdd69f0992e15544ad17 *./model-00008-of-000062.safetensors
|
| 93 |
+
07fa95a4e6b3e9b3475076f160e839cb715ab426fb78a63d8e2decb636cb8987 *./model-00007-of-000062.safetensors
|
| 94 |
+
a0b071272706a4d4d4ed5a331d60590930a3529c13c4bc158b6f1b0bc3dd8c85 *./model-00006-of-000062.safetensors
|
| 95 |
+
4e3907f683d7f8382d2a792304155e8533ffa3a94dd4bb5ff825124b0dba3835 *./model-00005-of-000062.safetensors
|
| 96 |
+
2bd5c0012a3cedf4582160173f857480dd58426282a0e5826609a02b5aff5b3e *./model-00004-of-000062.safetensors
|
| 97 |
+
3f63aa17d947032e0a524b5798eee3becbfc9a9b6f8a352ead3232e7b34bb289 *./model-00003-of-000062.safetensors
|
| 98 |
+
1e29a512e3737d1826c80a2277a8b42021878847753aadbe5e1ae2a2df3d7f8d *./model-00002-of-000062.safetensors
|
| 99 |
+
c692b00cbc19ee5a2d4a9bb71496f12a846d14e06d7e0ac79b46abc3243ee115 *./model-00001-of-000062.safetensors
|
config.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"PanguUltraMoEForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_openpangu_moe.PanguUltraMoEConfig",
|
| 8 |
+
"AutoModel": "modeling_openpangu_moe.PanguUltraMoEModel",
|
| 9 |
+
"AutoModelForCausalLM": "modeling_openpangu_moe.PanguUltraMoEForCausalLM"
|
| 10 |
+
},
|
| 11 |
+
"num_dense_layers": 3,
|
| 12 |
+
"hidden_act": "silu",
|
| 13 |
+
"hidden_size": 7680,
|
| 14 |
+
"initializer_range": 0.02,
|
| 15 |
+
"intermediate_size": 18432,
|
| 16 |
+
"attention_kv_lora_dim": 512,
|
| 17 |
+
"max_position_embeddings": 131072,
|
| 18 |
+
"model_type": "pangu_ultra_moe",
|
| 19 |
+
"moe_intermediate_size": 2048,
|
| 20 |
+
"num_routed_experts": 256,
|
| 21 |
+
"num_shared_experts": 1,
|
| 22 |
+
"num_attention_heads": 128,
|
| 23 |
+
"num_experts_per_tok": 8,
|
| 24 |
+
"num_hidden_layers": 61,
|
| 25 |
+
"num_key_value_heads": 128,
|
| 26 |
+
"num_mtp_layers": 1,
|
| 27 |
+
"attention_q_lora_dim": 1536,
|
| 28 |
+
"attention_qk_dim": 128,
|
| 29 |
+
"attention_qk_rope_dim": 64,
|
| 30 |
+
"rms_norm_eps": 1e-05,
|
| 31 |
+
"rope_theta": 25600000,
|
| 32 |
+
"routed_scaling_factor": 2.5,
|
| 33 |
+
"sandwich_norm": true,
|
| 34 |
+
"tie_word_embeddings": false,
|
| 35 |
+
"torch_dtype": "bfloat16",
|
| 36 |
+
"transformers_version": "4.48.2",
|
| 37 |
+
"use_cache": true,
|
| 38 |
+
"attention_v_dim": 128,
|
| 39 |
+
"vocab_size": 153600
|
| 40 |
+
}
|
configuration_openpangu_moe.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 3 |
+
|
| 4 |
+
"""openPanguUltraMoE 718B model configuration"""
|
| 5 |
+
|
| 6 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PanguUltraMoEConfig(PretrainedConfig):
|
| 10 |
+
|
| 11 |
+
model_type = "pangu_ultra_moe"
|
| 12 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
vocab_size=153600,
|
| 17 |
+
hidden_size=7680,
|
| 18 |
+
intermediate_size=18432,
|
| 19 |
+
moe_intermediate_size=2048,
|
| 20 |
+
num_hidden_layers=61,
|
| 21 |
+
num_mtp_layers=1,
|
| 22 |
+
num_attention_heads=128,
|
| 23 |
+
num_key_value_heads=128,
|
| 24 |
+
num_shared_experts=1,
|
| 25 |
+
num_routed_experts=256,
|
| 26 |
+
routed_scaling_factor=2.5,
|
| 27 |
+
attention_kv_lora_dim=512,
|
| 28 |
+
attention_q_lora_dim=1536,
|
| 29 |
+
attention_qk_rope_dim=64,
|
| 30 |
+
attention_v_dim=128,
|
| 31 |
+
attention_qk_dim=128,
|
| 32 |
+
num_experts_per_tok=8,
|
| 33 |
+
num_dense_layers=3,
|
| 34 |
+
norm_topk_prob=True,
|
| 35 |
+
hidden_act="silu",
|
| 36 |
+
max_position_embeddings=131072,
|
| 37 |
+
initializer_range=0.02,
|
| 38 |
+
rms_norm_eps=1e-5,
|
| 39 |
+
use_cache=True,
|
| 40 |
+
pad_token_id=None,
|
| 41 |
+
bos_token_id=0,
|
| 42 |
+
eos_token_id=1,
|
| 43 |
+
tie_word_embeddings=False,
|
| 44 |
+
rope_theta=25600000,
|
| 45 |
+
attention_dropout=0.0,
|
| 46 |
+
**kwargs,
|
| 47 |
+
):
|
| 48 |
+
self.vocab_size = vocab_size
|
| 49 |
+
self.max_position_embeddings = max_position_embeddings
|
| 50 |
+
self.hidden_size = hidden_size
|
| 51 |
+
self.num_hidden_layers = num_hidden_layers
|
| 52 |
+
self.num_attention_heads = num_attention_heads
|
| 53 |
+
self.num_key_value_heads = num_key_value_heads
|
| 54 |
+
self.hidden_act = hidden_act
|
| 55 |
+
self.initializer_range = initializer_range
|
| 56 |
+
self.rms_norm_eps = rms_norm_eps
|
| 57 |
+
self.use_cache = use_cache
|
| 58 |
+
self.rope_theta = rope_theta
|
| 59 |
+
|
| 60 |
+
self.num_dense_layers = num_dense_layers
|
| 61 |
+
self.intermediate_size = intermediate_size
|
| 62 |
+
self.moe_intermediate_size = moe_intermediate_size
|
| 63 |
+
self.num_shared_experts = num_shared_experts
|
| 64 |
+
self.num_routed_experts = num_routed_experts
|
| 65 |
+
self.routed_scaling_factor = routed_scaling_factor
|
| 66 |
+
self.num_experts_per_tok = num_experts_per_tok
|
| 67 |
+
self.norm_topk_prob = norm_topk_prob
|
| 68 |
+
self.attention_kv_lora_dim = attention_kv_lora_dim
|
| 69 |
+
self.attention_q_lora_dim = attention_q_lora_dim
|
| 70 |
+
self.attention_qk_rope_dim = attention_qk_rope_dim
|
| 71 |
+
self.attention_v_dim = attention_v_dim
|
| 72 |
+
self.attention_qk_dim = attention_qk_dim
|
| 73 |
+
self.attention_dropout = attention_dropout
|
| 74 |
+
self.num_mtp_layers = num_mtp_layers
|
| 75 |
+
|
| 76 |
+
super().__init__(
|
| 77 |
+
pad_token_id=pad_token_id,
|
| 78 |
+
bos_token_id=bos_token_id,
|
| 79 |
+
eos_token_id=eos_token_id,
|
| 80 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 81 |
+
**kwargs,
|
| 82 |
+
)
|
doc/docker.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Atlas 800T A2一键部署
|
| 2 |
+
|
| 3 |
+
参考使用vllm-ascend社区开源[镜像](https://vllm-ascend.readthedocs.io/en/latest/quick_start.html)
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
export IMAGE=m.daocloud.io/quay.io/ascend/vllm-ascend:v0.9.2rc1
|
| 7 |
+
export NAME=atlas_800t_a2
|
| 8 |
+
docker run --rm \
|
| 9 |
+
--name $NAME \
|
| 10 |
+
--net=host \
|
| 11 |
+
--device /dev/davinci0 \
|
| 12 |
+
--device /dev/davinci1 \
|
| 13 |
+
--device /dev/davinci2 \
|
| 14 |
+
--device /dev/davinci3 \
|
| 15 |
+
--device /dev/davinci4 \
|
| 16 |
+
--device /dev/davinci5 \
|
| 17 |
+
--device /dev/davinci6 \
|
| 18 |
+
--device /dev/davinci7 \
|
| 19 |
+
--device /dev/davinci_manager \
|
| 20 |
+
--device /dev/devmm_svm \
|
| 21 |
+
--device /dev/hisi_hdc \
|
| 22 |
+
-v /usr/local/dcmi:/usr/local/dcmi \
|
| 23 |
+
-v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
|
| 24 |
+
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
| 25 |
+
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
|
| 26 |
+
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
|
| 27 |
+
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
| 28 |
+
-v /data:/data \
|
| 29 |
+
-v /tmp:/tmp \
|
| 30 |
+
-it $IMAGE bash
|
| 31 |
+
```
|
doc/docker_EN.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Deploy on Atlas 800T A2
|
| 2 |
+
|
| 3 |
+
Please refer to the vllm-ascend community open-source [mirror](https://vllm-ascend.readthedocs.io/en/latest/quick_start.html)
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
export IMAGE=m.daocloud.io/quay.io/ascend/vllm-ascend:v0.9.2rc1
|
| 7 |
+
export NAME=atlas_800t_a2
|
| 8 |
+
docker run --rm \
|
| 9 |
+
--name $NAME \
|
| 10 |
+
--net=host \
|
| 11 |
+
--device /dev/davinci0 \
|
| 12 |
+
--device /dev/davinci1 \
|
| 13 |
+
--device /dev/davinci2 \
|
| 14 |
+
--device /dev/davinci3 \
|
| 15 |
+
--device /dev/davinci4 \
|
| 16 |
+
--device /dev/davinci5 \
|
| 17 |
+
--device /dev/davinci6 \
|
| 18 |
+
--device /dev/davinci7 \
|
| 19 |
+
--device /dev/davinci_manager \
|
| 20 |
+
--device /dev/devmm_svm \
|
| 21 |
+
--device /dev/hisi_hdc \
|
| 22 |
+
-v /usr/local/dcmi:/usr/local/dcmi \
|
| 23 |
+
-v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
|
| 24 |
+
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
| 25 |
+
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
|
| 26 |
+
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
|
| 27 |
+
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
| 28 |
+
-v /data:/data \
|
| 29 |
+
-v /tmp:/tmp \
|
| 30 |
+
-it $IMAGE bash
|
| 31 |
+
```
|
doc/vllm_ascend_for_openpangu_ultra_moe_718b.md
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## openPangu-Ultra-MoE-718B在[vllm-ascend](https://github.com/vllm-project/vllm-ascend)部署指导文档
|
| 2 |
+
|
| 3 |
+
### 部署环境说明
|
| 4 |
+
|
| 5 |
+
Atlas 800T A2(64GB) 64卡可以部署openPangu-Ultra-MoE-718B(bf16),32卡可部署盘古 Ultra MoE (int8),选用vllm-ascend社区镜像v0.9.1-dev,多个节点都需拉取镜像。
|
| 6 |
+
```bash
|
| 7 |
+
docker pull quay.io/ascend/vllm-ascend:v0.9.1-dev
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
* 网络环境检测
|
| 11 |
+
在每个节点上依次执行以下命令。所有结果必须为 success 且状态必须为 UP:
|
| 12 |
+
```bash
|
| 13 |
+
# Check the remote switch ports
|
| 14 |
+
for i in {0..7}; do hccn_tool -i $i -lldp -g | grep Ifname; done
|
| 15 |
+
# Get the link status of the Ethernet ports (UP or DOWN)
|
| 16 |
+
for i in {0..7}; do hccn_tool -i $i -link -g ; done
|
| 17 |
+
# Check the network health status
|
| 18 |
+
for i in {0..7}; do hccn_tool -i $i -net_health -g ; done
|
| 19 |
+
# View the network detected IP configuration
|
| 20 |
+
for i in {0..7}; do hccn_tool -i $i -netdetect -g ; done
|
| 21 |
+
# View gateway configuration
|
| 22 |
+
for i in {0..7}; do hccn_tool -i $i -gateway -g ; done
|
| 23 |
+
# View NPU network configuration
|
| 24 |
+
cat /etc/hccn.conf
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
### 镜像启动和推理代码适配
|
| 28 |
+
|
| 29 |
+
以下操作需在每个节点都执行。
|
| 30 |
+
|
| 31 |
+
启动镜像。
|
| 32 |
+
```bash
|
| 33 |
+
# Update the vllm-ascend image
|
| 34 |
+
export IMAGE=quay.io/ascend/vllm-ascend:v0.9.1-dev # Use correct image id
|
| 35 |
+
export NAME=vllm-ascend # Custom docker name
|
| 36 |
+
|
| 37 |
+
# Run the container using the defined variables
|
| 38 |
+
# Note if you are running bridge network with docker, Please expose available ports for multiple nodes communication in advance
|
| 39 |
+
# To prevent device interference from other docker containers, add the argument "--privileged"
|
| 40 |
+
docker run --rm \
|
| 41 |
+
--name $NAME \
|
| 42 |
+
--network host \
|
| 43 |
+
--device /dev/davinci0 \
|
| 44 |
+
--device /dev/davinci1 \
|
| 45 |
+
--device /dev/davinci2 \
|
| 46 |
+
--device /dev/davinci3 \
|
| 47 |
+
--device /dev/davinci4 \
|
| 48 |
+
--device /dev/davinci5 \
|
| 49 |
+
--device /dev/davinci6 \
|
| 50 |
+
--device /dev/davinci7 \
|
| 51 |
+
--device /dev/davinci_manager \
|
| 52 |
+
--device /dev/devmm_svm \
|
| 53 |
+
--device /dev/hisi_hdc \
|
| 54 |
+
-v /usr/local/dcmi:/usr/local/dcmi \
|
| 55 |
+
-v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
|
| 56 |
+
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
| 57 |
+
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
|
| 58 |
+
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
|
| 59 |
+
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
| 60 |
+
-v /mnt/sfs_turbo/.cache:/root/.cache \
|
| 61 |
+
-it $IMAGE bash
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
如果未进入容器,需以root用户进入容器。
|
| 65 |
+
```
|
| 66 |
+
docker exec -itu root $NAME /bin/bash
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
下载vllm(v0.9.2),替换镜像内置的vllm代码。
|
| 70 |
+
```bash
|
| 71 |
+
pip install --no-deps vllm==0.9.2 pybase64==1.4.1
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
下载[vllm-ascend (v0.9.2rc1)](https://github.com/vllm-project/vllm-ascend/releases/tag/v0.9.2rc1),替换镜像内置的vllm-ascend代码(`/vllm-workspace/vllm-ascend/`)。例如下载Assets中的[Source code
|
| 75 |
+
(tar.gz)](https://github.com/vllm-project/vllm-ascend/archive/refs/tags/v0.9.2rc1.tar.gz)得到v0.9.2rc1.tar.gz,然后解压并替换:
|
| 76 |
+
```bash
|
| 77 |
+
tar -zxvf vllm-ascend-0.9.2rc1.tar.gz -C /vllm-workspace/vllm-ascend/ --strip-components=1
|
| 78 |
+
export PYTHONPATH=/vllm-workspace/vllm-ascend/:${PYTHONPATH}
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
使用当前代码仓中适配盘古模型的vllm-ascend代码替换`/vllm-workspace/vllm-ascend/vllm_ascend/`中的部分代码。
|
| 82 |
+
```bash
|
| 83 |
+
yes | cp -r inference/vllm_ascend/* /vllm-workspace/vllm-ascend/vllm_ascend/
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
### BF16推理
|
| 87 |
+
|
| 88 |
+
以下操作需在每个节点都执行。
|
| 89 |
+
|
| 90 |
+
运行命令:
|
| 91 |
+
```bash
|
| 92 |
+
# This obtained through ifconfig
|
| 93 |
+
# nic_name is the network interface name corresponding to local_ip
|
| 94 |
+
local_ip=`hostname -I | cut -d' ' -f1`
|
| 95 |
+
nic_name=$(ifconfig | grep -B 1 "$local_ip" | head -n 1 | awk '{print $1}' | sed 's/://')
|
| 96 |
+
export HCCL_IF_IP=$local_ip
|
| 97 |
+
export GLOO_SOCKET_IFNAME=$nic_name
|
| 98 |
+
export TP_SOCKET_IFNAME=$nic_name
|
| 99 |
+
export HCCL_SOCKET_IFNAME=$nic_name
|
| 100 |
+
export OMP_PROC_BIND=false
|
| 101 |
+
export OMP_NUM_THREADS=100
|
| 102 |
+
export VLLM_USE_V1=1
|
| 103 |
+
export HCCL_BUFFSIZE=1024
|
| 104 |
+
export VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP=1
|
| 105 |
+
export VLLM_ASCEND_ENABLE_TOP_N_SIGMA=1 # enable top-n-sigma sampling
|
| 106 |
+
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
|
| 107 |
+
|
| 108 |
+
MASTER_NODE_IP=xxx.xxx.xxx.xxx # master/head node ip
|
| 109 |
+
NODE_RANK=xxx # current node rank (0~7)
|
| 110 |
+
NUM_NODES=8 # number of nodes
|
| 111 |
+
NUM_NPUS_LOCAL=8 # number of NPUs per node
|
| 112 |
+
DATA_PARALLEL_SIZE_LOCAL=4 # DP size per node, can be set to 1, 2, or 4
|
| 113 |
+
LOCAL_CKPT_DIR=/root/.cache/pangu_ultra_moe # The pangu_ultra_moe bf16 weight
|
| 114 |
+
# Specifying HOST=127.0.0.1 (localhost) means the server can only be accessed from the master device.
|
| 115 |
+
# Specifying HOST=0.0.0.0 allows the vLLM server to be accessed from other devices on the same network or even from the internet, provided proper network configuration (e.g., firewall rules, port forwarding) is in place.
|
| 116 |
+
HOST=xxx.xxx.xxx.xxx
|
| 117 |
+
|
| 118 |
+
if [[ $NODE_RANK -ne 0 ]]; then
|
| 119 |
+
headless="--headless"
|
| 120 |
+
else
|
| 121 |
+
headless=""
|
| 122 |
+
fi
|
| 123 |
+
|
| 124 |
+
vllm serve $LOCAL_CKPT_DIR \
|
| 125 |
+
--host $HOST \
|
| 126 |
+
--port 8004 \
|
| 127 |
+
--data-parallel-size $((NUM_NODES*DATA_PARALLEL_SIZE_LOCAL)) \
|
| 128 |
+
--data-parallel-size-local $DATA_PARALLEL_SIZE_LOCAL \
|
| 129 |
+
--data-parallel-start-rank $((DATA_PARALLEL_SIZE_LOCAL*NODE_RANK)) \
|
| 130 |
+
--data-parallel-address $MASTER_NODE_IP \
|
| 131 |
+
--data-parallel-rpc-port 13389 \
|
| 132 |
+
--tensor-parallel-size $((NUM_NPUS_LOCAL/DATA_PARALLEL_SIZE_LOCAL)) \
|
| 133 |
+
--seed 1024 \
|
| 134 |
+
--served-model-name pangu_ultra_moe \
|
| 135 |
+
--enable-expert-parallel \
|
| 136 |
+
--max-num-seqs 8 \
|
| 137 |
+
--max-model-len 32768 \
|
| 138 |
+
--max-num-batched-tokens 4096 \
|
| 139 |
+
--trust-remote-code \
|
| 140 |
+
--no-enable-prefix-caching \
|
| 141 |
+
--gpu-memory-utilization 0.9 \
|
| 142 |
+
${headless} \
|
| 143 |
+
--additional-config '{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}'
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
### 发请求测试
|
| 147 |
+
|
| 148 |
+
服务启动后,在主节点或者其他节点向主节点发送测试请求:
|
| 149 |
+
|
| 150 |
+
```bash
|
| 151 |
+
curl http://${MASTER_NODE_IP}:8004/v1/chat/completions \
|
| 152 |
+
-H "Content-Type: application/json" \
|
| 153 |
+
-d '{
|
| 154 |
+
"model": "pangu_ultra_moe",
|
| 155 |
+
"messages": [
|
| 156 |
+
{
|
| 157 |
+
"role": "user",
|
| 158 |
+
"content": "Who are you?"
|
| 159 |
+
}
|
| 160 |
+
],
|
| 161 |
+
"max_tokens": 512,
|
| 162 |
+
"temperature": 0.7,
|
| 163 |
+
"top_p": 1.0,
|
| 164 |
+
"top_k": -1,
|
| 165 |
+
"vllm_xargs": {"top_n_sigma": 0.05}
|
| 166 |
+
}'
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
### Int8推理
|
| 170 |
+
|
| 171 |
+
#### ModelSlim量化
|
| 172 |
+
|
| 173 |
+
openPangu-Ultra-MoE-718B模型支持使用开源量化框架[ModelSlim](https://gitcode.com/Ascend/msit/blob/br_noncom_pangu_ultra_moe_8.1.RC1_POC_20251231/msmodelslim/example/Pangu/README.md)进行量化,当前模型支持W8A8权重激活量化。
|
| 174 |
+
|
| 175 |
+
##### openPangu-Ultra-MoE-718B W8A8 动态量化
|
| 176 |
+
|
| 177 |
+
```bash
|
| 178 |
+
python3 quant_pangu_ultra_moe_w8a8.py --model_path {浮点权重路径} --save_path {W8A8量化权重路径} --dynamic
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
##### openPangu-Ultra-MoE-718B W8A8 混合量化 + MTP 量化
|
| 182 |
+
|
| 183 |
+
生成openPangu-Ultra-MoE-718B模型W8A8量化权重(含MTP)
|
| 184 |
+
```bash
|
| 185 |
+
python3 quant_pangu_ultra_moe_w8a8.py --model_path {浮点权重路径} --save_path {W8A8量化权重路径} --dynamic --quant_mtp mix
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
相较于BF16模型,int8量化模型的config.json增加以下字段:
|
| 189 |
+
```
|
| 190 |
+
"mla_quantize": "w8a8",
|
| 191 |
+
"quantize": "w8a8_dynamic",
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
如果MTP量化,增加字段:
|
| 195 |
+
```
|
| 196 |
+
"mtp_quantize": "w8a8_dynamic",
|
| 197 |
+
```
|
| 198 |
+
ModelSlim量化脚本生成量化模型后会自动追加上述字段到config.json中。
|
| 199 |
+
|
| 200 |
+
#### Int8推理
|
| 201 |
+
|
| 202 |
+
相较于BF16模型推理,int8量化模型推理仅需使用4节点(32卡),修改变量
|
| 203 |
+
```bash
|
| 204 |
+
NUM_NODES=4
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
启动命令需要修改为对应的量化权重路径,另外增加`--quantization ascend`:
|
| 208 |
+
```bash
|
| 209 |
+
LOCAL_CKPT_DIR=/root/.cache/pang_ultra_moe_w8a8
|
| 210 |
+
|
| 211 |
+
vllm serve $LOCAL_CKPT_DIR \
|
| 212 |
+
...
|
| 213 |
+
--quantization ascend
|
| 214 |
+
...
|
| 215 |
+
```
|
doc/vllm_ascend_for_openpangu_ultra_moe_718b_EN.md
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Deployment Guide of the openPangu-Ultra-MoE-718B Based on [vllm-ascend](https://github.com/vllm-project/vllm-ascend)
|
| 2 |
+
|
| 3 |
+
### Deployment Environment Description
|
| 4 |
+
|
| 5 |
+
The Atlas 800T A2 (64 GB) supports the deployment of the openPangu-Ultra-MoE-718B (bf16) with 64 cards and the deployment of the openPangu-Ultra-MoE-718B (int8) with 32 cards. The vllm-ascend community image v0.9.1-dev is used and needs to be pulled on multiple nodes.
|
| 6 |
+
```bash
|
| 7 |
+
docker pull quay.io/ascend/vllm-ascend:v0.9.1-dev
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
* Network Detection
|
| 11 |
+
Run the following commands on each node: All the results must be success and the status must be UP.
|
| 12 |
+
```bash
|
| 13 |
+
# Check the remote switch ports
|
| 14 |
+
for i in {0..7}; do hccn_tool -i $i -lldp -g | grep Ifname; done
|
| 15 |
+
# Get the link status of the Ethernet ports (UP or DOWN)
|
| 16 |
+
for i in {0..7}; do hccn_tool -i $i -link -g ; done
|
| 17 |
+
# Check the network health status
|
| 18 |
+
for i in {0..7}; do hccn_tool -i $i -net_health -g ; done
|
| 19 |
+
# View the network detected IP configuration
|
| 20 |
+
for i in {0..7}; do hccn_tool -i $i -netdetect -g ; done
|
| 21 |
+
# View gateway configuration
|
| 22 |
+
for i in {0..7}; do hccn_tool -i $i -gateway -g ; done
|
| 23 |
+
# View NPU network configuration
|
| 24 |
+
cat /etc/hccn.conf
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
### Docker Startup and Inference Code Adaptation
|
| 28 |
+
|
| 29 |
+
Perform the following operations on all nodes.
|
| 30 |
+
|
| 31 |
+
Run the following command to start the docker:
|
| 32 |
+
```bash
|
| 33 |
+
# Update the vllm-ascend image
|
| 34 |
+
export IMAGE=quay.io/ascend/vllm-ascend:v0.9.1-dev # Use correct image id
|
| 35 |
+
export NAME=vllm-ascend # Custom docker name
|
| 36 |
+
|
| 37 |
+
# Run the container using the defined variables
|
| 38 |
+
# Note if you are running bridge network with docker, Please expose available ports for multiple nodes communication in advance
|
| 39 |
+
# To prevent device interference from other docker containers, add the argument "--privileged"
|
| 40 |
+
docker run --rm \
|
| 41 |
+
--name $NAME \
|
| 42 |
+
--network host \
|
| 43 |
+
--device /dev/davinci0 \
|
| 44 |
+
--device /dev/davinci1 \
|
| 45 |
+
--device /dev/davinci2 \
|
| 46 |
+
--device /dev/davinci3 \
|
| 47 |
+
--device /dev/davinci4 \
|
| 48 |
+
--device /dev/davinci5 \
|
| 49 |
+
--device /dev/davinci6 \
|
| 50 |
+
--device /dev/davinci7 \
|
| 51 |
+
--device /dev/davinci_manager \
|
| 52 |
+
--device /dev/devmm_svm \
|
| 53 |
+
--device /dev/hisi_hdc \
|
| 54 |
+
-v /usr/local/dcmi:/usr/local/dcmi \
|
| 55 |
+
-v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
|
| 56 |
+
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
| 57 |
+
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
|
| 58 |
+
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
|
| 59 |
+
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
| 60 |
+
-v /mnt/sfs_turbo/.cache:/root/.cache \
|
| 61 |
+
-it $IMAGE bash
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
If not inside the container, enter the container as the root user:
|
| 65 |
+
```
|
| 66 |
+
docker exec -itu root $NAME /bin/bash
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
Download vllm (v0.9.2) to replace the built-in vllm code of the image.
|
| 70 |
+
```bash
|
| 71 |
+
pip install --no-deps vllm==0.9.2 pybase64==1.4.1
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
Download [vllm-ascend (v0.9.2rc1)](https://github.com/vllm-project/vllm-ascend/releases/tag/v0.9.2rc1) and replace the built-in vllm-ascend code in the image (/vllm-workspace/vllm-ascend/). For example, download [Source code (tar.gz)](https://github.com/vllm-project/vllm-ascend/archive/refs/tags/v0.9.2rc1.tar.gz) from Assets to get v0.9.2rc1.tar.gz, then extract and replace:
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
tar -zxvf vllm-ascend-0.9.2rc1.tar.gz -C /vllm-workspace/vllm-ascend/ --strip-components=1
|
| 78 |
+
export PYTHONPATH=/vllm-workspace/vllm-ascend/:${PYTHONPATH}
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
Use the Pangu model-adapted vllm-ascend code from the current repository to replace parts of the code in `/vllm-workspace/vllm-ascend/vllm_ascend/`:
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
yes | cp -r inference/vllm_ascend/* /vllm-workspace/vllm-ascend/vllm_ascend/
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### BF16 Inference
|
| 88 |
+
|
| 89 |
+
Perform the following operations on all nodes.
|
| 90 |
+
|
| 91 |
+
Run command:
|
| 92 |
+
```bash
|
| 93 |
+
# This obtained through ifconfig
|
| 94 |
+
# nic_name is the network interface name corresponding to local_ip
|
| 95 |
+
local_ip=`hostname -I | cut -d' ' -f1`
|
| 96 |
+
nic_name=$(ifconfig | grep -B 1 "$local_ip" | head -n 1 | awk '{print $1}' | sed 's/://')
|
| 97 |
+
export HCCL_IF_IP=$local_ip
|
| 98 |
+
export GLOO_SOCKET_IFNAME=$nic_name
|
| 99 |
+
export TP_SOCKET_IFNAME=$nic_name
|
| 100 |
+
export HCCL_SOCKET_IFNAME=$nic_name
|
| 101 |
+
export OMP_PROC_BIND=false
|
| 102 |
+
export OMP_NUM_THREADS=100
|
| 103 |
+
export VLLM_USE_V1=1
|
| 104 |
+
export HCCL_BUFFSIZE=1024
|
| 105 |
+
export VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP=1
|
| 106 |
+
export VLLM_ASCEND_ENABLE_TOP_N_SIGMA=1 # enable top-n-sigma sampling
|
| 107 |
+
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
|
| 108 |
+
|
| 109 |
+
MASTER_NODE_IP=xxx.xxx.xxx.xxx # master/head node ip
|
| 110 |
+
NODE_RANK=xxx # current node rank (0~7)
|
| 111 |
+
NUM_NODES=8 # number of nodes
|
| 112 |
+
NUM_NPUS_LOCAL=8 # number of NPUs per node
|
| 113 |
+
DATA_PARALLEL_SIZE_LOCAL=4 # DP size per node, can be set to 1, 2, or 4
|
| 114 |
+
LOCAL_CKPT_DIR=/root/.cache/pangu_ultra_moe # The pangu_ultra_moe bf16 weight
|
| 115 |
+
# Specifying HOST=127.0.0.1 (localhost) means the server can only be accessed from the master device.
|
| 116 |
+
# Specifying HOST=0.0.0.0 allows the vLLM server to be accessed from other devices on the same network or even from the internet, provided proper network configuration (e.g., firewall rules, port forwarding) is in place.
|
| 117 |
+
HOST=xxx.xxx.xxx.xxx
|
| 118 |
+
|
| 119 |
+
if [[ $NODE_RANK -ne 0 ]]; then
|
| 120 |
+
headless="--headless"
|
| 121 |
+
else
|
| 122 |
+
headless=""
|
| 123 |
+
fi
|
| 124 |
+
|
| 125 |
+
vllm serve $LOCAL_CKPT_DIR \
|
| 126 |
+
--host $HOST \
|
| 127 |
+
--port 8004 \
|
| 128 |
+
--data-parallel-size $((NUM_NODES*DATA_PARALLEL_SIZE_LOCAL)) \
|
| 129 |
+
--data-parallel-size-local $DATA_PARALLEL_SIZE_LOCAL \
|
| 130 |
+
--data-parallel-start-rank $((DATA_PARALLEL_SIZE_LOCAL*NODE_RANK)) \
|
| 131 |
+
--data-parallel-address $MASTER_NODE_IP \
|
| 132 |
+
--data-parallel-rpc-port 13389 \
|
| 133 |
+
--tensor-parallel-size $((NUM_NPUS_LOCAL/DATA_PARALLEL_SIZE_LOCAL)) \
|
| 134 |
+
--seed 1024 \
|
| 135 |
+
--served-model-name pangu_ultra_moe \
|
| 136 |
+
--enable-expert-parallel \
|
| 137 |
+
--max-num-seqs 8 \
|
| 138 |
+
--max-model-len 32768 \
|
| 139 |
+
--max-num-batched-tokens 4096 \
|
| 140 |
+
--trust-remote-code \
|
| 141 |
+
--no-enable-prefix-caching \
|
| 142 |
+
--gpu-memory-utilization 0.9 \
|
| 143 |
+
${headless} \
|
| 144 |
+
--additional-config '{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}'
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
### Test Request
|
| 148 |
+
|
| 149 |
+
After server launched, send test request from master node or other nodes:
|
| 150 |
+
|
| 151 |
+
```bash
|
| 152 |
+
curl http://${MASTER_NODE_IP}:8004/v1/chat/completions \
|
| 153 |
+
-H "Content-Type: application/json" \
|
| 154 |
+
-d '{
|
| 155 |
+
"model": "pangu_ultra_moe",
|
| 156 |
+
"messages": [
|
| 157 |
+
{
|
| 158 |
+
"role": "user",
|
| 159 |
+
"content": "Who are you?"
|
| 160 |
+
}
|
| 161 |
+
],
|
| 162 |
+
"max_tokens": 512,
|
| 163 |
+
"temperature": 0.7,
|
| 164 |
+
"top_p": 1.0,
|
| 165 |
+
"top_k": -1,
|
| 166 |
+
"vllm_xargs": {"top_n_sigma": 0.05}
|
| 167 |
+
}'
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
### Int8 Inference
|
| 171 |
+
|
| 172 |
+
#### ModelSlim Quantization
|
| 173 |
+
|
| 174 |
+
The openPangu-Ultra-MoE-718B model supports quantization using the open source quantization framework [ModelSlim](https://gitcode.com/Ascend/msit/blob/br_noncom_pangu_ultra_moe_8.1.RC1_POC_20251231/msmodelslim/example/Pangu/README.md). The current model supports W8A8 weight activation quantization.
|
| 175 |
+
|
| 176 |
+
##### openPangu-Ultra-MoE-718B W8A8 Dynamic quantization
|
| 177 |
+
|
| 178 |
+
```bash
|
| 179 |
+
Python3 quant_pangu_ultra_moe_w8a8.py --model_path {bf16 weight path} --save_path {W8A8 weight path} --dynamic
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
##### openPangu-Ultra-MoE-718B W8A8 Hybrid quantization + MTP quantization
|
| 183 |
+
|
| 184 |
+
Generate the openPangu-Ultra-MoE-718B model W8A8 quantization weight (including MTP)
|
| 185 |
+
```bash
|
| 186 |
+
python3 quant_pangu_ultra_moe_w8a8.py --model_path {bf16 weight path} --save_path {W8A8 weight path} --dynamic --quant_mtp mix
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
Compared with the BF16 model, the following fields are added to the config.json file of the int8 quantization model:
|
| 190 |
+
```
|
| 191 |
+
"mla_quantize": "w8a8",
|
| 192 |
+
"quantize": "w8a8_dynamic",
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
If the MTP is included, the following fields are added:
|
| 196 |
+
```
|
| 197 |
+
"mtp_quantize": "w8a8_dynamic",
|
| 198 |
+
```
|
| 199 |
+
After the ModelSlim quantization script generates a quantization model, the preceding fields are automatically added to the config.json file.
|
| 200 |
+
|
| 201 |
+
#### Int8 Inference
|
| 202 |
+
|
| 203 |
+
Compared with the BF16 model inference, the int8 quantization model inference uses only four nodes (32 cards). Variables are modified.
|
| 204 |
+
```bash
|
| 205 |
+
NUM_NODES=4
|
| 206 |
+
```
|
| 207 |
+
|
| 208 |
+
The startup command needs to be changed to the corresponding quantization weight path, and add `--quantization ascend`:
|
| 209 |
+
```bash
|
| 210 |
+
LOCAL_CKPT_DIR=/root/.cache/pangu_ultra_moe_w8a8
|
| 211 |
+
|
| 212 |
+
vllm serve $LOCAL_CKPT_DIR \
|
| 213 |
+
...
|
| 214 |
+
--quantization ascend
|
| 215 |
+
...
|
| 216 |
+
```
|
generation_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 0,
|
| 4 |
+
"eos_token_id": 45892,
|
| 5 |
+
"do_sample": true,
|
| 6 |
+
"temperature": 0.7,
|
| 7 |
+
"top_p": 1.0,
|
| 8 |
+
"top_n_sigma": 0.05,
|
| 9 |
+
"top_k": -1,
|
| 10 |
+
"transformers_version": "4.48.2"
|
| 11 |
+
}
|
inference/generate.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import yaml
|
| 9 |
+
from model import PanguUltraMoEForCausalLM
|
| 10 |
+
from runner import ModelRunner
|
| 11 |
+
|
| 12 |
+
root_logger = logging.getLogger()
|
| 13 |
+
root_logger.handlers.clear()
|
| 14 |
+
logging.basicConfig(
|
| 15 |
+
format="%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s",
|
| 16 |
+
level=logging.INFO,
|
| 17 |
+
)
|
| 18 |
+
torch.manual_seed(42)
|
| 19 |
+
torch.npu.manual_seed_all(42)
|
| 20 |
+
|
| 21 |
+
"""
|
| 22 |
+
NOTE:
|
| 23 |
+
For enhancing model safety, we recommend the following system prompt.
|
| 24 |
+
It is suggested to be removed for other normal use cases and model evaluation.
|
| 25 |
+
"""
|
| 26 |
+
safe_word = "你必须严格遵守法律法规和社会道德规范。" \
|
| 27 |
+
"生成任何内容时,都应避免涉及暴力、色情、恐怖主义、种族歧视、性别歧视等不当内容。" \
|
| 28 |
+
"一旦检测到输入或输出有此类倾向,应拒绝回答并发出警告。例如,如果输入内容包含暴力威胁或色情描述," \
|
| 29 |
+
"应返回错误信息:“您的输入包含不当内容,无法处理。”"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# basic token generator
|
| 33 |
+
def generate_default_prompt(bs):
|
| 34 |
+
# prompt batch size define actual model forward batch size
|
| 35 |
+
fast_thinking_template = "[unused9]系统:%s[unused10][unused9]用户:{} /no_think[unused10][unused9]助手:" % (safe_word,)
|
| 36 |
+
slow_thinking_template = "[unused9]系统:%s[unused10][unused9]用户:{}[unused10][unused9]助手:" % (safe_word,)
|
| 37 |
+
preset_prompts = [slow_thinking_template.format(args.prompt)]
|
| 38 |
+
preset_prompts = preset_prompts * (bs // len(preset_prompts) + 1)
|
| 39 |
+
preset_prompts = preset_prompts[:bs]
|
| 40 |
+
logging.info(f"prompt batch size: {bs}")
|
| 41 |
+
return preset_prompts
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def generate_chat_prompt(bs):
|
| 45 |
+
preset_prompts = [
|
| 46 |
+
{"role": "system", "content": safe_word},
|
| 47 |
+
{"role": "user", "content": args.prompt}
|
| 48 |
+
]
|
| 49 |
+
preset_prompts = [preset_prompts] * (bs // len(preset_prompts) + 1)
|
| 50 |
+
preset_prompts = preset_prompts[:bs]
|
| 51 |
+
logging.info(f"chat prompt batch size: {bs}")
|
| 52 |
+
return preset_prompts
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def generate_prompt(bs, tokenizer_mode):
|
| 56 |
+
if tokenizer_mode == "default":
|
| 57 |
+
return generate_default_prompt(bs)
|
| 58 |
+
else:
|
| 59 |
+
return generate_chat_prompt(bs)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def parse_args():
|
| 63 |
+
parser = argparse.ArgumentParser(description="llm run parameters")
|
| 64 |
+
parser.add_argument("--yaml_file_path", type=str, help="inference configurations")
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--local_rank",
|
| 67 |
+
type=int,
|
| 68 |
+
default=0,
|
| 69 |
+
help="Local rank id for torch distributed launch",
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument("--prompt", type=str, default="3*7=?", help="user prompts")
|
| 72 |
+
parser_args = parser.parse_args()
|
| 73 |
+
return parser_args
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def main(runner_config):
|
| 77 |
+
bs = runner_config.get("data_config").get("batch_size", 1)
|
| 78 |
+
tokenizer_mode = runner_config.get("model_config").get("tokenizer_mode", "default")
|
| 79 |
+
preset_prompts = generate_prompt(bs, tokenizer_mode)
|
| 80 |
+
logging.info(f"input prompts: {preset_prompts}")
|
| 81 |
+
model_runner = ModelRunner(runner_config)
|
| 82 |
+
torch.npu.set_compile_mode(jit_compile=False)
|
| 83 |
+
model_runner.init_model(PanguUltraMoEForCausalLM)
|
| 84 |
+
# warmup
|
| 85 |
+
model_runner.model_generate(preset_prompts, warm_up=True)
|
| 86 |
+
# generate
|
| 87 |
+
model_runner.model_generate(preset_prompts)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def read_yaml(yaml_file_path):
|
| 91 |
+
try:
|
| 92 |
+
with open(yaml_file_path, "r", encoding="utf-8") as file:
|
| 93 |
+
data = yaml.safe_load(file)
|
| 94 |
+
except FileNotFoundError:
|
| 95 |
+
logging.error(f"No such yaml file: {yaml_file_path}")
|
| 96 |
+
except yaml.YAMLERROR as e:
|
| 97 |
+
logging.error(f"Load yaml file failed: {e}")
|
| 98 |
+
return data
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
args = parse_args()
|
| 103 |
+
yaml_file_path = args.yaml_file_path
|
| 104 |
+
runner_config = read_yaml(yaml_file_path)
|
| 105 |
+
main(runner_config)
|
| 106 |
+
logging.info("model run success")
|
inference/generate.sh
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
#!/bin/bash
|
| 5 |
+
|
| 6 |
+
# use example:
|
| 7 |
+
# bash generate.sh ${NNODES} ${NODE_RANK} ${NPROC_PER_NODE} ${MASTER_ADDR} ${PROMPTS}
|
| 8 |
+
|
| 9 |
+
# input args
|
| 10 |
+
export NNODES=$1
|
| 11 |
+
export NODE_RANK=$2
|
| 12 |
+
export NPROC_PER_NODE=$3
|
| 13 |
+
export MASTER_ADDR=$4 # master node IP
|
| 14 |
+
export prompt=$5
|
| 15 |
+
export MASTER_PORT=6038 # master node port
|
| 16 |
+
export WORLD_SIZE=32
|
| 17 |
+
export YAML=runner_config/tp32.yaml
|
| 18 |
+
export RANK_OFFSET=`expr $NODE_RANK \* ${NPROC_PER_NODE}`
|
| 19 |
+
|
| 20 |
+
# setup env
|
| 21 |
+
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
|
| 22 |
+
export HCCL_SOCKET_IFNAME=enp # network card prefix. specify it according to the actual situation e.g enp/eth
|
| 23 |
+
export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'` # get current node IP
|
| 24 |
+
export HCCL_IF_BASE_PORT=23456
|
| 25 |
+
export HCCL_OP_EXPANSION_MODE=AIV
|
| 26 |
+
export HCCL_CONNECT_TIMEOUT=1200
|
| 27 |
+
export HCCL_EXEC_TIMEOUT=1200
|
| 28 |
+
if [[ -d "/usr/local/Ascend/ascend-toolkit/latest" ]]; then
|
| 29 |
+
export ASCEND_HOME_PATH=/usr/local/Ascend/ascend-toolkit/latest
|
| 30 |
+
else
|
| 31 |
+
export ASCEND_HOME_PATH=/usr/local/Ascend/latest
|
| 32 |
+
fi
|
| 33 |
+
export PYTHONPATH=${PYTHONPATH}:${ASCEND_HOME_PATH}/python/site-packages/
|
| 34 |
+
|
| 35 |
+
# set result path
|
| 36 |
+
DATE=`date +%Y%m%d`
|
| 37 |
+
export MODEL_NAME="pangu_ultra_moe"
|
| 38 |
+
NAME=${MODEL_NAME}_${WORLD_SIZE}p
|
| 39 |
+
export TASK_QUEUE_ENABLE=2 # eager mode:opt host perf
|
| 40 |
+
export RES_PATH="res/${DATE}/${NAME}"
|
| 41 |
+
WORK_DIR=`pwd`
|
| 42 |
+
DUMP_PRECISION_PATH=${WORK_DIR}'/'${RES_PATH}'/dump_data'
|
| 43 |
+
mkdir -p ${WORK_DIR}'/'${RES_PATH}
|
| 44 |
+
mkdir -p ${DUMP_PRECISION_PATH}
|
| 45 |
+
|
| 46 |
+
# launch multi proc
|
| 47 |
+
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
|
| 48 |
+
avg_core_per_rank=`expr $cores \/ $NPROC_PER_NODE`
|
| 49 |
+
core_gap=`expr $avg_core_per_rank \- 1`
|
| 50 |
+
for((i=0; i<${NPROC_PER_NODE}; i++))
|
| 51 |
+
do
|
| 52 |
+
echo $i
|
| 53 |
+
start=`expr $i \* $avg_core_per_rank`
|
| 54 |
+
end=`expr $start \+ $core_gap`
|
| 55 |
+
cmdopt=$start"-"$end
|
| 56 |
+
export LOCAL_RANK=$i
|
| 57 |
+
export RANK=$(expr $i + $RANK_OFFSET)
|
| 58 |
+
export RANK_ID=$RANK
|
| 59 |
+
if [ $i -eq 0 ];then
|
| 60 |
+
taskset -c $cmdopt python3 generate.py \
|
| 61 |
+
--prompt "$prompt" \
|
| 62 |
+
--yaml_file_path=${YAML} 2>&1 | tee ${WORK_DIR}/${RES_PATH}/log_${LOCAL_RANK}.log &
|
| 63 |
+
else
|
| 64 |
+
taskset -c $cmdopt python3 generate.py \
|
| 65 |
+
--prompt "$prompt" \
|
| 66 |
+
--yaml_file_path=${YAML} &> ${WORK_DIR}/${RES_PATH}/log_${LOCAL_RANK}.log &
|
| 67 |
+
fi
|
| 68 |
+
done
|
| 69 |
+
|
| 70 |
+
wait
|
inference/model.py
ADDED
|
@@ -0,0 +1,918 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 3 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
import warnings
|
| 20 |
+
from typing import Dict, List, Optional, Tuple
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.distributed as dist
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
import torch.utils.checkpoint
|
| 26 |
+
import torch_npu
|
| 27 |
+
from torch import nn
|
| 28 |
+
from torch.distributed.distributed_c10d import _world
|
| 29 |
+
from transformers.activations import ACT2FN
|
| 30 |
+
from transformers.cache_utils import Cache
|
| 31 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
| 32 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 33 |
+
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
|
| 34 |
+
from transformers.utils.import_utils import is_torch_fx_available
|
| 35 |
+
|
| 36 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
| 37 |
+
from configuration_openpangu_moe import PanguUltraMoEConfig
|
| 38 |
+
|
| 39 |
+
if is_torch_fx_available():
|
| 40 |
+
if not is_torch_greater_or_equal_than_1_13:
|
| 41 |
+
import torch.fx
|
| 42 |
+
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class PanguUltraMoERMSNorm(nn.Module):
|
| 46 |
+
def __init__(self, hidden_dim, epsilon=1e-5):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.weight = nn.Parameter(torch.empty(hidden_dim))
|
| 49 |
+
self.epsilon = epsilon
|
| 50 |
+
|
| 51 |
+
def forward(self, hidden_states, *args):
|
| 52 |
+
if len(args) == 0:
|
| 53 |
+
result = torch_npu.npu_rms_norm(hidden_states, self.weight, self.epsilon)[0]
|
| 54 |
+
return result
|
| 55 |
+
elif len(args) == 1 and args[0] is None:
|
| 56 |
+
result = torch_npu.npu_rms_norm(hidden_states, self.weight, self.epsilon)[0]
|
| 57 |
+
residual = hidden_states
|
| 58 |
+
return (result, residual)
|
| 59 |
+
elif len(args) == 1:
|
| 60 |
+
residual = args[0]
|
| 61 |
+
y, _, x = torch_npu.npu_add_rms_norm(
|
| 62 |
+
residual, hidden_states, self.weight, self.epsilon
|
| 63 |
+
)
|
| 64 |
+
return (y, x)
|
| 65 |
+
else:
|
| 66 |
+
raise NotImplementedError(f"PanguUltraMoERMSNorm inner error")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class PanguUltraMoERotaryEmbedding(nn.Module):
|
| 70 |
+
def __init__(
|
| 71 |
+
self, dim, max_position_embeddings=131072, base=25600000.0, device=None
|
| 72 |
+
):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.dim = dim
|
| 75 |
+
self.max_position_embeddings = max_position_embeddings
|
| 76 |
+
self.base = base
|
| 77 |
+
self._set_cache(
|
| 78 |
+
seq_len=max_position_embeddings,
|
| 79 |
+
device=device,
|
| 80 |
+
dtype=torch.get_default_dtype(),
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def _set_cache(self, seq_len, device, dtype):
|
| 84 |
+
self.max_seq_len_cached = seq_len
|
| 85 |
+
dim = self.dim
|
| 86 |
+
|
| 87 |
+
inv_freq = 1.0 / (
|
| 88 |
+
self.base
|
| 89 |
+
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
| 90 |
+
)
|
| 91 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 92 |
+
|
| 93 |
+
t = torch.arange(seq_len, device=device, dtype=torch.float32)
|
| 94 |
+
|
| 95 |
+
freqs = torch.outer(t, inv_freq)
|
| 96 |
+
|
| 97 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 98 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 99 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 100 |
+
|
| 101 |
+
def forward(self, x, kv_len, max_seq_len=None):
|
| 102 |
+
if max_seq_len is None:
|
| 103 |
+
self._set_cache(seq_len=kv_len, device=x.device, dtype=x.dtype)
|
| 104 |
+
elif max_seq_len > self.max_seq_len_cached:
|
| 105 |
+
self._set_cache(seq_len=max_seq_len, device=x.device, dtype=x.dtype)
|
| 106 |
+
|
| 107 |
+
batch_size = x.shape[0]
|
| 108 |
+
seq_len = x.shape[1]
|
| 109 |
+
if seq_len == 1:
|
| 110 |
+
cos = (
|
| 111 |
+
torch.index_select(self.cos_cached, dim=0, index=kv_len)
|
| 112 |
+
.unsqueeze(1)
|
| 113 |
+
.unsqueeze(1)
|
| 114 |
+
)
|
| 115 |
+
sin = (
|
| 116 |
+
torch.index_select(self.sin_cached, dim=0, index=kv_len)
|
| 117 |
+
.unsqueeze(1)
|
| 118 |
+
.unsqueeze(1)
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
cos = (
|
| 122 |
+
self.cos_cached[:seq_len]
|
| 123 |
+
.unsqueeze(0)
|
| 124 |
+
.unsqueeze(2)
|
| 125 |
+
.repeat(batch_size, 1, 1, 1)
|
| 126 |
+
)
|
| 127 |
+
sin = (
|
| 128 |
+
self.sin_cached[:seq_len]
|
| 129 |
+
.unsqueeze(0)
|
| 130 |
+
.unsqueeze(2)
|
| 131 |
+
.repeat(batch_size, 1, 1, 1)
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
cos = cos[0, :, 0, :]
|
| 135 |
+
sin = sin[0, :, 0, :]
|
| 136 |
+
return (
|
| 137 |
+
cos.to(dtype=x.dtype),
|
| 138 |
+
sin.to(dtype=x.dtype),
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def rotate_half(x):
|
| 143 |
+
"""Rotates half the hidden dims of the input."""
|
| 144 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 145 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 146 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
| 150 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
q (`torch.Tensor`): The query tensor.
|
| 154 |
+
k (`torch.Tensor`): The key tensor.
|
| 155 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 156 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 157 |
+
position_ids (`torch.Tensor`):
|
| 158 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
| 159 |
+
used to pass offsetted position ids when working with a KV-cache.
|
| 160 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 161 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 162 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 163 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 164 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 165 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 166 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 167 |
+
Returns:
|
| 168 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 169 |
+
"""
|
| 170 |
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
| 171 |
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
| 172 |
+
|
| 173 |
+
b, h, s, d = q.shape
|
| 174 |
+
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
| 175 |
+
|
| 176 |
+
b, h, s, d = k.shape
|
| 177 |
+
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
| 178 |
+
|
| 179 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 180 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 181 |
+
return q_embed, k_embed
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class MLP(nn.Module):
|
| 185 |
+
def __init__(self, config, runner_config, hidden_size=None, intermediate_size=None):
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.runner_config = runner_config
|
| 188 |
+
self.moe_tp_size = self.runner_config.get("parallel_config").get(
|
| 189 |
+
"moe_tp_size", 1
|
| 190 |
+
)
|
| 191 |
+
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
| 192 |
+
self.intermediate_size = (
|
| 193 |
+
config.intermediate_size if intermediate_size is None else intermediate_size
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
self.intermediate_size_per_rank = self.intermediate_size // self.moe_tp_size
|
| 197 |
+
self.merge_up_gate_proj = nn.Linear(
|
| 198 |
+
self.hidden_size, self.intermediate_size_per_rank * 2, bias=False
|
| 199 |
+
)
|
| 200 |
+
self.down_proj = nn.Linear(
|
| 201 |
+
self.intermediate_size_per_rank, self.hidden_size, bias=False
|
| 202 |
+
)
|
| 203 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 204 |
+
|
| 205 |
+
def forward(self, x):
|
| 206 |
+
merged_x = self.merge_up_gate_proj(x)
|
| 207 |
+
gate_state, up_state = merged_x.chunk(2, dim=-1)
|
| 208 |
+
intermediate_hidden_states = self.act_fn(gate_state) * up_state
|
| 209 |
+
down_proj = self.down_proj(intermediate_hidden_states)
|
| 210 |
+
if self.moe_tp_size > 1:
|
| 211 |
+
dist.all_reduce(down_proj)
|
| 212 |
+
return down_proj
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class MoE(nn.Module):
|
| 216 |
+
def __init__(self, config, runner_config, hidden_size=None, intermediate_size=None):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.runner_config = runner_config
|
| 219 |
+
self.moe_tp_size = self.runner_config.get("parallel_config").get(
|
| 220 |
+
"moe_tp_size", 1
|
| 221 |
+
)
|
| 222 |
+
self.num_experts = config.num_routed_experts
|
| 223 |
+
|
| 224 |
+
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
| 225 |
+
self.intermediate_size = (
|
| 226 |
+
config.intermediate_size if intermediate_size is None else intermediate_size
|
| 227 |
+
)
|
| 228 |
+
self.intermediate_size_per_rank = self.intermediate_size // self.moe_tp_size
|
| 229 |
+
|
| 230 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 231 |
+
|
| 232 |
+
self.group_w1_w3 = nn.Parameter(
|
| 233 |
+
torch.ones(
|
| 234 |
+
self.num_experts, self.intermediate_size_per_rank * 2, self.hidden_size
|
| 235 |
+
),
|
| 236 |
+
requires_grad=False,
|
| 237 |
+
)
|
| 238 |
+
self.group_w2 = nn.Parameter(
|
| 239 |
+
torch.ones(
|
| 240 |
+
self.num_experts, self.hidden_size, self.intermediate_size_per_rank
|
| 241 |
+
),
|
| 242 |
+
requires_grad=False,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
def forward(self, hidden_states, expert_tokens, seq_len=None):
|
| 246 |
+
mm1_mm3 = torch_npu.npu_grouped_matmul(
|
| 247 |
+
[hidden_states],
|
| 248 |
+
[torch.transpose(self.group_w1_w3, 1, 2)],
|
| 249 |
+
group_list=expert_tokens,
|
| 250 |
+
group_type=0,
|
| 251 |
+
split_item=3,
|
| 252 |
+
)[0]
|
| 253 |
+
mm1, mm3 = mm1_mm3.chunk(2, dim=-1)
|
| 254 |
+
intermediate_hidden_states = self.act_fn(mm1) * mm3
|
| 255 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 256 |
+
[intermediate_hidden_states],
|
| 257 |
+
[torch.transpose(self.group_w2, 1, 2)],
|
| 258 |
+
group_list=expert_tokens,
|
| 259 |
+
group_type=0,
|
| 260 |
+
split_item=3,
|
| 261 |
+
)[0]
|
| 262 |
+
return hidden_states
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class MoEGate(nn.Module):
|
| 266 |
+
def __init__(self, config):
|
| 267 |
+
super().__init__()
|
| 268 |
+
self.top_k = config.num_experts_per_tok
|
| 269 |
+
self.routed_scaling_factor = config.routed_scaling_factor
|
| 270 |
+
|
| 271 |
+
self.norm_topk_prob = config.norm_topk_prob
|
| 272 |
+
self.weight = nn.Parameter(
|
| 273 |
+
torch.empty((config.num_routed_experts, config.hidden_size))
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
def forward(self, hidden_states):
|
| 277 |
+
bsz, seq_len, h = hidden_states.shape
|
| 278 |
+
hidden_states = hidden_states.view(-1, h)
|
| 279 |
+
logits = F.linear(
|
| 280 |
+
hidden_states.to(torch.float32), self.weight.to(torch.float32), None
|
| 281 |
+
)
|
| 282 |
+
scores = logits.sigmoid()
|
| 283 |
+
scores_for_choice = scores.view(bsz * seq_len, -1)
|
| 284 |
+
_, topk_idx = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)
|
| 285 |
+
topk_weight = scores.gather(1, topk_idx)
|
| 286 |
+
|
| 287 |
+
if self.top_k > 1 and self.norm_topk_prob:
|
| 288 |
+
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
| 289 |
+
topk_weight = topk_weight / denominator
|
| 290 |
+
topk_weight = topk_weight * self.routed_scaling_factor
|
| 291 |
+
|
| 292 |
+
return topk_idx, topk_weight
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class PanguUltraMoE(nn.Module):
|
| 296 |
+
def __init__(self, config, runner_config):
|
| 297 |
+
super().__init__()
|
| 298 |
+
self.runner_config = runner_config
|
| 299 |
+
self.hidden_dim = config.hidden_size
|
| 300 |
+
self.moe_tp_size = self.runner_config.get("parallel_config").get(
|
| 301 |
+
"moe_tp_size", 1
|
| 302 |
+
)
|
| 303 |
+
self.batch_size_decode = self.runner_config.get("data_config").get(
|
| 304 |
+
"batch_size", 1
|
| 305 |
+
)
|
| 306 |
+
self.batch_size_prefill = self.batch_size_decode
|
| 307 |
+
self.num_experts_per_tok = config.num_experts_per_tok
|
| 308 |
+
self.num_experts = config.num_routed_experts
|
| 309 |
+
self.num_shared_experts = config.num_shared_experts
|
| 310 |
+
self.top_k = config.num_experts_per_tok
|
| 311 |
+
|
| 312 |
+
self.experts_per_rank = config.num_routed_experts
|
| 313 |
+
self.experts = MoE(
|
| 314 |
+
config, self.runner_config, intermediate_size=config.moe_intermediate_size
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
self.gate = MoEGate(config)
|
| 318 |
+
if self.num_shared_experts is not None:
|
| 319 |
+
intermediate_size = config.moe_intermediate_size * self.num_shared_experts
|
| 320 |
+
self.shared_experts = MLP(
|
| 321 |
+
config, self.runner_config, intermediate_size=intermediate_size
|
| 322 |
+
)
|
| 323 |
+
self.row_idx_decode_len = self.batch_size_decode * self.top_k
|
| 324 |
+
self.row_idx_decode = (
|
| 325 |
+
torch.arange(0, self.row_idx_decode_len, dtype=torch.int32)
|
| 326 |
+
.view(self.top_k, -1)
|
| 327 |
+
.permute(1, 0)
|
| 328 |
+
.int()
|
| 329 |
+
.contiguous()
|
| 330 |
+
.npu()
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
def forward(self, hidden_states):
|
| 334 |
+
identity = hidden_states
|
| 335 |
+
topk_idx, topk_weight = self.gate(hidden_states)
|
| 336 |
+
y = self.moe_npu(hidden_states, topk_idx, topk_weight)
|
| 337 |
+
if self.num_shared_experts is not None:
|
| 338 |
+
y = y + self.shared_experts(identity)
|
| 339 |
+
return y
|
| 340 |
+
|
| 341 |
+
def moe_npu(self, x, topk_ids, topk_weight):
|
| 342 |
+
batch_size, sequence_length, h = x.shape
|
| 343 |
+
hidden_states = x.view(-1, x.shape[-1])
|
| 344 |
+
|
| 345 |
+
routing_weights = topk_weight.to(x.dtype)
|
| 346 |
+
expert_idx = topk_ids.int()
|
| 347 |
+
if sequence_length == 1:
|
| 348 |
+
row_idx = self.row_idx_decode
|
| 349 |
+
else:
|
| 350 |
+
row_idx_prefill_len = self.batch_size_prefill * sequence_length * self.top_k
|
| 351 |
+
row_idx = (
|
| 352 |
+
torch.arange(
|
| 353 |
+
0, row_idx_prefill_len, dtype=torch.int32, device=topk_weight.device
|
| 354 |
+
)
|
| 355 |
+
.view(self.top_k, -1)
|
| 356 |
+
.permute(1, 0)
|
| 357 |
+
.int()
|
| 358 |
+
.contiguous()
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
active_num = batch_size * sequence_length
|
| 362 |
+
expanded_x, expanded_row_idx, expanded_expert_idx = (
|
| 363 |
+
torch_npu.npu_moe_init_routing(
|
| 364 |
+
hidden_states,
|
| 365 |
+
row_idx=row_idx,
|
| 366 |
+
expert_idx=expert_idx,
|
| 367 |
+
active_num=active_num,
|
| 368 |
+
)
|
| 369 |
+
)
|
| 370 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 371 |
+
expanded_expert_idx, self.num_experts
|
| 372 |
+
)
|
| 373 |
+
expert_tokens = expert_tokens.to(torch.int64)
|
| 374 |
+
|
| 375 |
+
hidden_states_ordered_by_experts = self.experts(
|
| 376 |
+
expanded_x, expert_tokens, seq_len=sequence_length
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 380 |
+
hidden_states_ordered_by_experts,
|
| 381 |
+
skip1=None,
|
| 382 |
+
skip2=None,
|
| 383 |
+
bias=None,
|
| 384 |
+
scales=routing_weights,
|
| 385 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 386 |
+
export_for_source_row=expert_idx,
|
| 387 |
+
)
|
| 388 |
+
if self.moe_tp_size > 1:
|
| 389 |
+
dist.all_reduce(hidden_states)
|
| 390 |
+
hidden_states = hidden_states.view(batch_size, -1, self.hidden_dim)
|
| 391 |
+
return hidden_states
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class PanguUltraMoEAttention(nn.Module):
|
| 395 |
+
def __init__(
|
| 396 |
+
self,
|
| 397 |
+
config: PanguUltraMoEConfig,
|
| 398 |
+
layer_idx: Optional[int] = None,
|
| 399 |
+
runner_config: Optional[Dict] = None,
|
| 400 |
+
):
|
| 401 |
+
super().__init__()
|
| 402 |
+
if runner_config is not None:
|
| 403 |
+
self.attn_tp_size = runner_config.get("parallel_config").get(
|
| 404 |
+
"attn_tp_size", 1
|
| 405 |
+
)
|
| 406 |
+
else:
|
| 407 |
+
self.attn_tp_size = 1
|
| 408 |
+
self.layer_idx = layer_idx
|
| 409 |
+
|
| 410 |
+
self.hidden_size = config.hidden_size
|
| 411 |
+
self.num_heads = config.num_attention_heads
|
| 412 |
+
self.num_heads_per_rank = self.num_heads // self.attn_tp_size
|
| 413 |
+
self.num_key_value_heads_per_rank = self.num_heads_per_rank
|
| 414 |
+
|
| 415 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 416 |
+
self.rope_theta = config.rope_theta
|
| 417 |
+
self.attention_q_lora_dim = config.attention_q_lora_dim
|
| 418 |
+
self.attention_qk_rope_dim = config.attention_qk_rope_dim
|
| 419 |
+
self.attention_kv_lora_dim = config.attention_kv_lora_dim
|
| 420 |
+
self.attention_v_dim = config.attention_v_dim
|
| 421 |
+
self.attention_qk_dim = config.attention_qk_dim
|
| 422 |
+
self.q_head_dim = config.attention_qk_dim + config.attention_qk_rope_dim
|
| 423 |
+
|
| 424 |
+
if self.attention_q_lora_dim is None:
|
| 425 |
+
self.q_proj = nn.Linear(
|
| 426 |
+
self.hidden_size, self.num_heads_per_rank * self.q_head_dim, bias=False
|
| 427 |
+
)
|
| 428 |
+
else:
|
| 429 |
+
self.q_a_proj = nn.Linear(
|
| 430 |
+
self.hidden_size, config.attention_q_lora_dim, bias=False
|
| 431 |
+
)
|
| 432 |
+
self.q_a_layernorm = PanguUltraMoERMSNorm(config.attention_q_lora_dim)
|
| 433 |
+
self.q_b_proj = nn.Linear(
|
| 434 |
+
config.attention_q_lora_dim,
|
| 435 |
+
self.num_heads_per_rank * self.q_head_dim,
|
| 436 |
+
bias=False,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
self.kv_a_proj_with_mqa = nn.Linear(
|
| 440 |
+
self.hidden_size,
|
| 441 |
+
config.attention_kv_lora_dim + config.attention_qk_rope_dim,
|
| 442 |
+
bias=False,
|
| 443 |
+
)
|
| 444 |
+
self.kv_a_layernorm = PanguUltraMoERMSNorm(config.attention_kv_lora_dim)
|
| 445 |
+
|
| 446 |
+
self.kv_b_proj_w_k = nn.Parameter(
|
| 447 |
+
torch.zeros(
|
| 448 |
+
self.num_heads_per_rank,
|
| 449 |
+
self.attention_qk_dim,
|
| 450 |
+
self.attention_kv_lora_dim,
|
| 451 |
+
)
|
| 452 |
+
)
|
| 453 |
+
self.kv_b_proj_w_v = nn.Parameter(
|
| 454 |
+
torch.zeros(
|
| 455 |
+
self.num_heads_per_rank,
|
| 456 |
+
self.attention_kv_lora_dim,
|
| 457 |
+
self.attention_v_dim,
|
| 458 |
+
)
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
self.o_proj = nn.Linear(
|
| 462 |
+
self.num_heads_per_rank * self.attention_v_dim,
|
| 463 |
+
self.hidden_size,
|
| 464 |
+
bias=False,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
self.softmax_scale = self.q_head_dim ** (-0.5)
|
| 468 |
+
|
| 469 |
+
def bmm_5d(self, x, y):
|
| 470 |
+
b, s, n, _, d = x.shape
|
| 471 |
+
x = x.view(b * s, n, d).transpose(0, 1)
|
| 472 |
+
output = torch.matmul(x, y)
|
| 473 |
+
output = output.transpose(1, 0).view(b, s, n, -1)
|
| 474 |
+
return output
|
| 475 |
+
|
| 476 |
+
def prepare_qkv(
|
| 477 |
+
self,
|
| 478 |
+
hidden_states: torch.Tensor,
|
| 479 |
+
cos_sin: torch.Tensor = None,
|
| 480 |
+
kv_len: torch.IntTensor = None,
|
| 481 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 482 |
+
past_key_value: Optional[Cache] = None,
|
| 483 |
+
**kwargs,
|
| 484 |
+
):
|
| 485 |
+
bsz, q_len, _ = hidden_states.size()
|
| 486 |
+
|
| 487 |
+
if self.attention_q_lora_dim is None:
|
| 488 |
+
q = self.q_proj(hidden_states)
|
| 489 |
+
else:
|
| 490 |
+
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
| 491 |
+
|
| 492 |
+
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
| 493 |
+
compressed_kv, k_pe = torch.split(
|
| 494 |
+
compressed_kv,
|
| 495 |
+
[self.attention_kv_lora_dim, self.attention_qk_rope_dim],
|
| 496 |
+
dim=-1,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
q = q.view(bsz, q_len, self.num_heads_per_rank, self.q_head_dim)
|
| 500 |
+
q_nope, q_pe = torch.split(
|
| 501 |
+
q, [self.attention_qk_dim, self.attention_qk_rope_dim], dim=-1
|
| 502 |
+
)
|
| 503 |
+
q_pe = q_pe.transpose(1, 2)
|
| 504 |
+
q_nope = self.bmm_5d(
|
| 505 |
+
q_nope.view(bsz, q_len, self.num_heads_per_rank, 1, self.attention_qk_dim),
|
| 506 |
+
self.kv_b_proj_w_k,
|
| 507 |
+
)
|
| 508 |
+
q_nope = q_nope.view(
|
| 509 |
+
bsz, q_len, self.num_heads_per_rank, self.attention_kv_lora_dim
|
| 510 |
+
)
|
| 511 |
+
q_nope = q_nope.transpose(1, 2)
|
| 512 |
+
|
| 513 |
+
k_pe = k_pe.view(bsz, q_len, 1, self.attention_qk_rope_dim).transpose(1, 2)
|
| 514 |
+
k_nope = (
|
| 515 |
+
self.kv_a_layernorm(compressed_kv)
|
| 516 |
+
.view(bsz, -1, 1, self.attention_kv_lora_dim)
|
| 517 |
+
.transpose(1, 2)
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
cos, sin = cos_sin
|
| 521 |
+
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
| 522 |
+
|
| 523 |
+
query_states = torch.cat([q_nope, q_pe], dim=-1)
|
| 524 |
+
key_states = torch.cat([k_nope, k_pe], dim=-1)
|
| 525 |
+
|
| 526 |
+
kv_seq_len = k_nope.shape[-2]
|
| 527 |
+
if past_key_value is not None:
|
| 528 |
+
past_key_states = past_key_value[self.layer_idx][0]
|
| 529 |
+
torch_npu.scatter_update_(past_key_states, kv_len, key_states, -2)
|
| 530 |
+
if q_len == 1:
|
| 531 |
+
key_states = past_key_states
|
| 532 |
+
kv_seq_len = past_key_value[0][0].size()[-2]
|
| 533 |
+
value_states = key_states
|
| 534 |
+
return query_states, key_states, value_states, kv_seq_len
|
| 535 |
+
|
| 536 |
+
def apply_attention_npu(
|
| 537 |
+
self,
|
| 538 |
+
query_states,
|
| 539 |
+
key_states,
|
| 540 |
+
value_states,
|
| 541 |
+
kv_seq_len,
|
| 542 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 543 |
+
actual_seq_lengths_kv: list = None,
|
| 544 |
+
output_attentions: bool = False,
|
| 545 |
+
past_key_value: Optional[Cache] = None,
|
| 546 |
+
):
|
| 547 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
| 548 |
+
bsz, _, q_len, _ = query_states.size()
|
| 549 |
+
attn_weights = (
|
| 550 |
+
torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
|
| 551 |
+
)
|
| 552 |
+
if attention_mask is not None:
|
| 553 |
+
attn_weights = attn_weights + attention_mask
|
| 554 |
+
else:
|
| 555 |
+
raise ValueError("attention mask must not be None")
|
| 556 |
+
|
| 557 |
+
attn_weights = nn.functional.softmax(
|
| 558 |
+
attn_weights, dim=-1, dtype=torch.float32
|
| 559 |
+
).to(query_states.dtype)
|
| 560 |
+
value_states = value_states[..., : self.attention_kv_lora_dim]
|
| 561 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 562 |
+
|
| 563 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 564 |
+
attn_output = self.bmm_5d(attn_output.unsqueeze(3), self.kv_b_proj_w_v)
|
| 565 |
+
attn_output = self.o_proj(attn_output.reshape(bsz, q_len, -1))
|
| 566 |
+
if self.attn_tp_size > 1:
|
| 567 |
+
dist.all_reduce(attn_output)
|
| 568 |
+
return attn_output
|
| 569 |
+
|
| 570 |
+
def forward(
|
| 571 |
+
self,
|
| 572 |
+
hidden_states: torch.Tensor,
|
| 573 |
+
kv_len: torch.IntTensor = None,
|
| 574 |
+
actual_seq_lengths_kv: list = None,
|
| 575 |
+
cos_sin: torch.Tensor = None,
|
| 576 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 577 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 578 |
+
past_key_value: Optional[Cache] = None,
|
| 579 |
+
output_attentions: bool = False,
|
| 580 |
+
use_cache: bool = False,
|
| 581 |
+
**kwargs,
|
| 582 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 583 |
+
if "padding_mask" in kwargs:
|
| 584 |
+
warnings.warn(
|
| 585 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
| 586 |
+
)
|
| 587 |
+
query_states, key_states, value_states, kv_seq_len = self.prepare_qkv(
|
| 588 |
+
hidden_states=hidden_states,
|
| 589 |
+
cos_sin=cos_sin,
|
| 590 |
+
kv_len=kv_len,
|
| 591 |
+
position_ids=position_ids,
|
| 592 |
+
past_key_value=past_key_value,
|
| 593 |
+
)
|
| 594 |
+
output = self.apply_attention_npu(
|
| 595 |
+
query_states=query_states,
|
| 596 |
+
key_states=key_states,
|
| 597 |
+
value_states=value_states,
|
| 598 |
+
kv_seq_len=kv_seq_len,
|
| 599 |
+
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
| 600 |
+
attention_mask=attention_mask,
|
| 601 |
+
output_attentions=output_attentions,
|
| 602 |
+
past_key_value=past_key_value,
|
| 603 |
+
)
|
| 604 |
+
return output
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
class PanguUltraMoEDecoderLayer(nn.Module):
|
| 608 |
+
def __init__(
|
| 609 |
+
self, config: PanguUltraMoEConfig, runner_config: Dict, layer_idx: int
|
| 610 |
+
):
|
| 611 |
+
super().__init__()
|
| 612 |
+
self.runner_config = runner_config
|
| 613 |
+
self.hidden_size = config.hidden_size
|
| 614 |
+
|
| 615 |
+
self.self_attn = PanguUltraMoEAttention(
|
| 616 |
+
config=config, runner_config=self.runner_config, layer_idx=layer_idx
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
self.mlp = (
|
| 620 |
+
PanguUltraMoE(config, self.runner_config)
|
| 621 |
+
if (
|
| 622 |
+
config.num_routed_experts is not None
|
| 623 |
+
and layer_idx >= config.num_dense_layers
|
| 624 |
+
)
|
| 625 |
+
else MLP(config, self.runner_config)
|
| 626 |
+
)
|
| 627 |
+
self.input_layernorm = PanguUltraMoERMSNorm(
|
| 628 |
+
config.hidden_size, epsilon=config.rms_norm_eps
|
| 629 |
+
)
|
| 630 |
+
self.post_attention_layernorm = PanguUltraMoERMSNorm(
|
| 631 |
+
config.hidden_size, epsilon=config.rms_norm_eps
|
| 632 |
+
)
|
| 633 |
+
if getattr(config, "sandwich_norm", False):
|
| 634 |
+
self.sandwich_norm = True
|
| 635 |
+
self.pre_mlp_layernorm = PanguUltraMoERMSNorm(
|
| 636 |
+
config.hidden_size, epsilon=config.rms_norm_eps
|
| 637 |
+
)
|
| 638 |
+
self.post_mlp_layernorm = PanguUltraMoERMSNorm(
|
| 639 |
+
config.hidden_size, epsilon=config.rms_norm_eps
|
| 640 |
+
)
|
| 641 |
+
else:
|
| 642 |
+
self.sandwich_norm = False
|
| 643 |
+
|
| 644 |
+
def forward(
|
| 645 |
+
self,
|
| 646 |
+
hidden_states: torch.Tensor,
|
| 647 |
+
kv_len: torch.IntTensor,
|
| 648 |
+
actual_seq_lengths_kv: list,
|
| 649 |
+
cos_sin: torch.Tensor,
|
| 650 |
+
past_residual: Optional[torch.Tensor] = None,
|
| 651 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 652 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 653 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 654 |
+
**kwargs,
|
| 655 |
+
) -> Tuple[torch.FloatTensor]:
|
| 656 |
+
hidden_states, residual = self.input_layernorm(hidden_states, past_residual)
|
| 657 |
+
|
| 658 |
+
# Self Attention
|
| 659 |
+
hidden_states = self.self_attn(
|
| 660 |
+
hidden_states=hidden_states,
|
| 661 |
+
kv_len=kv_len,
|
| 662 |
+
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
| 663 |
+
cos_sin=cos_sin,
|
| 664 |
+
attention_mask=attention_mask,
|
| 665 |
+
position_ids=position_ids,
|
| 666 |
+
past_key_value=past_key_value,
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
if self.sandwich_norm:
|
| 670 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 671 |
+
hidden_states, residual = self.pre_mlp_layernorm(hidden_states, residual)
|
| 672 |
+
else:
|
| 673 |
+
hidden_states, residual = self.post_attention_layernorm(
|
| 674 |
+
hidden_states, residual
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
hidden_states = self.mlp(hidden_states)
|
| 678 |
+
|
| 679 |
+
if self.sandwich_norm:
|
| 680 |
+
hidden_states = self.post_mlp_layernorm(hidden_states)
|
| 681 |
+
|
| 682 |
+
outputs = (residual, hidden_states)
|
| 683 |
+
return outputs
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
class PanguUltraMoEPreTrainedModel(PreTrainedModel):
|
| 687 |
+
config_class = PanguUltraMoEConfig
|
| 688 |
+
base_model_prefix = "model"
|
| 689 |
+
supports_gradient_checkpointing = True
|
| 690 |
+
_no_split_modules = ["PanguUltraMoEDecoderLayer"]
|
| 691 |
+
_skip_keys_device_placement = "past_key_values"
|
| 692 |
+
_supports_cache_class = True
|
| 693 |
+
|
| 694 |
+
def _init_weights(self, module):
|
| 695 |
+
pass
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
class PanguUltraMoEModel(PanguUltraMoEPreTrainedModel):
|
| 699 |
+
def __init__(self, config: PanguUltraMoEConfig, runner_config: Dict):
|
| 700 |
+
super().__init__(config)
|
| 701 |
+
self.config = config
|
| 702 |
+
self.runner_config = runner_config
|
| 703 |
+
self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
| 704 |
+
self.rank_offset = int(os.getenv("RANK_OFFSET", "0"))
|
| 705 |
+
self.global_rank = self.local_rank + self.rank_offset
|
| 706 |
+
self.embed_tp_size = self.runner_config.get("parallel_config").get(
|
| 707 |
+
"embed_tp_size", 1
|
| 708 |
+
)
|
| 709 |
+
self.padding_idx = config.pad_token_id
|
| 710 |
+
self.vocab_size = config.vocab_size
|
| 711 |
+
self.vocab_size_per_rank = self.vocab_size // self.embed_tp_size
|
| 712 |
+
|
| 713 |
+
self.embed_tokens = nn.Embedding(
|
| 714 |
+
self.vocab_size_per_rank, config.hidden_size, self.padding_idx
|
| 715 |
+
)
|
| 716 |
+
self.layers = nn.ModuleList(
|
| 717 |
+
[
|
| 718 |
+
PanguUltraMoEDecoderLayer(config, self.runner_config, layer_idx)
|
| 719 |
+
for layer_idx in range(config.num_hidden_layers)
|
| 720 |
+
]
|
| 721 |
+
)
|
| 722 |
+
self.norm = PanguUltraMoERMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
|
| 723 |
+
self.gradient_checkpointing = False
|
| 724 |
+
# Initialize weights and apply final processing
|
| 725 |
+
self.post_init()
|
| 726 |
+
self.rotary_emb = PanguUltraMoERotaryEmbedding(
|
| 727 |
+
self.config.attention_qk_rope_dim,
|
| 728 |
+
max_position_embeddings=self.config.max_position_embeddings,
|
| 729 |
+
base=self.config.rope_theta,
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
def forward(
|
| 733 |
+
self,
|
| 734 |
+
input_ids: torch.LongTensor,
|
| 735 |
+
kv_len: torch.IntTensor = None,
|
| 736 |
+
actual_seq_lengths_kv: list = None,
|
| 737 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 738 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 739 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 740 |
+
):
|
| 741 |
+
|
| 742 |
+
batch_size, seq_length = input_ids.shape
|
| 743 |
+
past_key_values_length = past_key_values[0][0].size()[-2]
|
| 744 |
+
|
| 745 |
+
if position_ids is None:
|
| 746 |
+
device = input_ids.device
|
| 747 |
+
position_ids = torch.arange(
|
| 748 |
+
past_key_values_length,
|
| 749 |
+
seq_length + past_key_values_length,
|
| 750 |
+
dtype=torch.long,
|
| 751 |
+
device=device,
|
| 752 |
+
)
|
| 753 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
| 754 |
+
else:
|
| 755 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
| 756 |
+
|
| 757 |
+
if self.embed_tp_size > 1:
|
| 758 |
+
new_input_ids = input_ids - self.global_rank * self.vocab_size_per_rank
|
| 759 |
+
mask = (new_input_ids >= 0) & (
|
| 760 |
+
new_input_ids < self.vocab_size_per_rank
|
| 761 |
+
) # (bs, qlen)
|
| 762 |
+
new_input_ids_per_rank = new_input_ids * mask
|
| 763 |
+
inputs_embeds = self.embed_tokens(new_input_ids_per_rank) * mask.unsqueeze(
|
| 764 |
+
-1
|
| 765 |
+
)
|
| 766 |
+
dist.all_reduce(inputs_embeds)
|
| 767 |
+
else:
|
| 768 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 769 |
+
hidden_states = inputs_embeds
|
| 770 |
+
|
| 771 |
+
cos_sin = self.rotary_emb(
|
| 772 |
+
hidden_states, kv_len, self.config.max_position_embeddings
|
| 773 |
+
)
|
| 774 |
+
residual = None
|
| 775 |
+
|
| 776 |
+
for decoder_layer in self.layers:
|
| 777 |
+
residual, hidden_states = decoder_layer(
|
| 778 |
+
hidden_states,
|
| 779 |
+
kv_len,
|
| 780 |
+
actual_seq_lengths_kv,
|
| 781 |
+
cos_sin=cos_sin,
|
| 782 |
+
past_residual=residual,
|
| 783 |
+
attention_mask=attention_mask,
|
| 784 |
+
position_ids=position_ids,
|
| 785 |
+
past_key_value=past_key_values,
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
| 789 |
+
|
| 790 |
+
return hidden_states
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
class PanguUltraMoEForCausalLM(PanguUltraMoEPreTrainedModel):
|
| 794 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 795 |
+
|
| 796 |
+
def __init__(self, config, runner_config):
|
| 797 |
+
super().__init__(config)
|
| 798 |
+
self.config = config
|
| 799 |
+
self.runner_config = runner_config
|
| 800 |
+
self.embed_tp_size = self.runner_config.get("parallel_config").get(
|
| 801 |
+
"embed_tp_size", 1
|
| 802 |
+
)
|
| 803 |
+
self.model = PanguUltraMoEModel(config, self.runner_config)
|
| 804 |
+
self.vocab_size = config.vocab_size
|
| 805 |
+
self.lm_head = nn.Linear(
|
| 806 |
+
config.hidden_size, config.vocab_size // self.embed_tp_size, bias=False
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
def forward(
|
| 810 |
+
self,
|
| 811 |
+
input_ids: torch.LongTensor = None,
|
| 812 |
+
kv_len: torch.IntTensor = None,
|
| 813 |
+
actual_seq_lengths_kv: list = None,
|
| 814 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 815 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 816 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 817 |
+
):
|
| 818 |
+
outputs = self.model(
|
| 819 |
+
input_ids=input_ids,
|
| 820 |
+
kv_len=kv_len,
|
| 821 |
+
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
| 822 |
+
attention_mask=attention_mask,
|
| 823 |
+
position_ids=position_ids,
|
| 824 |
+
past_key_values=past_key_values,
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
hidden_states = outputs
|
| 828 |
+
|
| 829 |
+
if hidden_states.size()[1] > 1:
|
| 830 |
+
gather_index, _ = torch.max(position_ids, dim=-1)
|
| 831 |
+
gather_index = (
|
| 832 |
+
gather_index.unsqueeze(1)
|
| 833 |
+
.unsqueeze(2)
|
| 834 |
+
.repeat(1, 1, hidden_states.shape[-1])
|
| 835 |
+
)
|
| 836 |
+
hidden_states = torch.gather(hidden_states, 1, gather_index)
|
| 837 |
+
|
| 838 |
+
logits = self.lm_head(hidden_states)
|
| 839 |
+
if self.embed_tp_size > 1:
|
| 840 |
+
new_logits = torch.zeros_like(logits).repeat(self.embed_tp_size, 1, 1)
|
| 841 |
+
dist.all_gather_into_tensor(new_logits, logits, group=_world._default_pg)
|
| 842 |
+
new_logits = new_logits.reshape(
|
| 843 |
+
self.embed_tp_size, logits.shape[0], logits.shape[1], -1
|
| 844 |
+
).permute(1, 2, 0, 3)
|
| 845 |
+
logits = new_logits.reshape(logits.shape[0], logits.shape[1], -1)
|
| 846 |
+
logits = logits.float()
|
| 847 |
+
|
| 848 |
+
return logits
|
| 849 |
+
|
| 850 |
+
def init_cache(self, input_ids):
|
| 851 |
+
batch_size, seq_len = input_ids.size()
|
| 852 |
+
|
| 853 |
+
cache_seq_len = self.config.max_position_embeddings
|
| 854 |
+
|
| 855 |
+
past_key_values = ()
|
| 856 |
+
cache_key_shape = (
|
| 857 |
+
batch_size,
|
| 858 |
+
1,
|
| 859 |
+
cache_seq_len,
|
| 860 |
+
self.config.attention_kv_lora_dim + self.config.attention_qk_rope_dim,
|
| 861 |
+
)
|
| 862 |
+
dtype = self.config.torch_dtype
|
| 863 |
+
|
| 864 |
+
for _ in range(self.config.num_hidden_layers):
|
| 865 |
+
key_cache = torch.zeros(
|
| 866 |
+
cache_key_shape, dtype=dtype, device=input_ids.device
|
| 867 |
+
)
|
| 868 |
+
past_key_values += ((key_cache,),)
|
| 869 |
+
|
| 870 |
+
return past_key_values
|
| 871 |
+
|
| 872 |
+
def prepare_inputs_for_generation(
|
| 873 |
+
self,
|
| 874 |
+
input_ids,
|
| 875 |
+
past_key_values=None,
|
| 876 |
+
attention_mask=None,
|
| 877 |
+
inputs_embeds=None,
|
| 878 |
+
is_prefill=None,
|
| 879 |
+
kv_len=None,
|
| 880 |
+
share_mask_tril=None,
|
| 881 |
+
**kwargs,
|
| 882 |
+
):
|
| 883 |
+
batch_size, seq_len = input_ids.size()
|
| 884 |
+
if past_key_values is None:
|
| 885 |
+
past_key_values = self.init_cache(input_ids)
|
| 886 |
+
if is_prefill:
|
| 887 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 888 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 889 |
+
attention_mask = share_mask_tril
|
| 890 |
+
kv_len = torch.zeros(
|
| 891 |
+
(position_ids.size()[0]), dtype=torch.int32, device=input_ids.device
|
| 892 |
+
)
|
| 893 |
+
actual_seq_lengths_kv = None
|
| 894 |
+
past_key_values_length = 0
|
| 895 |
+
input_mask = None
|
| 896 |
+
else:
|
| 897 |
+
attention_mask = None
|
| 898 |
+
position_ids = kv_len.unsqueeze(1)
|
| 899 |
+
actual_seq_lengths_kv = (kv_len + 1).cpu().detach().numpy().tolist()
|
| 900 |
+
past_key_values_length = self.config.max_position_embeddings - seq_len
|
| 901 |
+
input_mask = share_mask_tril
|
| 902 |
+
|
| 903 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
| 904 |
+
input_mask, (batch_size, seq_len), input_ids.float(), past_key_values_length
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
model_inputs = {}
|
| 908 |
+
model_inputs.update(
|
| 909 |
+
{
|
| 910 |
+
"input_ids": input_ids,
|
| 911 |
+
"position_ids": position_ids,
|
| 912 |
+
"past_key_values": past_key_values,
|
| 913 |
+
"attention_mask": attention_mask,
|
| 914 |
+
"kv_len": kv_len,
|
| 915 |
+
"actual_seq_lengths_kv": actual_seq_lengths_kv,
|
| 916 |
+
}
|
| 917 |
+
)
|
| 918 |
+
return model_inputs
|
inference/runner.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 3 |
+
|
| 4 |
+
import copy
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch.distributed.distributed_c10d import _world
|
| 11 |
+
from transformers import AutoTokenizer
|
| 12 |
+
|
| 13 |
+
root_logger = logging.getLogger()
|
| 14 |
+
root_logger.handlers.clear()
|
| 15 |
+
logging.basicConfig(
|
| 16 |
+
format="%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s",
|
| 17 |
+
level=logging.INFO,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
torch.manual_seed(42)
|
| 21 |
+
torch.npu.manual_seed_all(42)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_init_attn_mask(mask_length, device, valid_len=None):
|
| 25 |
+
share_mask_tril = ~torch.tril(
|
| 26 |
+
torch.ones((mask_length, mask_length), dtype=torch.bool, device=device)
|
| 27 |
+
)
|
| 28 |
+
if valid_len is not None:
|
| 29 |
+
share_mask_tril[-valid_len:, :] = torch.zeros(valid_len, mask_length)
|
| 30 |
+
return share_mask_tril
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_decode_mask(mask_length, device, position):
|
| 34 |
+
decode_mask = torch.zeros((1, mask_length), device=device)
|
| 35 |
+
decode_mask[0, :position] = 1
|
| 36 |
+
return decode_mask
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def sample(input_logits: torch.Tensor, temperature=1.0, top_p=0.0, top_k=0, top_n_sigma=-1.0, **kwargs):
|
| 40 |
+
# shape of input_logits: [batch_size, 1, vocab_size]
|
| 41 |
+
# greedy
|
| 42 |
+
if temperature <= 0.0 or top_k == 1 or top_p == 0.0 or top_n_sigma == 0.0:
|
| 43 |
+
return torch.argmax(input_logits, dim=-1)
|
| 44 |
+
|
| 45 |
+
logits = input_logits / temperature
|
| 46 |
+
|
| 47 |
+
filter_value = -3.4028e+38
|
| 48 |
+
|
| 49 |
+
# top_n_sigma truncation
|
| 50 |
+
if top_n_sigma > 0.0:
|
| 51 |
+
max_vals, _ = logits.max(dim=-1, keepdim=True)
|
| 52 |
+
std_vals = logits.std(dim=-1, keepdim=True)
|
| 53 |
+
threshold = max_vals - top_n_sigma * std_vals
|
| 54 |
+
mask = logits < threshold
|
| 55 |
+
logits = torch.where(mask, filter_value, logits)
|
| 56 |
+
|
| 57 |
+
# top_k truncation
|
| 58 |
+
if top_k > 0:
|
| 59 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 60 |
+
logits[indices_to_remove] = filter_value
|
| 61 |
+
|
| 62 |
+
# top_p truncation
|
| 63 |
+
if 0.0 < top_p < 1.0:
|
| 64 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
| 65 |
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
| 66 |
+
|
| 67 |
+
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
| 68 |
+
# keep at least 1 token
|
| 69 |
+
sorted_indices_to_remove[..., -1:] = 0
|
| 70 |
+
|
| 71 |
+
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
|
| 72 |
+
logits = logits.masked_fill(indices_to_remove, filter_value)
|
| 73 |
+
|
| 74 |
+
probs = logits.softmax(dim=-1)
|
| 75 |
+
outputs = torch.multinomial(probs.squeeze(1), num_samples=1)
|
| 76 |
+
return outputs
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class ModelRunner:
|
| 80 |
+
def __init__(self, runner_config):
|
| 81 |
+
self.runner_config = runner_config
|
| 82 |
+
self.model_name = runner_config.get("model_name", "default_model_name")
|
| 83 |
+
model_path = self.runner_config.get("model_path")
|
| 84 |
+
self.dtype = runner_config.get("model_config").get("dtype", torch.bfloat16)
|
| 85 |
+
self.max_position_embeddings = runner_config.get("data_config").get(
|
| 86 |
+
"max_position_embeddings", 131072
|
| 87 |
+
)
|
| 88 |
+
self.input_max_len = runner_config.get("data_config").get("input_max_len", 1024)
|
| 89 |
+
self.max_new_tokens = runner_config.get("data_config").get("max_new_tokens", 32)
|
| 90 |
+
self.batch_size = runner_config.get("data_config").get("batch_size", 16)
|
| 91 |
+
self.sampling_params = runner_config.get("sampling_config", {})
|
| 92 |
+
self.tokenizer = None
|
| 93 |
+
self.model = None
|
| 94 |
+
self.device = None
|
| 95 |
+
self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
| 96 |
+
self.rank_offset = int(os.getenv("RANK_OFFSET", "0"))
|
| 97 |
+
self.global_rank = self.local_rank + self.rank_offset
|
| 98 |
+
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
|
| 99 |
+
if self.world_size == 1:
|
| 100 |
+
self.model_path = model_path
|
| 101 |
+
else:
|
| 102 |
+
self.model_path = os.path.join(model_path, f"rank_{self.global_rank}")
|
| 103 |
+
|
| 104 |
+
self.res_path = os.getenv("RES_PATH", "./")
|
| 105 |
+
self.enable_profiler = runner_config.get("model_config").get(
|
| 106 |
+
"enable_profiler", 0
|
| 107 |
+
)
|
| 108 |
+
self.use_pretrained_model = True
|
| 109 |
+
self.execute_mode = runner_config.get("exe_mode", "dynamo")
|
| 110 |
+
self.tokenizer_mode = runner_config.get("model_config").get(
|
| 111 |
+
"tokenizer_mode", "default"
|
| 112 |
+
)
|
| 113 |
+
self.init_device()
|
| 114 |
+
self.start_time = None
|
| 115 |
+
self.end_time = None
|
| 116 |
+
self.with_ckpt = runner_config.get("model_config").get("with_ckpt", 1)
|
| 117 |
+
|
| 118 |
+
@staticmethod
|
| 119 |
+
def repeat_batch(tensor, repeat_num):
|
| 120 |
+
if repeat_num == 1:
|
| 121 |
+
return tensor
|
| 122 |
+
return tensor.repeat(repeat_num, *[1] * (tensor.dim() - 1))
|
| 123 |
+
|
| 124 |
+
def init_device(self):
|
| 125 |
+
logging.info(
|
| 126 |
+
"Set execution using npu index: %s, global: %s",
|
| 127 |
+
self.local_rank,
|
| 128 |
+
self.global_rank,
|
| 129 |
+
)
|
| 130 |
+
self.device = torch.device("%s:%s" % ("npu", self.local_rank))
|
| 131 |
+
torch.npu.set_device(self.device)
|
| 132 |
+
if torch.npu.is_available() and self.world_size > 1:
|
| 133 |
+
if _world._default_pg is None:
|
| 134 |
+
torch.distributed.init_process_group(
|
| 135 |
+
backend="hccl", world_size=self.world_size, rank=self.global_rank
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def init_model(self, model, config=None):
|
| 139 |
+
if self.with_ckpt:
|
| 140 |
+
self.use_pretrained_model = True
|
| 141 |
+
config = None
|
| 142 |
+
else:
|
| 143 |
+
self.use_pretrained_model = False
|
| 144 |
+
from configuration_openpangu_moe import PanguUltraMoEConfig as config
|
| 145 |
+
logging.info(f"use_pretrained_model: {self.use_pretrained_model}")
|
| 146 |
+
|
| 147 |
+
if self.use_pretrained_model:
|
| 148 |
+
self.load_model(model)
|
| 149 |
+
else:
|
| 150 |
+
self.init_model_from_config(model, config=config)
|
| 151 |
+
self.to_device()
|
| 152 |
+
self.compile_model()
|
| 153 |
+
self.init_tokenizer()
|
| 154 |
+
|
| 155 |
+
def init_model_from_config(self, model, config):
|
| 156 |
+
if config is None:
|
| 157 |
+
raise Exception("config cannot be None")
|
| 158 |
+
config_file = f"{self.model_path}/config.json"
|
| 159 |
+
model_config = config.from_pretrained(
|
| 160 |
+
config_file,
|
| 161 |
+
torch_dtype=self.dtype,
|
| 162 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 163 |
+
)
|
| 164 |
+
self.model = model(model_config, runner_config=self.runner_config).to(
|
| 165 |
+
self.dtype
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def load_model(self, model):
|
| 169 |
+
logging.info("Try to load pretrained model in path: %s", self.model_path)
|
| 170 |
+
self.model = model.from_pretrained(
|
| 171 |
+
self.model_path,
|
| 172 |
+
low_cpu_mem_usage=True,
|
| 173 |
+
ignore_mismatched_sizes=True,
|
| 174 |
+
torch_dtype=self.dtype,
|
| 175 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 176 |
+
runner_config=self.runner_config,
|
| 177 |
+
)
|
| 178 |
+
for name, params in self.model.named_parameters():
|
| 179 |
+
logging.info(
|
| 180 |
+
"Param of %s: %s, %s, %s",
|
| 181 |
+
self.model_name,
|
| 182 |
+
name,
|
| 183 |
+
params.size(),
|
| 184 |
+
params.dtype,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def to_device(self):
|
| 188 |
+
self.model.to(self.device)
|
| 189 |
+
logging.info("Model weights H2D finished.")
|
| 190 |
+
|
| 191 |
+
def init_tokenizer(self):
|
| 192 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 193 |
+
self.model_path,
|
| 194 |
+
trust_remote_code=True,
|
| 195 |
+
local_files_only=True,
|
| 196 |
+
padding_side="right",
|
| 197 |
+
truncation_side="right",
|
| 198 |
+
)
|
| 199 |
+
if self.tokenizer.pad_token is None:
|
| 200 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 201 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
| 202 |
+
|
| 203 |
+
def compile_model(self):
|
| 204 |
+
logging.info("The final model structure is: \n %s", self.model)
|
| 205 |
+
if self.execute_mode == "dynamo":
|
| 206 |
+
logging.info("Try to compile model")
|
| 207 |
+
self.graph_compile()
|
| 208 |
+
|
| 209 |
+
def graph_compile(self):
|
| 210 |
+
import torchair as tng
|
| 211 |
+
import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce
|
| 212 |
+
from torchair.configs.compiler_config import CompilerConfig
|
| 213 |
+
|
| 214 |
+
compiler_config = CompilerConfig()
|
| 215 |
+
compiler_config.experimental_config.frozen_parameter = True
|
| 216 |
+
compiler_config.experimental_config.tiling_schedule_optimize = True
|
| 217 |
+
npu_backend = tng.get_npu_backend(compiler_config=compiler_config)
|
| 218 |
+
self.model = torch.compile(
|
| 219 |
+
self.model, dynamic=True, fullgraph=True, backend=npu_backend
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def mark_inputs(self, model_inputs):
|
| 223 |
+
if self.execute_mode == "dynamo":
|
| 224 |
+
input_ids = model_inputs.get("input_ids")
|
| 225 |
+
kv_len = model_inputs.get("kv_len")
|
| 226 |
+
attention_mask = model_inputs.get("attention_mask")
|
| 227 |
+
position_ids = model_inputs.get("position_ids")
|
| 228 |
+
past_key_values = model_inputs.get("past_key_values")
|
| 229 |
+
|
| 230 |
+
# prefill with dynamic sequence length, decode with static sequence length
|
| 231 |
+
torch._dynamo.mark_static(kv_len)
|
| 232 |
+
for item in past_key_values:
|
| 233 |
+
for sub_item in item:
|
| 234 |
+
torch._dynamo.mark_static(sub_item)
|
| 235 |
+
|
| 236 |
+
torch._dynamo.mark_static(input_ids)
|
| 237 |
+
if attention_mask is not None:
|
| 238 |
+
torch._dynamo.mark_static(attention_mask)
|
| 239 |
+
torch._dynamo.mark_static(position_ids)
|
| 240 |
+
|
| 241 |
+
def model_input_prepare(self, input_dict):
|
| 242 |
+
input_ids = input_dict.get("input_ids")
|
| 243 |
+
attention_mask = input_dict.get("attention_mask")
|
| 244 |
+
past_key_values = input_dict.get("past_key_values")
|
| 245 |
+
is_prefill = input_dict.get("is_prefill")
|
| 246 |
+
kv_len = input_dict.get("kv_len")
|
| 247 |
+
share_mask_tril = input_dict.get("share_mask_tril")
|
| 248 |
+
model_inputs = self.model.prepare_inputs_for_generation(
|
| 249 |
+
input_ids=input_ids,
|
| 250 |
+
attention_mask=attention_mask,
|
| 251 |
+
past_key_values=past_key_values,
|
| 252 |
+
is_prefill=is_prefill,
|
| 253 |
+
kv_len=kv_len,
|
| 254 |
+
input_lens=input_dict.get("input_lens"),
|
| 255 |
+
share_mask_tril=share_mask_tril,
|
| 256 |
+
)
|
| 257 |
+
return model_inputs
|
| 258 |
+
|
| 259 |
+
def model_inference(self, model_inputs, warm_up=False):
|
| 260 |
+
torch.npu.synchronize()
|
| 261 |
+
if warm_up:
|
| 262 |
+
self.mark_inputs(model_inputs)
|
| 263 |
+
if self.start_time is None:
|
| 264 |
+
self.start_time = time.time()
|
| 265 |
+
with torch.no_grad():
|
| 266 |
+
logits = self.model(**model_inputs)
|
| 267 |
+
torch.npu.synchronize()
|
| 268 |
+
self.end_time = time.time()
|
| 269 |
+
if torch.distributed.get_rank() != 0:
|
| 270 |
+
logging.info(
|
| 271 |
+
f"{self.model_name} inference time cost {(self.end_time - self.start_time)*1000:.2f} ms"
|
| 272 |
+
)
|
| 273 |
+
self.start_time = time.time()
|
| 274 |
+
return logits
|
| 275 |
+
|
| 276 |
+
def model_generate(self, prompts, warm_up=False, **kwargs):
|
| 277 |
+
calling_func = {
|
| 278 |
+
"default": self.tokenizer,
|
| 279 |
+
"chat": self.tokenizer.apply_chat_template,
|
| 280 |
+
}
|
| 281 |
+
kwargs = {
|
| 282 |
+
"return_tensors": "pt",
|
| 283 |
+
"truncation": True,
|
| 284 |
+
"padding": "max_length",
|
| 285 |
+
"max_length": self.input_max_len,
|
| 286 |
+
}
|
| 287 |
+
if self.tokenizer_mode == "chat":
|
| 288 |
+
chat_kwargs = {"add_generation_prompt": True, "return_dict": True}
|
| 289 |
+
kwargs.update(chat_kwargs)
|
| 290 |
+
tokenizer = calling_func.get(self.tokenizer_mode, self.tokenizer)
|
| 291 |
+
inputs = tokenizer(prompts, **kwargs).to(self.device)
|
| 292 |
+
|
| 293 |
+
# get init input_dict
|
| 294 |
+
share_mask_tril = get_init_attn_mask(
|
| 295 |
+
self.max_position_embeddings, self.device, valid_len=self.input_max_len
|
| 296 |
+
)
|
| 297 |
+
share_mask_tril = share_mask_tril[None, None, ...]
|
| 298 |
+
|
| 299 |
+
input_lens = copy.deepcopy(inputs.input_ids.size()[1])
|
| 300 |
+
logging.info("Padding max prompts lens is : %d", input_lens)
|
| 301 |
+
input_dict = {
|
| 302 |
+
"input_ids": inputs.input_ids,
|
| 303 |
+
"generate_ids": inputs.input_ids,
|
| 304 |
+
"input_lens": input_lens,
|
| 305 |
+
"kv_len": None,
|
| 306 |
+
"past_key_values": None,
|
| 307 |
+
"attention_mask": inputs.attention_mask,
|
| 308 |
+
"share_mask_tril": share_mask_tril,
|
| 309 |
+
"is_prefill": True,
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
if torch.distributed.get_rank() == 0:
|
| 313 |
+
logging.info(
|
| 314 |
+
f"inputs.input_ids {inputs.input_ids[:,:30]} eod id {self.tokenizer.eos_token_id}"
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
generate_tokens = 0
|
| 318 |
+
cnt = 0
|
| 319 |
+
all_done = [False for _ in range(input_dict["input_ids"].size(0))]
|
| 320 |
+
done_len = [-1 for _ in range(input_dict["input_ids"].size(0))]
|
| 321 |
+
while True:
|
| 322 |
+
jump_flag = self.get_jump_flag(cnt, warm_up, generate_tokens)
|
| 323 |
+
if jump_flag:
|
| 324 |
+
break
|
| 325 |
+
|
| 326 |
+
# exit until all reach eod
|
| 327 |
+
if input_dict["input_ids"].size(1) == 1:
|
| 328 |
+
for bi in range(input_dict["input_ids"].size(0)):
|
| 329 |
+
if (
|
| 330 |
+
input_dict["input_ids"][bi, 0].item()
|
| 331 |
+
== self.tokenizer.eos_token_id
|
| 332 |
+
):
|
| 333 |
+
all_done[bi] = True
|
| 334 |
+
done_len[bi] = generate_tokens
|
| 335 |
+
if all(all_done):
|
| 336 |
+
break
|
| 337 |
+
model_inputs = self.model_input_prepare(input_dict)
|
| 338 |
+
# fix decode mask
|
| 339 |
+
if model_inputs["position_ids"].shape[1] == 1:
|
| 340 |
+
model_inputs["attention_mask"].fill_(-3.4028e38)
|
| 341 |
+
for bi in range(model_inputs["position_ids"].size(0)):
|
| 342 |
+
max_l = model_inputs["position_ids"][bi].max().item()
|
| 343 |
+
model_inputs["attention_mask"][bi, :, :, : max_l + 1] = 0
|
| 344 |
+
outputs = self.model_inference(model_inputs, warm_up=warm_up)
|
| 345 |
+
self.model_output_process(model_inputs, outputs, input_dict)
|
| 346 |
+
# prof.step()
|
| 347 |
+
generate_tokens += 1
|
| 348 |
+
cnt += 1
|
| 349 |
+
|
| 350 |
+
generate_ids = input_dict["generate_ids"][:, input_lens:].clip(
|
| 351 |
+
0, self.model.config.vocab_size - 1
|
| 352 |
+
)
|
| 353 |
+
for bi in range(generate_ids.size(0)):
|
| 354 |
+
if done_len[bi] != -1:
|
| 355 |
+
generate_ids[bi, done_len[bi] :] = self.tokenizer.eos_token_id
|
| 356 |
+
res = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True)
|
| 357 |
+
|
| 358 |
+
if isinstance(res, list):
|
| 359 |
+
for answer in res:
|
| 360 |
+
logging.info("Inference decode result: \n%s", answer)
|
| 361 |
+
else:
|
| 362 |
+
logging.info("Inference decode result: \n%s", res)
|
| 363 |
+
return res
|
| 364 |
+
|
| 365 |
+
def get_jump_flag(self, cnt, warm_up, generate_tokens):
|
| 366 |
+
default_decode_dump = 2
|
| 367 |
+
# warm up only perform for 5 times(decode)
|
| 368 |
+
jump_flag_warm = warm_up and cnt >= default_decode_dump
|
| 369 |
+
# do not generate after max_token
|
| 370 |
+
jump_flag_oversize = generate_tokens >= self.max_new_tokens
|
| 371 |
+
jump_flag = jump_flag_oversize or jump_flag_warm
|
| 372 |
+
return jump_flag
|
| 373 |
+
|
| 374 |
+
def model_output_process(self, model_inputs, outputs, input_dict):
|
| 375 |
+
next_batch = self.batch_size
|
| 376 |
+
attn_tp_size = self.runner_config.get("parallel_config").get("attn_tp_size", 1)
|
| 377 |
+
if self.world_size % attn_tp_size != 0:
|
| 378 |
+
raise Exception(
|
| 379 |
+
f"world_size ({self.world_siz}) not divisible by attn_tp_size ({attn_tp_size})!"
|
| 380 |
+
)
|
| 381 |
+
attn_dp_size = self.world_size // attn_tp_size
|
| 382 |
+
input_dict["is_prefill"] = False
|
| 383 |
+
input_dict["input_lens"] = input_dict["input_lens"] + 1
|
| 384 |
+
|
| 385 |
+
kv_len = torch.max(model_inputs.get("position_ids"), axis=1)[0] + 1
|
| 386 |
+
input_dict["kv_len"] = kv_len
|
| 387 |
+
|
| 388 |
+
logits = outputs
|
| 389 |
+
past_key_values = model_inputs.get("past_key_values")
|
| 390 |
+
input_dict["past_key_values"] = past_key_values
|
| 391 |
+
|
| 392 |
+
attention_mask = None
|
| 393 |
+
|
| 394 |
+
share_mask_tril = get_decode_mask(
|
| 395 |
+
mask_length=self.max_position_embeddings,
|
| 396 |
+
device=self.device,
|
| 397 |
+
position=input_dict["input_lens"],
|
| 398 |
+
)
|
| 399 |
+
share_mask_tril = share_mask_tril[None, None, ...]
|
| 400 |
+
|
| 401 |
+
input_dict["attention_mask"] = attention_mask
|
| 402 |
+
input_dict["share_mask_tril"] = ModelRunner.repeat_batch(
|
| 403 |
+
share_mask_tril, self.batch_size
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
next_tokens = sample(logits, **self.sampling_params)
|
| 407 |
+
torch.distributed.broadcast(next_tokens, src=0)
|
| 408 |
+
input_dict["input_ids"] = next_tokens
|
| 409 |
+
input_dict["generate_ids"] = torch.cat(
|
| 410 |
+
[input_dict["generate_ids"], next_tokens], dim=-1
|
| 411 |
+
)
|
inference/runner_config/tp1.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 3 |
+
|
| 4 |
+
model_name: "pangu_ultra_moe"
|
| 5 |
+
model_path: "./model"
|
| 6 |
+
exe_mode: "eager" # ["dynamo", "eager"]
|
| 7 |
+
|
| 8 |
+
model_config:
|
| 9 |
+
tokenizer_mode: default # ["default", "chat"]
|
| 10 |
+
mm_quant_mode: None
|
| 11 |
+
mla_backend: absorb # [native, absorb]
|
| 12 |
+
with_ckpt: 1 # [0, 1]
|
| 13 |
+
enable_profiler: 0 # [0, 1]
|
| 14 |
+
|
| 15 |
+
data_config:
|
| 16 |
+
input_max_len: 4096
|
| 17 |
+
max_new_tokens: 28000
|
| 18 |
+
batch_size: 1
|
| 19 |
+
max_position_embeddings: 32768
|
| 20 |
+
|
| 21 |
+
parallel_config:
|
| 22 |
+
attn_tp_size: 1
|
| 23 |
+
moe_tp_size: 1
|
| 24 |
+
embed_tp_size: 1
|
| 25 |
+
|
| 26 |
+
sampling_config:
|
| 27 |
+
top_n_sigma: 0.05
|
| 28 |
+
top_p: 1.0
|
| 29 |
+
temperature: 0.7
|
| 30 |
+
top_k: -1
|
inference/runner_config/tp32.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 3 |
+
|
| 4 |
+
model_name: "pangu_ultra_moe"
|
| 5 |
+
model_path: "./model"
|
| 6 |
+
exe_mode: "eager" # ["dynamo", "eager"]
|
| 7 |
+
|
| 8 |
+
model_config:
|
| 9 |
+
tokenizer_mode: default # ["default", "chat"]
|
| 10 |
+
mm_quant_mode: None
|
| 11 |
+
mla_backend: absorb # [native, absorb]
|
| 12 |
+
with_ckpt: 1 # [0, 1]
|
| 13 |
+
enable_profiler: 0 # [0, 1]
|
| 14 |
+
|
| 15 |
+
data_config:
|
| 16 |
+
input_max_len: 4096
|
| 17 |
+
max_new_tokens: 28000
|
| 18 |
+
batch_size: 1
|
| 19 |
+
max_position_embeddings: 32768
|
| 20 |
+
|
| 21 |
+
parallel_config:
|
| 22 |
+
attn_tp_size: 32
|
| 23 |
+
moe_tp_size: 32
|
| 24 |
+
embed_tp_size: 32
|
| 25 |
+
|
| 26 |
+
sampling_config:
|
| 27 |
+
top_n_sigma: 0.05
|
| 28 |
+
top_p: 1.0
|
| 29 |
+
temperature: 0.7
|
| 30 |
+
top_k: -1
|
inference/split_weight.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import shutil
|
| 8 |
+
from threading import Thread
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import yaml
|
| 13 |
+
from torch import nn
|
| 14 |
+
from transformers import AutoModelForCausalLM
|
| 15 |
+
|
| 16 |
+
from model import PanguUltraMoEForCausalLM
|
| 17 |
+
|
| 18 |
+
root_logger = logging.getLogger()
|
| 19 |
+
root_logger.handlers.clear()
|
| 20 |
+
logging.basicConfig(
|
| 21 |
+
format="%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s",
|
| 22 |
+
level=logging.INFO,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _to_parameter(data):
|
| 27 |
+
return nn.Parameter(data, requires_grad=False)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def split_w_dense(block, dst_model, i, local_rank):
|
| 31 |
+
up_weight_list = []
|
| 32 |
+
ffn_dim = dst_model.model.layers[i].mlp.intermediate_size_per_rank
|
| 33 |
+
gate_weight = block.mlp.gate_proj.weight[
|
| 34 |
+
local_rank * ffn_dim : (local_rank + 1) * ffn_dim, :
|
| 35 |
+
].contiguous()
|
| 36 |
+
up_weight = block.mlp.up_proj.weight[
|
| 37 |
+
local_rank * ffn_dim : (local_rank + 1) * ffn_dim, :
|
| 38 |
+
].contiguous()
|
| 39 |
+
up_weight_list.append(_to_parameter(torch.cat([gate_weight, up_weight], axis=0)))
|
| 40 |
+
|
| 41 |
+
if len(up_weight_list) == 1:
|
| 42 |
+
dst_model.model.layers[i].mlp.merge_up_gate_proj.weight = up_weight_list[0]
|
| 43 |
+
else:
|
| 44 |
+
dst_model.model.layers[i].mlp.merge_up_gate_proj.weight = _to_parameter(
|
| 45 |
+
torch.cat(up_weight_list, axis=0)
|
| 46 |
+
)
|
| 47 |
+
dst_model.model.layers[i].mlp.down_proj.weight.data = (
|
| 48 |
+
block.mlp.down_proj.weight.data[
|
| 49 |
+
:, local_rank * ffn_dim : (local_rank + 1) * ffn_dim
|
| 50 |
+
].contiguous()
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def split_w_moe(block, dst_model, i, local_rank):
|
| 55 |
+
shared_up_weight_list = []
|
| 56 |
+
ffn_dim = dst_model.model.layers[i].mlp.shared_experts.intermediate_size_per_rank
|
| 57 |
+
gate_weight = block.mlp.shared_experts.gate_proj.weight[
|
| 58 |
+
local_rank * ffn_dim : (local_rank + 1) * ffn_dim, :
|
| 59 |
+
].contiguous()
|
| 60 |
+
up_weight = block.mlp.shared_experts.up_proj.weight[
|
| 61 |
+
local_rank * ffn_dim : (local_rank + 1) * ffn_dim, :
|
| 62 |
+
].contiguous()
|
| 63 |
+
shared_up_weight_list.append(
|
| 64 |
+
_to_parameter(torch.cat([gate_weight, up_weight], axis=0))
|
| 65 |
+
)
|
| 66 |
+
if len(shared_up_weight_list) == 1:
|
| 67 |
+
dst_model.model.layers[i].mlp.shared_experts.merge_up_gate_proj.weight = (
|
| 68 |
+
shared_up_weight_list[0]
|
| 69 |
+
)
|
| 70 |
+
else:
|
| 71 |
+
dst_model.model.layers[i].mlp.shared_experts.merge_up_gate_proj.weight = (
|
| 72 |
+
_to_parameter(torch.cat(shared_up_weight_list, axis=0))
|
| 73 |
+
)
|
| 74 |
+
dst_model.model.layers[i].mlp.shared_experts.down_proj.weight.data = (
|
| 75 |
+
block.mlp.shared_experts.down_proj.weight.data[
|
| 76 |
+
:, local_rank * ffn_dim : (local_rank + 1) * ffn_dim
|
| 77 |
+
].contiguous()
|
| 78 |
+
)
|
| 79 |
+
dst_model.model.layers[i].mlp.gate.weight.data = block.mlp.gate.weight.data
|
| 80 |
+
|
| 81 |
+
expert_num = block.mlp.num_routed_experts
|
| 82 |
+
gate_proj_list, down_proj_list, up_proj_list = [], [], []
|
| 83 |
+
for _, src_expert in enumerate(block.mlp.experts):
|
| 84 |
+
ffn_dim = dst_model.model.layers[i].mlp.experts.intermediate_size_per_rank
|
| 85 |
+
gate_proj_list.append(
|
| 86 |
+
src_expert.gate_proj.weight.data[
|
| 87 |
+
local_rank * ffn_dim : (local_rank + 1) * ffn_dim, :
|
| 88 |
+
].contiguous()
|
| 89 |
+
)
|
| 90 |
+
up_proj_list.append(
|
| 91 |
+
src_expert.up_proj.weight.data[
|
| 92 |
+
local_rank * ffn_dim : (local_rank + 1) * ffn_dim, :
|
| 93 |
+
].contiguous()
|
| 94 |
+
)
|
| 95 |
+
down_proj_list.append(
|
| 96 |
+
src_expert.down_proj.weight.data[
|
| 97 |
+
:, local_rank * ffn_dim : (local_rank + 1) * ffn_dim
|
| 98 |
+
].contiguous()
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
dst_model.model.layers[i].mlp.experts.group_w2.data = (
|
| 102 |
+
torch.cat(down_proj_list, dim=0).view(expert_num, -1, ffn_dim).contiguous()
|
| 103 |
+
)
|
| 104 |
+
group_gate_proj = (
|
| 105 |
+
torch.cat(gate_proj_list, dim=0).view(expert_num, ffn_dim, -1).contiguous()
|
| 106 |
+
)
|
| 107 |
+
group_up_proj = (
|
| 108 |
+
torch.cat(up_proj_list, dim=0).view(expert_num, ffn_dim, -1).contiguous()
|
| 109 |
+
)
|
| 110 |
+
dst_model.model.layers[i].mlp.experts.group_w1_w3.data = torch.cat(
|
| 111 |
+
[group_gate_proj, group_up_proj], dim=1
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def split_w_attn(block, dst_model, i, local_rank):
|
| 116 |
+
q_dim = (
|
| 117 |
+
dst_model.model.layers[0].self_attn.num_heads_per_rank
|
| 118 |
+
* dst_model.model.layers[0].self_attn.q_head_dim
|
| 119 |
+
)
|
| 120 |
+
o_dim = (
|
| 121 |
+
dst_model.model.layers[0].self_attn.num_heads_per_rank
|
| 122 |
+
* dst_model.model.layers[0].self_attn.attention_v_dim
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if dst_model.model.layers[i].self_attn.attention_q_lora_dim is None:
|
| 126 |
+
dst_model.model.layers[i].self_attn.q_proj.weight.data = (
|
| 127 |
+
block.self_attn.q_proj.weight.data[
|
| 128 |
+
local_rank * q_dim : (local_rank + 1) * q_dim, :
|
| 129 |
+
].contiguous()
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
dst_model.model.layers[i].self_attn.q_a_proj.weight.data = (
|
| 133 |
+
block.self_attn.q_a_proj.weight.data
|
| 134 |
+
)
|
| 135 |
+
dst_model.model.layers[i].self_attn.q_a_layernorm.weight.data = (
|
| 136 |
+
block.self_attn.q_a_layernorm.weight.data
|
| 137 |
+
)
|
| 138 |
+
dst_model.model.layers[i].self_attn.q_b_proj.weight.data = (
|
| 139 |
+
block.self_attn.q_b_proj.weight.data[
|
| 140 |
+
local_rank * q_dim : (local_rank + 1) * q_dim, :
|
| 141 |
+
].contiguous()
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
dst_model.model.layers[i].self_attn.kv_a_proj_with_mqa.weight.data = (
|
| 145 |
+
block.self_attn.kv_a_proj_with_mqa.weight.data
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
dst_model.model.layers[i].self_attn.kv_a_layernorm.weight.data = (
|
| 149 |
+
block.self_attn.kv_a_layernorm.weight.data
|
| 150 |
+
)
|
| 151 |
+
dst_model.model.layers[i].self_attn.o_proj.weight.data = (
|
| 152 |
+
block.self_attn.o_proj.weight.data[
|
| 153 |
+
:, local_rank * o_dim : (local_rank + 1) * o_dim
|
| 154 |
+
].contiguous()
|
| 155 |
+
)
|
| 156 |
+
dst_model.model.layers[i].input_layernorm.weight.data = (
|
| 157 |
+
block.input_layernorm.weight.data
|
| 158 |
+
)
|
| 159 |
+
dst_model.model.layers[i].post_attention_layernorm.weight.data = (
|
| 160 |
+
block.post_attention_layernorm.weight.data
|
| 161 |
+
)
|
| 162 |
+
dst_model.model.layers[i].pre_mlp_layernorm.weight.data = (
|
| 163 |
+
block.pre_mlp_layernorm.weight.data
|
| 164 |
+
)
|
| 165 |
+
dst_model.model.layers[i].post_mlp_layernorm.weight.data = (
|
| 166 |
+
block.post_mlp_layernorm.weight.data
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def kv_low_rank_split(block, dst_model, i, local_rank):
|
| 171 |
+
k_dim = dst_model.model.layers[0].self_attn.num_heads_per_rank * (
|
| 172 |
+
dst_model.model.layers[0].self_attn.attention_qk_dim
|
| 173 |
+
+ dst_model.model.layers[0].self_attn.attention_v_dim
|
| 174 |
+
)
|
| 175 |
+
kv_b_proj_weight_data = block.self_attn.kv_b_proj.weight.data[
|
| 176 |
+
local_rank * k_dim : (local_rank + 1) * k_dim, :
|
| 177 |
+
].contiguous()
|
| 178 |
+
attention_qk_dim = dst_model.model.layers[i].self_attn.attention_qk_dim
|
| 179 |
+
num_heads_per_rank = dst_model.model.layers[i].self_attn.num_heads_per_rank
|
| 180 |
+
attention_kv_lora_dim = dst_model.model.layers[i].self_attn.attention_kv_lora_dim
|
| 181 |
+
attention_v_dim = dst_model.model.layers[i].self_attn.attention_v_dim
|
| 182 |
+
|
| 183 |
+
index_tensor = torch.arange(attention_qk_dim).repeat(
|
| 184 |
+
num_heads_per_rank
|
| 185 |
+
) + torch.arange(num_heads_per_rank).repeat_interleave(attention_qk_dim) * (
|
| 186 |
+
attention_qk_dim + attention_v_dim
|
| 187 |
+
)
|
| 188 |
+
kv_b_proj_w_k = torch.index_select(kv_b_proj_weight_data, dim=0, index=index_tensor)
|
| 189 |
+
dst_model.model.layers[i].self_attn.kv_b_proj_w_k.data = kv_b_proj_w_k.view(
|
| 190 |
+
num_heads_per_rank, attention_qk_dim, attention_kv_lora_dim
|
| 191 |
+
).contiguous()
|
| 192 |
+
index_tensor = torch.arange(
|
| 193 |
+
attention_qk_dim, attention_qk_dim + attention_v_dim
|
| 194 |
+
).repeat(num_heads_per_rank) + torch.arange(num_heads_per_rank).repeat_interleave(
|
| 195 |
+
attention_v_dim
|
| 196 |
+
) * (
|
| 197 |
+
attention_qk_dim + attention_v_dim
|
| 198 |
+
)
|
| 199 |
+
kv_b_proj_w_v = torch.index_select(kv_b_proj_weight_data, dim=0, index=index_tensor)
|
| 200 |
+
dst_model.model.layers[i].self_attn.kv_b_proj_w_v.data = (
|
| 201 |
+
kv_b_proj_w_v.view(num_heads_per_rank, attention_v_dim, attention_kv_lora_dim)
|
| 202 |
+
.transpose(1, 2)
|
| 203 |
+
.contiguous()
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def split_layer(block, dst_model, i, local_rank, attn_tp_size, moe_tp_size):
|
| 208 |
+
# attn weights
|
| 209 |
+
local_rank_tp_attn = local_rank % attn_tp_size
|
| 210 |
+
split_w_attn(block, dst_model, i, local_rank_tp_attn)
|
| 211 |
+
kv_low_rank_split(block, dst_model, i, local_rank_tp_attn)
|
| 212 |
+
|
| 213 |
+
# moe experts weights
|
| 214 |
+
local_rank_tp_moe = local_rank % moe_tp_size
|
| 215 |
+
if i >= dst_model.config.num_dense_layers:
|
| 216 |
+
split_w_moe(block, dst_model, i, local_rank_tp_moe)
|
| 217 |
+
else:
|
| 218 |
+
split_w_dense(block, dst_model, i, local_rank_tp_moe)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def split_w(src_model, dst_model, local_rank, runner_config):
|
| 222 |
+
attn_tp_size = runner_config.get("parallel_config").get("attn_tp_size")
|
| 223 |
+
moe_tp_size = runner_config.get("parallel_config").get("moe_tp_size")
|
| 224 |
+
embed_tp_size = runner_config.get("parallel_config").get("embed_tp_size")
|
| 225 |
+
|
| 226 |
+
vocab_size = src_model.model.vocab_size // embed_tp_size
|
| 227 |
+
embed_tp_rank = local_rank % embed_tp_size
|
| 228 |
+
|
| 229 |
+
dst_model.lm_head.weight.data = src_model.lm_head.weight.data[
|
| 230 |
+
embed_tp_rank * vocab_size : (embed_tp_rank + 1) * vocab_size, :
|
| 231 |
+
]
|
| 232 |
+
dst_model.model.embed_tokens.weight.data = src_model.model.embed_tokens.weight.data[
|
| 233 |
+
embed_tp_rank * vocab_size : (embed_tp_rank + 1) * vocab_size, :
|
| 234 |
+
]
|
| 235 |
+
|
| 236 |
+
dst_model.model.norm.weight.data = src_model.model.norm.weight.data
|
| 237 |
+
|
| 238 |
+
layer_num = len(src_model.model.layers)
|
| 239 |
+
|
| 240 |
+
all_threads = []
|
| 241 |
+
for i in range(0, layer_num):
|
| 242 |
+
block = src_model.model.layers[i]
|
| 243 |
+
thread = Thread(
|
| 244 |
+
target=split_layer,
|
| 245 |
+
args=(block, dst_model, i, local_rank, attn_tp_size, moe_tp_size),
|
| 246 |
+
)
|
| 247 |
+
all_threads.append(thread)
|
| 248 |
+
thread.start()
|
| 249 |
+
for thread in all_threads:
|
| 250 |
+
thread.join()
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def copy_files_with_prefix(src_dir, dst_dir, prefix):
|
| 254 |
+
for file in os.listdir(src_dir):
|
| 255 |
+
if file.startswith(prefix):
|
| 256 |
+
src_file = os.path.join(src_dir, file)
|
| 257 |
+
dst_file = os.path.join(dst_dir, file)
|
| 258 |
+
shutil.copy2(src_file, dst_file)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def parse_args():
|
| 262 |
+
parser = argparse.ArgumentParser(
|
| 263 |
+
description="Split weight parameters with tensor parallel"
|
| 264 |
+
)
|
| 265 |
+
parser.add_argument("--model_path", type=str, help="Path of model weights")
|
| 266 |
+
parser.add_argument(
|
| 267 |
+
"--output_path",
|
| 268 |
+
type=str,
|
| 269 |
+
help="The output directory where the results are saved",
|
| 270 |
+
)
|
| 271 |
+
parser.add_argument(
|
| 272 |
+
"--origin_yaml_file_path", type=str, help="inference configurations"
|
| 273 |
+
)
|
| 274 |
+
parser.add_argument(
|
| 275 |
+
"--new_yaml_file_path", type=str, help="inference configurations"
|
| 276 |
+
)
|
| 277 |
+
parser.add_argument(
|
| 278 |
+
"--world_size", type=int, default=8, help="The parallel rank size of model"
|
| 279 |
+
)
|
| 280 |
+
parser.add_argument("--node_num", type=int, default=1, help="The parallel node num")
|
| 281 |
+
parser.add_argument(
|
| 282 |
+
"--node_rank", type=int, default=0, help="The parallel node rank"
|
| 283 |
+
)
|
| 284 |
+
parser_args = parser.parse_args()
|
| 285 |
+
return parser_args
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def show_model_states(origin_model, model_name="src_model"):
|
| 289 |
+
src_param_size = 0
|
| 290 |
+
for name, params in origin_model.named_parameters():
|
| 291 |
+
size_per_param = np.prod(params.size())
|
| 292 |
+
src_param_size += size_per_param
|
| 293 |
+
logging.info(
|
| 294 |
+
"Param of %s tensor parallel: %s, %s, %s",
|
| 295 |
+
model_name,
|
| 296 |
+
name,
|
| 297 |
+
params.size(),
|
| 298 |
+
params.dtype,
|
| 299 |
+
)
|
| 300 |
+
logging.info(
|
| 301 |
+
"Total param size of %s tensor parallel: %s", model_name, src_param_size
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def read_yaml(yaml_file_path):
|
| 306 |
+
try:
|
| 307 |
+
with open(yaml_file_path, "r", encoding="utf-8") as file:
|
| 308 |
+
data = yaml.safe_load(file)
|
| 309 |
+
except FileNotFoundError:
|
| 310 |
+
logging.error(f"No such yaml file: {yaml_file_path}")
|
| 311 |
+
except yaml.YAMLERROR as e:
|
| 312 |
+
logging.error(f"Load yaml file failed: {e}")
|
| 313 |
+
return data
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def check_vars(world_size, runner_config):
|
| 317 |
+
attn_tp_size = runner_config.get("parallel_config").get("attn_tp_size")
|
| 318 |
+
moe_tp_size = runner_config.get("parallel_config").get("moe_tp_size")
|
| 319 |
+
embed_tp_size = runner_config.get("parallel_config").get("embed_tp_size")
|
| 320 |
+
if world_size % attn_tp_size != 0:
|
| 321 |
+
logging.error(
|
| 322 |
+
"world_size %s mod attn_tp_size %s must be 0", world_size, attn_tp_size
|
| 323 |
+
)
|
| 324 |
+
exit(1)
|
| 325 |
+
if world_size % moe_tp_size != 0:
|
| 326 |
+
logging.error(
|
| 327 |
+
"world_size %s mod moe_tp_size %s must be 0", world_size, moe_tp_size
|
| 328 |
+
)
|
| 329 |
+
exit(1)
|
| 330 |
+
if world_size % embed_tp_size != 0:
|
| 331 |
+
logging.error(
|
| 332 |
+
"world_size %s mod embed_tp_size %s must be 0", world_size, embed_tp_size
|
| 333 |
+
)
|
| 334 |
+
exit(1)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
if __name__ == "__main__":
|
| 338 |
+
logging.info("Start to split weight...")
|
| 339 |
+
args = parse_args()
|
| 340 |
+
output_path = args.output_path
|
| 341 |
+
|
| 342 |
+
old_runner_config = read_yaml(args.origin_yaml_file_path)
|
| 343 |
+
new_runner_config = read_yaml(args.new_yaml_file_path)
|
| 344 |
+
world_size = args.world_size
|
| 345 |
+
|
| 346 |
+
if not os.path.exists(output_path):
|
| 347 |
+
os.makedirs(output_path)
|
| 348 |
+
origin_model = AutoModelForCausalLM.from_pretrained(
|
| 349 |
+
args.model_path,
|
| 350 |
+
trust_remote_code=True,
|
| 351 |
+
local_files_only=True,
|
| 352 |
+
ignore_mismatched_sizes=True,
|
| 353 |
+
low_cpu_mem_usage=True,
|
| 354 |
+
torch_dtype=torch.bfloat16,
|
| 355 |
+
attn_implementation="eager",
|
| 356 |
+
)
|
| 357 |
+
show_model_states(origin_model, "origin_model")
|
| 358 |
+
|
| 359 |
+
node_rank_id = args.node_rank
|
| 360 |
+
rank_num_per_node = world_size // args.node_num
|
| 361 |
+
start_rank = rank_num_per_node * node_rank_id
|
| 362 |
+
end_rank = rank_num_per_node * (node_rank_id + 1)
|
| 363 |
+
|
| 364 |
+
for rank_id in range(start_rank, end_rank):
|
| 365 |
+
logging.info("rank_id={} / rank_size={}".format(rank_id, world_size))
|
| 366 |
+
os.environ["LOCAL_RANK"] = str(rank_id)
|
| 367 |
+
|
| 368 |
+
save_path = os.path.join(output_path, f"rank_{rank_id}")
|
| 369 |
+
logging.info(
|
| 370 |
+
"Split weight for rank %s start, save path is: %s", rank_id, save_path
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
config = origin_model.config
|
| 374 |
+
part_model = PanguUltraMoEForCausalLM(config, new_runner_config)
|
| 375 |
+
|
| 376 |
+
split_w(origin_model, part_model, rank_id, new_runner_config)
|
| 377 |
+
|
| 378 |
+
show_model_states(part_model, "dst_model")
|
| 379 |
+
|
| 380 |
+
part_model.save_pretrained(save_path)
|
| 381 |
+
copy_files_with_prefix(args.model_path, save_path, "tokenizer")
|
| 382 |
+
copy_files_with_prefix(args.model_path, save_path, "tokenization")
|
| 383 |
+
logging.info(
|
| 384 |
+
"Split weight for rank %s finished, save path is: %s", rank_id, save_path
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
del part_model
|
inference/split_weight.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 3 |
+
|
| 4 |
+
rm -rf ./model
|
| 5 |
+
mkdir ./model
|
| 6 |
+
python split_weight.py \
|
| 7 |
+
--model_path=../ \
|
| 8 |
+
--output_path=./model \
|
| 9 |
+
--origin_yaml_file_path=./runner_config/tp1.yaml \
|
| 10 |
+
--new_yaml_file_path=./runner_config/tp32.yaml \
|
| 11 |
+
--world_size=32 \
|
| 12 |
+
--node_num=1 \
|
| 13 |
+
--node_rank=0
|
inference/vllm_ascend/_build_info.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Auto-generated file
|
| 2 |
+
__soc_version__ = 'ASCEND910B1'
|
| 3 |
+
__sleep_mode_enabled__ = True
|
inference/vllm_ascend/attention/attention.py
ADDED
|
@@ -0,0 +1,1220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# This file is a part of the vllm-ascend project.
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch_npu
|
| 24 |
+
import torchair._contrib.custom_torch_ops # type: ignore # noqa: F401
|
| 25 |
+
from torch.nn.functional import scaled_dot_product_attention
|
| 26 |
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
| 27 |
+
AttentionLayer,
|
| 28 |
+
AttentionMetadata, AttentionType,
|
| 29 |
+
MLAAttentionImpl)
|
| 30 |
+
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
|
| 31 |
+
CommonMetadataBuilder,
|
| 32 |
+
compute_slot_mapping,
|
| 33 |
+
compute_slot_mapping_start_idx,
|
| 34 |
+
is_block_tables_empty)
|
| 35 |
+
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
| 36 |
+
|
| 37 |
+
from vllm_ascend.ascend_config import get_ascend_config
|
| 38 |
+
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
| 39 |
+
from vllm_ascend.ops.cache import concat_and_cache_mla
|
| 40 |
+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
|
| 41 |
+
enable_custom_op, is_310p, nd_to_nz_2d)
|
| 42 |
+
from vllm_ascend.worker.model_runner import (
|
| 43 |
+
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
| 44 |
+
|
| 45 |
+
_ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class AscendAttentionBackend(AttentionBackend):
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def get_name() -> str:
|
| 52 |
+
return "ASCEND"
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
| 56 |
+
return AscendAttentionBackendImpl
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
def get_metadata_cls() -> Type["AscendMetadata"]:
|
| 60 |
+
return AscendMetadata
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
def get_state_cls() -> Type["CommonAttentionState"]:
|
| 64 |
+
return CommonAttentionState
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def get_kv_cache_shape(
|
| 68 |
+
num_blocks: int,
|
| 69 |
+
block_size: int,
|
| 70 |
+
num_kv_heads: int,
|
| 71 |
+
head_size: int,
|
| 72 |
+
) -> Tuple[int, ...]:
|
| 73 |
+
if is_310p():
|
| 74 |
+
return (2, num_blocks, num_kv_heads * head_size // 16, block_size,
|
| 75 |
+
16)
|
| 76 |
+
else:
|
| 77 |
+
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
| 78 |
+
|
| 79 |
+
@staticmethod
|
| 80 |
+
def swap_blocks(
|
| 81 |
+
src_kv_cache: List[torch.Tensor],
|
| 82 |
+
dst_kv_cache: List[torch.Tensor],
|
| 83 |
+
src_to_dst: torch.Tensor,
|
| 84 |
+
) -> None:
|
| 85 |
+
src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
|
| 86 |
+
dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
|
| 87 |
+
src_indices = src_to_dst[:, 0]
|
| 88 |
+
dst_indices = src_to_dst[:, 1]
|
| 89 |
+
|
| 90 |
+
dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
|
| 91 |
+
dst_key_cache.device)
|
| 92 |
+
dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
|
| 93 |
+
dst_key_cache.device)
|
| 94 |
+
|
| 95 |
+
@staticmethod
|
| 96 |
+
def copy_blocks(
|
| 97 |
+
kv_caches: List[torch.Tensor],
|
| 98 |
+
src_to_dists: torch.Tensor,
|
| 99 |
+
) -> None:
|
| 100 |
+
src_indices = src_to_dists[:, 0]
|
| 101 |
+
dst_indices = src_to_dists[:, 1]
|
| 102 |
+
|
| 103 |
+
for kv_cache in kv_caches:
|
| 104 |
+
key_caches = kv_cache[0]
|
| 105 |
+
value_caches = kv_cache[1]
|
| 106 |
+
key_caches[dst_indices] = key_caches[src_indices]
|
| 107 |
+
value_caches[dst_indices] = value_caches[src_indices]
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def get_builder_cls() -> Type["AscendMetadataBuilder"]:
|
| 111 |
+
return AscendMetadataBuilder
|
| 112 |
+
|
| 113 |
+
@classmethod
|
| 114 |
+
def make_metadata_builder(cls, *args, **kwargs) -> "AscendMetadataBuilder":
|
| 115 |
+
return cls.get_builder_cls()(*args, **kwargs)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class AscendMLAAttentionBackend(AscendAttentionBackend):
|
| 119 |
+
|
| 120 |
+
@staticmethod
|
| 121 |
+
def get_impl_cls() -> Type["AscendMLAAttentionBackendImpl"]:
|
| 122 |
+
return AscendMLAAttentionBackendImpl
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def get_kv_cache_shape(
|
| 126 |
+
num_blocks: int,
|
| 127 |
+
block_size: int,
|
| 128 |
+
num_kv_heads: int,
|
| 129 |
+
head_size: int,
|
| 130 |
+
) -> Tuple[int, ...]:
|
| 131 |
+
return (num_blocks, block_size, num_kv_heads, head_size)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@dataclass
|
| 135 |
+
class AscendMetadata(AttentionMetadata):
|
| 136 |
+
"""Metadata for Ascendbackend.
|
| 137 |
+
* modified from XFormersbackend
|
| 138 |
+
NOTE: Any python object stored here is not updated when it is
|
| 139 |
+
cuda-graph replayed. If you have values that need to be changed
|
| 140 |
+
dynamically, it should be stored in tensor. The tensor has to be
|
| 141 |
+
updated from `CUDAGraphRunner.forward` API.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
# |---------- N-1 iteration --------|
|
| 145 |
+
# |---------------- N iteration ---------------------|
|
| 146 |
+
# |- tokenA -|......................|-- newTokens ---|
|
| 147 |
+
# |---------- context_len ----------|
|
| 148 |
+
# |-------------------- seq_len ----------------------|
|
| 149 |
+
# |-- query_len ---|
|
| 150 |
+
|
| 151 |
+
# FIXME: It is for flash attn.
|
| 152 |
+
# Maximum sequence length among prefill batch. 0 if there are decoding
|
| 153 |
+
# Avoid mypy error
|
| 154 |
+
# Total number of prefill requests.
|
| 155 |
+
num_prefills: int
|
| 156 |
+
# Number of prefill tokens.
|
| 157 |
+
num_prefill_tokens: int
|
| 158 |
+
# (num_tokens,). The indices of the token slots that input tokens will be
|
| 159 |
+
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
| 160 |
+
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
| 161 |
+
# in block 0, and 1st slot in block 1, respectively.
|
| 162 |
+
slot_mapping: torch.Tensor
|
| 163 |
+
|
| 164 |
+
# requests only.
|
| 165 |
+
max_prefill_seq_len: int
|
| 166 |
+
# Maximum sequence length among decode batch. 0 if there are prefill
|
| 167 |
+
# requests only.
|
| 168 |
+
max_decode_seq_len: int
|
| 169 |
+
|
| 170 |
+
chunked_prefill_enabled: bool
|
| 171 |
+
|
| 172 |
+
# (batch_size, max_blocks_per_seq).
|
| 173 |
+
# Block addresses per sequence. (Seq id -> list of physical block)
|
| 174 |
+
block_tables: Optional[torch.Tensor]
|
| 175 |
+
|
| 176 |
+
# seq_lens stored as a tensor.
|
| 177 |
+
seq_lens_tensor: Optional[torch.Tensor]
|
| 178 |
+
|
| 179 |
+
# (batch_size,). The sequence length per sequence. Sequence length means
|
| 180 |
+
# the computed tokens + new tokens None if it is a decoding.
|
| 181 |
+
seq_lens: Optional[List[int]] = None
|
| 182 |
+
|
| 183 |
+
# The query lengths of the input sequences
|
| 184 |
+
query_lens: Optional[List[int]] = None
|
| 185 |
+
|
| 186 |
+
# Maximum query length in the batch. None for decoding.
|
| 187 |
+
max_query_len: Optional[int] = None
|
| 188 |
+
|
| 189 |
+
# Self-attention prefill/decode metadata cache
|
| 190 |
+
_cached_prefill_metadata: Optional["AscendMetadata"] = None
|
| 191 |
+
_cached_decode_metadata: Optional["AscendMetadata"] = None
|
| 192 |
+
|
| 193 |
+
# Begin encoder attn & enc/dec cross-attn fields...
|
| 194 |
+
|
| 195 |
+
# Encoder sequence lengths representation
|
| 196 |
+
encoder_seq_lens: Optional[List[int]] = None
|
| 197 |
+
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
| 198 |
+
|
| 199 |
+
# Maximum sequence length among encoder sequences
|
| 200 |
+
max_encoder_seq_len: Optional[int] = None
|
| 201 |
+
|
| 202 |
+
# Number of tokens input to encoder
|
| 203 |
+
num_encoder_tokens: Optional[int] = None
|
| 204 |
+
|
| 205 |
+
# Mask for normal situation
|
| 206 |
+
attn_mask: Optional[torch.Tensor] = None
|
| 207 |
+
|
| 208 |
+
# Mask for prefix caching
|
| 209 |
+
compress_mask: Optional[torch.Tensor] = None
|
| 210 |
+
|
| 211 |
+
# Mask for chunked prefill
|
| 212 |
+
chunk_mask: Optional[torch.Tensor] = None
|
| 213 |
+
|
| 214 |
+
# Cross-attention memory-mapping data structures: slot mapping
|
| 215 |
+
# and block tables
|
| 216 |
+
cross_slot_mapping: Optional[torch.Tensor] = None
|
| 217 |
+
cross_block_tables: Optional[torch.Tensor] = None
|
| 218 |
+
|
| 219 |
+
@property
|
| 220 |
+
def prefill_metadata(self) -> Optional["AscendMetadata"]:
|
| 221 |
+
if self.num_prefills == 0:
|
| 222 |
+
return None
|
| 223 |
+
|
| 224 |
+
if self._cached_prefill_metadata is not None:
|
| 225 |
+
# Recover cached prefill-phase attention
|
| 226 |
+
# metadata structure.
|
| 227 |
+
return self._cached_prefill_metadata
|
| 228 |
+
|
| 229 |
+
assert ((self.seq_lens is not None)
|
| 230 |
+
or (self.encoder_seq_lens is not None))
|
| 231 |
+
|
| 232 |
+
# Compute some attn_metadata fields which default to None.
|
| 233 |
+
slot_mapping = (None if self.slot_mapping is None else
|
| 234 |
+
self.slot_mapping[:self.num_prefill_tokens])
|
| 235 |
+
seq_lens = (None if self.seq_lens is None else
|
| 236 |
+
self.seq_lens[:self.num_prefills])
|
| 237 |
+
query_lens = (None if self.query_lens is None else
|
| 238 |
+
self.query_lens[:self.num_prefills])
|
| 239 |
+
block_tables = (None if self.block_tables is None else
|
| 240 |
+
self.block_tables[:self.num_prefills])
|
| 241 |
+
|
| 242 |
+
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
| 243 |
+
self.seq_lens_tensor[:self.num_prefills])
|
| 244 |
+
|
| 245 |
+
# Construct & cache prefill-phase attention metadata structure.
|
| 246 |
+
self._cached_prefill_metadata = AscendMetadata(
|
| 247 |
+
num_prefills=self.num_prefills,
|
| 248 |
+
num_prefill_tokens=self.num_prefill_tokens,
|
| 249 |
+
num_decode_tokens=0,
|
| 250 |
+
slot_mapping=slot_mapping,
|
| 251 |
+
seq_lens=seq_lens,
|
| 252 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 253 |
+
query_lens=query_lens,
|
| 254 |
+
max_query_len=self.max_query_len,
|
| 255 |
+
max_prefill_seq_len=self.max_prefill_seq_len,
|
| 256 |
+
max_decode_seq_len=0,
|
| 257 |
+
chunked_prefill_enabled=self.chunked_prefill_enabled,
|
| 258 |
+
block_tables=block_tables,
|
| 259 |
+
# Begin encoder & cross attn fields below...
|
| 260 |
+
encoder_seq_lens=self.encoder_seq_lens,
|
| 261 |
+
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
| 262 |
+
max_encoder_seq_len=self.max_encoder_seq_len,
|
| 263 |
+
multi_modal_placeholder_index_maps=self.
|
| 264 |
+
multi_modal_placeholder_index_maps,
|
| 265 |
+
cross_slot_mapping=self.cross_slot_mapping,
|
| 266 |
+
cross_block_tables=self.cross_block_tables,
|
| 267 |
+
enable_kv_scales_calculation=False)
|
| 268 |
+
return self._cached_prefill_metadata
|
| 269 |
+
|
| 270 |
+
@property
|
| 271 |
+
def decode_metadata(self) -> Optional["AscendMetadata"]:
|
| 272 |
+
if self.num_decode_tokens == 0:
|
| 273 |
+
return None
|
| 274 |
+
|
| 275 |
+
if self._cached_decode_metadata is not None:
|
| 276 |
+
# Recover cached decode-phase attention
|
| 277 |
+
# metadata structure.
|
| 278 |
+
return self._cached_decode_metadata
|
| 279 |
+
|
| 280 |
+
# Compute some attn_metadata fields which default to None.
|
| 281 |
+
slot_mapping = (None if self.slot_mapping is None else
|
| 282 |
+
self.slot_mapping[self.num_prefill_tokens:])
|
| 283 |
+
seq_lens = (None if self.seq_lens is None else
|
| 284 |
+
self.seq_lens[self.num_prefills:])
|
| 285 |
+
query_lens = (None if self.query_lens is None else
|
| 286 |
+
self.query_lens[self.num_prefills:])
|
| 287 |
+
block_tables = (None if self.block_tables is None else
|
| 288 |
+
self.block_tables[self.num_prefills:])
|
| 289 |
+
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
| 290 |
+
self.seq_lens_tensor[self.num_prefills:])
|
| 291 |
+
# Construct & cache decode-phase attention metadata structure.
|
| 292 |
+
self._cached_decode_metadata = AscendMetadata(
|
| 293 |
+
num_prefills=0,
|
| 294 |
+
num_prefill_tokens=0,
|
| 295 |
+
num_decode_tokens=self.num_decode_tokens,
|
| 296 |
+
slot_mapping=slot_mapping,
|
| 297 |
+
seq_lens=seq_lens,
|
| 298 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 299 |
+
query_lens=query_lens,
|
| 300 |
+
max_query_len=self.max_query_len,
|
| 301 |
+
max_prefill_seq_len=0,
|
| 302 |
+
max_decode_seq_len=self.max_decode_seq_len,
|
| 303 |
+
chunked_prefill_enabled=self.chunked_prefill_enabled,
|
| 304 |
+
block_tables=block_tables,
|
| 305 |
+
# Begin encoder & cross attn fields below...
|
| 306 |
+
encoder_seq_lens=self.encoder_seq_lens,
|
| 307 |
+
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
| 308 |
+
max_encoder_seq_len=self.max_encoder_seq_len,
|
| 309 |
+
multi_modal_placeholder_index_maps=self.
|
| 310 |
+
multi_modal_placeholder_index_maps,
|
| 311 |
+
cross_slot_mapping=self.cross_slot_mapping,
|
| 312 |
+
cross_block_tables=self.cross_block_tables,
|
| 313 |
+
enable_kv_scales_calculation=False)
|
| 314 |
+
return self._cached_decode_metadata
|
| 315 |
+
|
| 316 |
+
def advance_step(self,
|
| 317 |
+
model_input: "ModelInputForNPUWithSamplingMetadata",
|
| 318 |
+
sampled_token_ids: Optional[torch.Tensor],
|
| 319 |
+
block_size: int,
|
| 320 |
+
num_seqs: int,
|
| 321 |
+
num_queries: int,
|
| 322 |
+
turn_prefills_into_decodes: bool = False):
|
| 323 |
+
"""
|
| 324 |
+
Update metadata in-place to advance one decode step.
|
| 325 |
+
"""
|
| 326 |
+
# When using cudagraph, the num_seqs is padded to the next captured
|
| 327 |
+
# batch sized, but num_queries tracks the actual number of requests in
|
| 328 |
+
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
| 329 |
+
if num_seqs != num_queries:
|
| 330 |
+
assert num_seqs > num_queries
|
| 331 |
+
|
| 332 |
+
if turn_prefills_into_decodes:
|
| 333 |
+
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
| 334 |
+
# decodes are scheduled together. In the first step, all the
|
| 335 |
+
# prefills turn into decodes. This update reflects that
|
| 336 |
+
# conversion.
|
| 337 |
+
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
| 338 |
+
self.num_decode_tokens += self.num_prefills
|
| 339 |
+
self.num_prefills = 0
|
| 340 |
+
self.num_prefill_tokens = 0
|
| 341 |
+
self.max_prefill_seq_len = 0
|
| 342 |
+
self.max_query_len = 1
|
| 343 |
+
|
| 344 |
+
self.slot_mapping = self.slot_mapping[:num_seqs]
|
| 345 |
+
else:
|
| 346 |
+
assert self.seq_lens is not None
|
| 347 |
+
assert self.max_decode_seq_len == max(self.seq_lens)
|
| 348 |
+
|
| 349 |
+
assert self.num_prefills == 0
|
| 350 |
+
assert self.num_prefill_tokens == 0
|
| 351 |
+
assert self.num_decode_tokens == num_seqs
|
| 352 |
+
assert self.slot_mapping.shape == (num_seqs, )
|
| 353 |
+
|
| 354 |
+
assert self.seq_lens is not None
|
| 355 |
+
assert len(self.seq_lens) == num_seqs
|
| 356 |
+
assert self.seq_lens_tensor is not None
|
| 357 |
+
assert self.seq_lens_tensor.shape == (num_seqs, )
|
| 358 |
+
assert self.max_query_len == 1
|
| 359 |
+
assert self.max_prefill_seq_len == 0
|
| 360 |
+
|
| 361 |
+
assert self.block_tables is not None
|
| 362 |
+
assert self.block_tables.shape[0] == num_seqs
|
| 363 |
+
|
| 364 |
+
# Update query lengths. Note that we update only queries and not seqs,
|
| 365 |
+
# since tensors may be padded due to captured cuda graph batch size
|
| 366 |
+
for i in range(num_queries):
|
| 367 |
+
self.seq_lens[i] += 1
|
| 368 |
+
self.max_decode_seq_len = max(self.seq_lens)
|
| 369 |
+
if enable_custom_op():
|
| 370 |
+
#advance a step on NPU for existing inputs for a multi-step runner if custom ops is enabled
|
| 371 |
+
torch.ops._C.advance_step_flashattn_ascendc(
|
| 372 |
+
num_seqs=num_seqs,
|
| 373 |
+
num_queries=num_queries,
|
| 374 |
+
block_size=block_size,
|
| 375 |
+
input_tokens=model_input.input_tokens,
|
| 376 |
+
sampled_token_ids=sampled_token_ids,
|
| 377 |
+
input_positions=model_input.input_positions,
|
| 378 |
+
seq_lens=self.seq_lens_tensor,
|
| 379 |
+
slot_mapping=self.slot_mapping,
|
| 380 |
+
block_tables=self.block_tables)
|
| 381 |
+
else:
|
| 382 |
+
# use traditional Pytorch method for updating these tensors.
|
| 383 |
+
# update input_tokens
|
| 384 |
+
sampled_token_ids_list = sampled_token_ids[:
|
| 385 |
+
num_queries].squeeze( # type: ignore
|
| 386 |
+
-1)
|
| 387 |
+
model_input.input_tokens[:
|
| 388 |
+
num_queries] = sampled_token_ids_list # type: ignore
|
| 389 |
+
|
| 390 |
+
# get seq_lens and input_positions
|
| 391 |
+
seq_lens = self.seq_lens_tensor[:num_queries]
|
| 392 |
+
next_seq_lens = seq_lens + 1
|
| 393 |
+
next_input_pos = next_seq_lens - 1
|
| 394 |
+
|
| 395 |
+
# update seq_lens and input_positions
|
| 396 |
+
self.seq_lens_tensor[:num_queries] = next_seq_lens
|
| 397 |
+
model_input.input_positions[:
|
| 398 |
+
num_queries] = next_input_pos # type: ignore
|
| 399 |
+
|
| 400 |
+
# 计算 block index 和 offset
|
| 401 |
+
block_idx = next_input_pos // block_size
|
| 402 |
+
block_offset = next_input_pos % block_size
|
| 403 |
+
|
| 404 |
+
current_block_table = self.block_tables.gather(
|
| 405 |
+
1, block_idx.unsqueeze(-1)).squeeze(-1)
|
| 406 |
+
slot_num = current_block_table * block_size + block_offset
|
| 407 |
+
|
| 408 |
+
# update slot_mapping
|
| 409 |
+
self.slot_mapping[:num_queries] = slot_num
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
| 413 |
+
|
| 414 |
+
_attn_mask_builder = None # noqa
|
| 415 |
+
|
| 416 |
+
def __init__(self, input_builder: "ModelInputForNPUBuilder"):
|
| 417 |
+
self.input_builder = input_builder
|
| 418 |
+
self.runner = input_builder.runner
|
| 419 |
+
self.sliding_window = input_builder.sliding_window
|
| 420 |
+
self.block_size = input_builder.block_size
|
| 421 |
+
|
| 422 |
+
self.attn_mask = None
|
| 423 |
+
self.compress_mask = None
|
| 424 |
+
self.chunk_mask = None
|
| 425 |
+
if AscendMetadataBuilder._attn_mask_builder is None:
|
| 426 |
+
AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder(
|
| 427 |
+
128, self.input_builder.runner.model_config.dtype)
|
| 428 |
+
|
| 429 |
+
def _add_seq_group(
|
| 430 |
+
self, inter_data: ModelInputForNPUBuilder.InterDataForSeqGroup,
|
| 431 |
+
chunked_prefill_enabled: bool):
|
| 432 |
+
"""Add a sequence group to the metadata. Specifically update/append
|
| 433 |
+
1. context length.
|
| 434 |
+
2. block table.
|
| 435 |
+
3. slot mapping.
|
| 436 |
+
"""
|
| 437 |
+
is_prompt = inter_data.is_prompt
|
| 438 |
+
block_tables = inter_data.block_tables
|
| 439 |
+
|
| 440 |
+
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
| 441 |
+
curr_sliding_window_block) in zip(
|
| 442 |
+
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
| 443 |
+
inter_data.orig_seq_lens, inter_data.seq_lens,
|
| 444 |
+
inter_data.query_lens, inter_data.context_lens,
|
| 445 |
+
inter_data.curr_sliding_window_blocks):
|
| 446 |
+
self.context_lens.append(context_len)
|
| 447 |
+
if is_prompt:
|
| 448 |
+
self.num_prefills += 1
|
| 449 |
+
self.num_prefill_tokens += token_len
|
| 450 |
+
self.prefill_seq_lens.append(seq_len)
|
| 451 |
+
else:
|
| 452 |
+
self.num_decode_tokens += query_len
|
| 453 |
+
self.curr_seq_lens.append(curr_seq_len)
|
| 454 |
+
|
| 455 |
+
# Compute block table.
|
| 456 |
+
# TODO(sang): Combine chunked prefill and prefix caching by
|
| 457 |
+
# only allowing multiple of block_size chunk size.
|
| 458 |
+
# NOTE: This only works for oooooooxxx style attention.
|
| 459 |
+
block_table: List[int] = []
|
| 460 |
+
prefix_cache_hit = any([
|
| 461 |
+
inter_data.prefix_cache_hit
|
| 462 |
+
for inter_data in self.input_builder.inter_data_list
|
| 463 |
+
])
|
| 464 |
+
if prefix_cache_hit:
|
| 465 |
+
# NOTE(woosuk): For flash-attn, the block table should
|
| 466 |
+
# include the entries for the incoming prefill tokens.
|
| 467 |
+
if block_tables is not None:
|
| 468 |
+
block_table = block_tables[seq_id]
|
| 469 |
+
elif ((chunked_prefill_enabled or not is_prompt)
|
| 470 |
+
and block_tables is not None):
|
| 471 |
+
if curr_sliding_window_block == 0:
|
| 472 |
+
block_table = block_tables[seq_id]
|
| 473 |
+
else:
|
| 474 |
+
block_table = block_tables[seq_id][
|
| 475 |
+
-curr_sliding_window_block:]
|
| 476 |
+
self.block_tables.append(block_table)
|
| 477 |
+
|
| 478 |
+
# Compute slot mapping.
|
| 479 |
+
is_profile_run = is_block_tables_empty(block_tables)
|
| 480 |
+
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
| 481 |
+
context_len,
|
| 482 |
+
self.sliding_window)
|
| 483 |
+
compute_slot_mapping(
|
| 484 |
+
is_profile_run,
|
| 485 |
+
self.slot_mapping,
|
| 486 |
+
seq_id,
|
| 487 |
+
seq_len,
|
| 488 |
+
context_len,
|
| 489 |
+
start_idx,
|
| 490 |
+
self.block_size,
|
| 491 |
+
inter_data.block_tables,
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
def _get_graph_runner_block_tables(
|
| 495 |
+
self, num_seqs: int,
|
| 496 |
+
block_tables: List[List[int]]) -> torch.Tensor:
|
| 497 |
+
# The shape of graph_block_tables is
|
| 498 |
+
# [max batch size, max context len // block size].
|
| 499 |
+
|
| 500 |
+
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
| 501 |
+
assert max_batch_size >= num_seqs
|
| 502 |
+
|
| 503 |
+
graph_block_tables = self.runner.graph_block_tables # [:num_seqs]
|
| 504 |
+
for i, block_table in enumerate(block_tables):
|
| 505 |
+
if block_table:
|
| 506 |
+
num_blocks = len(block_table)
|
| 507 |
+
if num_blocks <= max_blocks:
|
| 508 |
+
graph_block_tables[i, :num_blocks] = block_table
|
| 509 |
+
else:
|
| 510 |
+
graph_block_tables[
|
| 511 |
+
i, :max_blocks] = block_table[:max_blocks]
|
| 512 |
+
|
| 513 |
+
return torch.from_numpy(graph_block_tables).to(
|
| 514 |
+
device=self.runner.device, non_blocking=True)
|
| 515 |
+
|
| 516 |
+
def build(
|
| 517 |
+
self,
|
| 518 |
+
seq_lens: List[int],
|
| 519 |
+
query_lens: List[int],
|
| 520 |
+
graph_pad_size: int,
|
| 521 |
+
):
|
| 522 |
+
"""Build attention metadata with on-device tensors.
|
| 523 |
+
|
| 524 |
+
Args:
|
| 525 |
+
seq_lens: The maybe padded sequence lengths of the input sequences.
|
| 526 |
+
query_lens: The query lengths of the input sequences.
|
| 527 |
+
"""
|
| 528 |
+
for inter_data in self.input_builder.inter_data_list:
|
| 529 |
+
self._add_seq_group(inter_data,
|
| 530 |
+
self.input_builder.chunked_prefill_enabled)
|
| 531 |
+
|
| 532 |
+
device = self.runner.device
|
| 533 |
+
dtype = self.runner.model_config.dtype
|
| 534 |
+
use_npu_graph = graph_pad_size != -1
|
| 535 |
+
|
| 536 |
+
max_query_len = max(query_lens)
|
| 537 |
+
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
| 538 |
+
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
| 539 |
+
max_seq_len = max(max_prefill_seq_len, max_decode_seq_len)
|
| 540 |
+
num_decode_tokens = self.num_decode_tokens
|
| 541 |
+
|
| 542 |
+
if self.num_prefills == 0 and use_npu_graph:
|
| 543 |
+
num_seqs = len(seq_lens)
|
| 544 |
+
self.slot_mapping.extend([PAD_SLOT_ID] * graph_pad_size)
|
| 545 |
+
self.block_tables.extend([[]] * graph_pad_size)
|
| 546 |
+
block_tables = self._get_graph_runner_block_tables(
|
| 547 |
+
num_seqs, self.block_tables)
|
| 548 |
+
else:
|
| 549 |
+
block_tables = make_tensor_with_pad(
|
| 550 |
+
self.block_tables,
|
| 551 |
+
pad=0,
|
| 552 |
+
dtype=torch.int32,
|
| 553 |
+
device=device,
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
if self.num_prefills > 0:
|
| 557 |
+
if block_tables is None or block_tables.numel() == 0:
|
| 558 |
+
# normal mask
|
| 559 |
+
self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
| 560 |
+
max_prefill_seq_len, dtype, device)
|
| 561 |
+
if is_310p():
|
| 562 |
+
mask_nz = nd_to_nz_2d(self.attn_mask)
|
| 563 |
+
mask_nz = torch_npu.npu_format_cast(
|
| 564 |
+
mask_nz.contiguous(), ACL_FORMAT_FRACTAL_NZ)
|
| 565 |
+
self.attn_mask = mask_nz
|
| 566 |
+
elif self.num_decode_tokens == 0 and not self.input_builder.chunked_prefill_enabled:
|
| 567 |
+
# compress mask for prefix cache
|
| 568 |
+
self.compress_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
| 569 |
+
128, dtype, device)
|
| 570 |
+
else:
|
| 571 |
+
# chunk_mask for chunk prefill
|
| 572 |
+
attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
|
| 573 |
+
max_seq_len, dtype, device)
|
| 574 |
+
if attn_mask.numel() > 1 and attn_mask[0][1] > 0:
|
| 575 |
+
# Do not use in-place multiplication to avoid modifying `attn_mask_cache`!
|
| 576 |
+
attn_mask = attn_mask * -10000
|
| 577 |
+
chunk_mask_list = []
|
| 578 |
+
for i, seq_len in enumerate(seq_lens):
|
| 579 |
+
context_len = self.context_lens[i]
|
| 580 |
+
chunk_mask_list.append(attn_mask[context_len:seq_len])
|
| 581 |
+
self.chunk_mask = torch.cat(chunk_mask_list, 0)
|
| 582 |
+
else:
|
| 583 |
+
self.attn_mask = None
|
| 584 |
+
self.compress_mask = None
|
| 585 |
+
self.chunk_mask = None
|
| 586 |
+
|
| 587 |
+
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
| 588 |
+
|
| 589 |
+
assert device is not None
|
| 590 |
+
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
|
| 591 |
+
device, self.runner.pin_memory)
|
| 592 |
+
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
| 593 |
+
self.runner.pin_memory)
|
| 594 |
+
placeholder_index_maps = {
|
| 595 |
+
modality: placeholder_map.index_map()
|
| 596 |
+
for modality, placeholder_map in
|
| 597 |
+
self.multimodal_placeholder_maps.items()
|
| 598 |
+
}
|
| 599 |
+
|
| 600 |
+
return AscendMetadata(
|
| 601 |
+
num_prefills=self.num_prefills,
|
| 602 |
+
slot_mapping=slot_mapping_tensor,
|
| 603 |
+
num_prefill_tokens=self.num_prefill_tokens,
|
| 604 |
+
num_decode_tokens=num_decode_tokens,
|
| 605 |
+
seq_lens=seq_lens,
|
| 606 |
+
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
| 607 |
+
enable_kv_scales_calculation=True,
|
| 608 |
+
seq_lens_tensor=seq_lens_tensor,
|
| 609 |
+
query_lens=query_lens,
|
| 610 |
+
max_query_len=max_query_len,
|
| 611 |
+
max_prefill_seq_len=max_prefill_seq_len,
|
| 612 |
+
max_decode_seq_len=max_decode_seq_len,
|
| 613 |
+
block_tables=block_tables,
|
| 614 |
+
attn_mask=self.attn_mask,
|
| 615 |
+
compress_mask=self.compress_mask,
|
| 616 |
+
chunk_mask=self.chunk_mask,
|
| 617 |
+
chunked_prefill_enabled=self.input_builder.chunked_prefill_enabled,
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class AscendAttentionBackendImpl(AttentionImpl):
|
| 622 |
+
|
| 623 |
+
def __init__(
|
| 624 |
+
self,
|
| 625 |
+
num_heads: int,
|
| 626 |
+
head_size: int,
|
| 627 |
+
scale: float,
|
| 628 |
+
num_kv_heads: int,
|
| 629 |
+
alibi_slopes: Optional[List[float]],
|
| 630 |
+
sliding_window: Optional[int],
|
| 631 |
+
kv_cache_dtype: str,
|
| 632 |
+
blocksparse_params: Optional[Dict[str, Any]] = None,
|
| 633 |
+
logits_soft_cap: Optional[float] = None,
|
| 634 |
+
attn_type: str = AttentionType.DECODER,
|
| 635 |
+
kv_sharing_target_layer_name: Optional[str] = None,
|
| 636 |
+
use_irope: bool = False,
|
| 637 |
+
) -> None:
|
| 638 |
+
self.num_heads = num_heads
|
| 639 |
+
self.head_size = head_size
|
| 640 |
+
self.scale = float(scale)
|
| 641 |
+
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
| 642 |
+
self.hidden_size = self.num_heads * self.head_size
|
| 643 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 644 |
+
self.sliding_window = sliding_window
|
| 645 |
+
if alibi_slopes is not None:
|
| 646 |
+
alibi_slopes = torch.tensor(alibi_slopes,
|
| 647 |
+
dtype=torch.float32,
|
| 648 |
+
device="npu")
|
| 649 |
+
self.alibi_slopes = alibi_slopes
|
| 650 |
+
self.attn_type = attn_type
|
| 651 |
+
|
| 652 |
+
assert self.num_heads % self.num_kv_heads == 0
|
| 653 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 654 |
+
self.seq_len_cpu_tensor = None
|
| 655 |
+
self.query_len_cpu_tensor = None
|
| 656 |
+
self.key_cache = None
|
| 657 |
+
self.value_cache = None
|
| 658 |
+
|
| 659 |
+
def forward(
|
| 660 |
+
self,
|
| 661 |
+
layer: AttentionLayer,
|
| 662 |
+
query: torch.Tensor,
|
| 663 |
+
key: torch.Tensor,
|
| 664 |
+
value: torch.Tensor,
|
| 665 |
+
kv_cache: torch.Tensor,
|
| 666 |
+
attn_metadata: AscendMetadata,
|
| 667 |
+
attn_type: str = AttentionType.DECODER,
|
| 668 |
+
output: Optional[torch.Tensor] = None,
|
| 669 |
+
) -> torch.Tensor:
|
| 670 |
+
"""Forward pass with Ascend attention.
|
| 671 |
+
Args:
|
| 672 |
+
query: shape = [num_tokens, num_heads * head_size]
|
| 673 |
+
num_tokens = batch_size * seq_len
|
| 674 |
+
key: shape = [num_tokens, num_kv_heads * head_size]
|
| 675 |
+
value: shape = [num_tokens, num_kv_heads * head_size]
|
| 676 |
+
kv_cache: shape = [2, num_blocks, block_size,
|
| 677 |
+
num_kv_heads, head_size]
|
| 678 |
+
key_cache = [num_blocks, block_size,
|
| 679 |
+
num_kv_heads, head_size]
|
| 680 |
+
value_cache = [num_blocks, block_size,
|
| 681 |
+
num_kv_heads, head_size]
|
| 682 |
+
attn_metadata: Metadata for attention.
|
| 683 |
+
Returns:
|
| 684 |
+
shape = [batch_size, seq_len * num_heads * head_size]
|
| 685 |
+
"""
|
| 686 |
+
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
| 687 |
+
# View q k v to BSH.
|
| 688 |
+
num_tokens = query.shape[0]
|
| 689 |
+
query = query.view(-1, self.num_heads, self.head_size)
|
| 690 |
+
key = key.view(-1, self.num_kv_heads, self.head_size)
|
| 691 |
+
value = value.view(-1, self.num_kv_heads, self.head_size)
|
| 692 |
+
# TODO: Remove this contiguous in the future.
|
| 693 |
+
value = value.contiguous()
|
| 694 |
+
attn_type = self.attn_type
|
| 695 |
+
|
| 696 |
+
output = torch.empty(num_tokens,
|
| 697 |
+
self.num_heads,
|
| 698 |
+
self.head_size,
|
| 699 |
+
dtype=query.dtype,
|
| 700 |
+
device=query.device)
|
| 701 |
+
|
| 702 |
+
if kv_cache.numel() > 0:
|
| 703 |
+
if self.key_cache is None:
|
| 704 |
+
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
| 705 |
+
slots = attn_metadata.slot_mapping
|
| 706 |
+
|
| 707 |
+
if hasattr(layer, 'quant_method'):
|
| 708 |
+
isPrefill = True if attn_metadata.num_prefills > 0 else False
|
| 709 |
+
if isPrefill:
|
| 710 |
+
assert attn_metadata.prefill_metadata is not None
|
| 711 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 712 |
+
np.array(attn_metadata.prefill_metadata.seq_lens).astype(
|
| 713 |
+
np.int32))
|
| 714 |
+
else:
|
| 715 |
+
assert attn_metadata.decode_metadata is not None
|
| 716 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 717 |
+
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
| 718 |
+
np.int32))
|
| 719 |
+
block_tables = attn_metadata.decode_metadata.block_tables if attn_metadata.decode_metadata else None
|
| 720 |
+
# Details of kv_cache arrangement in attention quantization
|
| 721 |
+
# are implemented by quant_method.
|
| 722 |
+
layer.quant_method.apply(
|
| 723 |
+
layer,
|
| 724 |
+
query,
|
| 725 |
+
key,
|
| 726 |
+
value,
|
| 727 |
+
self.key_cache,
|
| 728 |
+
self.value_cache,
|
| 729 |
+
self.scale,
|
| 730 |
+
block_tables,
|
| 731 |
+
isPrefill,
|
| 732 |
+
attn_metadata,
|
| 733 |
+
output,
|
| 734 |
+
seq_lens_tensor_cpu=self.seq_lens_tensor_cpu)
|
| 735 |
+
else:
|
| 736 |
+
if self.key_cache is not None:
|
| 737 |
+
torch_npu._npu_reshape_and_cache(key=key,
|
| 738 |
+
value=value,
|
| 739 |
+
key_cache=self.key_cache,
|
| 740 |
+
value_cache=self.value_cache,
|
| 741 |
+
slot_indices=slots)
|
| 742 |
+
|
| 743 |
+
if attn_metadata.num_prefills > 0:
|
| 744 |
+
# Prefix cache disabled and chunk prefill disabled or no prefix cache hit
|
| 745 |
+
if (attn_metadata.block_tables is None
|
| 746 |
+
or attn_metadata.block_tables.numel() == 0):
|
| 747 |
+
if attn_type == AttentionType.ENCODER_ONLY:
|
| 748 |
+
# TODO: change to use torch_npu encoder attention op, instead
|
| 749 |
+
# of torch sdpa
|
| 750 |
+
query = query.movedim(0, query.dim() - 2)
|
| 751 |
+
key = key.movedim(0, key.dim() - 2)
|
| 752 |
+
value = value.movedim(0, value.dim() - 2)
|
| 753 |
+
|
| 754 |
+
causal_attn = (attn_type == AttentionType.DECODER)
|
| 755 |
+
if attn_metadata.seq_lens is not None:
|
| 756 |
+
seq_lens_q = seq_lens_kv = attn_metadata.seq_lens
|
| 757 |
+
attn_masks = [None] * len(seq_lens_q)
|
| 758 |
+
start_q, start_kv = 0, 0
|
| 759 |
+
for seq_len_q, seq_len_kv, mask in zip(
|
| 760 |
+
seq_lens_q, seq_lens_kv, attn_masks):
|
| 761 |
+
end_q = start_q + seq_len_q
|
| 762 |
+
end_kv = start_kv + seq_len_kv
|
| 763 |
+
sub_out = scaled_dot_product_attention(
|
| 764 |
+
query[None, :, start_q:end_q, :],
|
| 765 |
+
key[None, :, start_kv:end_kv, :],
|
| 766 |
+
value[None, :, start_kv:end_kv, :],
|
| 767 |
+
attn_mask=mask,
|
| 768 |
+
dropout_p=0.0,
|
| 769 |
+
is_causal=causal_attn and mask is None,
|
| 770 |
+
scale=self.scale).squeeze(0).movedim(
|
| 771 |
+
query.dim() - 2, 0)
|
| 772 |
+
output[start_q:end_q, :, :] = sub_out
|
| 773 |
+
start_q, start_kv = end_q, end_kv
|
| 774 |
+
else:
|
| 775 |
+
assert attn_metadata.attn_mask is not None
|
| 776 |
+
mask = attn_metadata.attn_mask
|
| 777 |
+
assert attn_metadata.prefill_metadata is not None
|
| 778 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 779 |
+
np.array(attn_metadata.prefill_metadata.seq_lens).
|
| 780 |
+
astype(np.int32))
|
| 781 |
+
if is_310p():
|
| 782 |
+
# align q k v output tensors
|
| 783 |
+
query = aligned_16(query)
|
| 784 |
+
key = aligned_16(key)
|
| 785 |
+
value = aligned_16(value)
|
| 786 |
+
output = aligned_16(output)
|
| 787 |
+
|
| 788 |
+
# do reformat in case of broadcasted tensors
|
| 789 |
+
mask = mask.repeat(
|
| 790 |
+
self.seq_lens_tensor_cpu.size(0), 1, 1, 1)
|
| 791 |
+
mask = torch_npu.npu_format_cast(
|
| 792 |
+
mask.contiguous(), ACL_FORMAT_FRACTAL_NZ)
|
| 793 |
+
torch_npu._npu_flash_attention(
|
| 794 |
+
query=query,
|
| 795 |
+
key=key,
|
| 796 |
+
value=value,
|
| 797 |
+
mask=mask,
|
| 798 |
+
seq_len=self.seq_lens_tensor_cpu,
|
| 799 |
+
scale_value=self.scale,
|
| 800 |
+
num_heads=self.num_heads,
|
| 801 |
+
num_kv_heads=self.num_kv_heads,
|
| 802 |
+
out=output)
|
| 803 |
+
output = output[:num_tokens, :, :]
|
| 804 |
+
# Prefix cache only and cache hit
|
| 805 |
+
elif attn_metadata.num_decode_tokens == 0 and not attn_metadata.chunked_prefill_enabled:
|
| 806 |
+
assert kv_cache is not None
|
| 807 |
+
assert attn_metadata.prefill_metadata is not None
|
| 808 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 809 |
+
np.array(
|
| 810 |
+
attn_metadata.prefill_metadata.seq_lens).astype(
|
| 811 |
+
np.int32))
|
| 812 |
+
self.query_lens_tensor_cpu = torch.from_numpy(
|
| 813 |
+
np.array(
|
| 814 |
+
attn_metadata.prefill_metadata.query_lens).astype(
|
| 815 |
+
np.int32))
|
| 816 |
+
block_tables = attn_metadata.prefill_metadata.block_tables
|
| 817 |
+
assert attn_metadata.compress_mask is not None
|
| 818 |
+
compress_mask = attn_metadata.compress_mask
|
| 819 |
+
torch_npu._npu_flash_attention_qlens(
|
| 820 |
+
query=query,
|
| 821 |
+
key_cache=self.key_cache,
|
| 822 |
+
value_cache=self.value_cache,
|
| 823 |
+
block_table=block_tables,
|
| 824 |
+
mask=compress_mask,
|
| 825 |
+
seq_len=self.query_lens_tensor_cpu,
|
| 826 |
+
context_lens=self.seq_lens_tensor_cpu,
|
| 827 |
+
num_kv_heads=self.num_kv_heads,
|
| 828 |
+
num_heads=self.num_heads,
|
| 829 |
+
scale_value=self.scale,
|
| 830 |
+
out=output)
|
| 831 |
+
# Splitfuse
|
| 832 |
+
else:
|
| 833 |
+
assert kv_cache is not None
|
| 834 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 835 |
+
np.array(attn_metadata.seq_lens).astype(np.int32))
|
| 836 |
+
self.query_lens_tensor_cpu = torch.from_numpy(
|
| 837 |
+
np.array(attn_metadata.query_lens).astype(np.int32))
|
| 838 |
+
block_tables = attn_metadata.block_tables
|
| 839 |
+
assert attn_metadata.chunk_mask is not None
|
| 840 |
+
chunk_mask = attn_metadata.chunk_mask
|
| 841 |
+
torch_npu._npu_paged_attention_splitfuse(
|
| 842 |
+
query=query,
|
| 843 |
+
key_cache=self.key_cache,
|
| 844 |
+
value_cache=self.value_cache,
|
| 845 |
+
block_table=block_tables,
|
| 846 |
+
context_lens=self.seq_lens_tensor_cpu,
|
| 847 |
+
mask=chunk_mask,
|
| 848 |
+
seq_len=self.query_lens_tensor_cpu,
|
| 849 |
+
num_kv_heads=self.num_kv_heads,
|
| 850 |
+
num_heads=self.num_heads,
|
| 851 |
+
scale_value=self.scale,
|
| 852 |
+
out=output)
|
| 853 |
+
# Decode only
|
| 854 |
+
else:
|
| 855 |
+
assert self.key_cache is not None
|
| 856 |
+
assert self.value_cache is not None
|
| 857 |
+
assert attn_metadata.decode_metadata is not None
|
| 858 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 859 |
+
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
| 860 |
+
np.int32))
|
| 861 |
+
if is_310p():
|
| 862 |
+
# # seq_lens_tensor needs to be transferred to the device for 310P
|
| 863 |
+
self.seq_lens_tensor_cpu = self.seq_lens_tensor_cpu.to(
|
| 864 |
+
device=self.key_cache.device)
|
| 865 |
+
block_tables = attn_metadata.decode_metadata.block_tables
|
| 866 |
+
torch_npu._npu_paged_attention(
|
| 867 |
+
query=query,
|
| 868 |
+
key_cache=self.key_cache,
|
| 869 |
+
value_cache=self.value_cache,
|
| 870 |
+
num_kv_heads=self.num_kv_heads,
|
| 871 |
+
num_heads=self.num_heads,
|
| 872 |
+
scale_value=self.scale,
|
| 873 |
+
block_table=block_tables,
|
| 874 |
+
context_lens=self.seq_lens_tensor_cpu,
|
| 875 |
+
out=output)
|
| 876 |
+
|
| 877 |
+
return output.view(num_tokens, self.hidden_size)
|
| 878 |
+
|
| 879 |
+
|
| 880 |
+
class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
|
| 881 |
+
|
| 882 |
+
def __init__(
|
| 883 |
+
self,
|
| 884 |
+
num_heads: int,
|
| 885 |
+
head_size: int,
|
| 886 |
+
scale: float,
|
| 887 |
+
num_kv_heads: int,
|
| 888 |
+
alibi_slopes: Optional[List[float]],
|
| 889 |
+
sliding_window: Optional[int],
|
| 890 |
+
kv_cache_dtype: str,
|
| 891 |
+
blocksparse_params: Optional[Dict[str, Any]] = None,
|
| 892 |
+
logits_soft_cap: Optional[float] = None,
|
| 893 |
+
attn_type: str = AttentionType.DECODER,
|
| 894 |
+
kv_sharing_target_layer_name: Optional[str] = None,
|
| 895 |
+
**extra_impl_args,
|
| 896 |
+
) -> None:
|
| 897 |
+
self.num_heads = num_heads
|
| 898 |
+
self.head_size = head_size
|
| 899 |
+
self.scale = float(scale)
|
| 900 |
+
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
| 901 |
+
self.hidden_size = self.num_heads * self.head_size
|
| 902 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 903 |
+
self.sliding_window = sliding_window
|
| 904 |
+
if alibi_slopes is not None:
|
| 905 |
+
alibi_slopes = torch.tensor(alibi_slopes,
|
| 906 |
+
dtype=torch.float32,
|
| 907 |
+
device="npu")
|
| 908 |
+
self.alibi_slopes = alibi_slopes
|
| 909 |
+
self.attn_type = attn_type
|
| 910 |
+
|
| 911 |
+
assert self.num_heads % self.num_kv_heads == 0
|
| 912 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 913 |
+
self.seq_len_cpu_tensor = None
|
| 914 |
+
|
| 915 |
+
# MLA Args
|
| 916 |
+
self.q_lora_rank = extra_impl_args['q_lora_rank']
|
| 917 |
+
self.kv_lora_rank = extra_impl_args['kv_lora_rank']
|
| 918 |
+
self.qk_nope_head_dim = extra_impl_args['qk_nope_head_dim']
|
| 919 |
+
self.qk_rope_head_dim = extra_impl_args['qk_rope_head_dim']
|
| 920 |
+
self.qk_head_dim = extra_impl_args['qk_head_dim']
|
| 921 |
+
self.v_head_dim = extra_impl_args['v_head_dim']
|
| 922 |
+
self.rotary_emb = extra_impl_args['rotary_emb']
|
| 923 |
+
self.q_proj = extra_impl_args['q_proj']
|
| 924 |
+
self.kv_b_proj = extra_impl_args['kv_b_proj']
|
| 925 |
+
self.o_proj = extra_impl_args['o_proj']
|
| 926 |
+
self.kv_a_proj_with_mqa = extra_impl_args.get('kv_a_proj_with_mqa',
|
| 927 |
+
None)
|
| 928 |
+
self.kv_a_layernorm = extra_impl_args.get('kv_a_layernorm', None)
|
| 929 |
+
self.k_pe_cache = None
|
| 930 |
+
self.k_nope_cache = None
|
| 931 |
+
self.w_kc = None
|
| 932 |
+
self.w_vc = None
|
| 933 |
+
|
| 934 |
+
ascend_config = get_ascend_config()
|
| 935 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 936 |
+
|
| 937 |
+
|
| 938 |
+
def exec_kv(
|
| 939 |
+
self,
|
| 940 |
+
hidden_states: torch.Tensor,
|
| 941 |
+
cos: torch.Tensor,
|
| 942 |
+
sin: torch.Tensor,
|
| 943 |
+
kv_cache: Tuple,
|
| 944 |
+
slots: torch.Tensor,
|
| 945 |
+
):
|
| 946 |
+
B = hidden_states.shape[0]
|
| 947 |
+
N = self.num_kv_heads
|
| 948 |
+
S = 1
|
| 949 |
+
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
| 950 |
+
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
| 951 |
+
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
| 952 |
+
|
| 953 |
+
k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
|
| 954 |
+
kv,
|
| 955 |
+
self.kv_a_layernorm.weight,
|
| 956 |
+
cos,
|
| 957 |
+
sin,
|
| 958 |
+
slots.to(torch.int64),
|
| 959 |
+
kv_cache[1],
|
| 960 |
+
kv_cache[0],
|
| 961 |
+
epsilon=self.kv_a_layernorm.variance_epsilon,
|
| 962 |
+
cache_mode="PA",
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
return k_pe, k_nope
|
| 966 |
+
|
| 967 |
+
def apply_rotary_emb(
|
| 968 |
+
self,
|
| 969 |
+
x: torch.Tensor,
|
| 970 |
+
cos: torch.Tensor,
|
| 971 |
+
sin: torch.Tensor,
|
| 972 |
+
is_neox_style: bool,
|
| 973 |
+
) -> torch.Tensor:
|
| 974 |
+
"""
|
| 975 |
+
Args:
|
| 976 |
+
x: [num_tokens, num_heads, head_size]
|
| 977 |
+
cos: [num_tokens, head_size // 2]
|
| 978 |
+
sin: [num_tokens, head_size // 2]
|
| 979 |
+
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
| 980 |
+
positional embeddings.
|
| 981 |
+
"""
|
| 982 |
+
cos = cos.unsqueeze(-2).to(x.dtype)
|
| 983 |
+
sin = sin.unsqueeze(-2).to(x.dtype)
|
| 984 |
+
if is_neox_style:
|
| 985 |
+
x1, x2 = torch.chunk(x, 2, dim=-1)
|
| 986 |
+
else:
|
| 987 |
+
x1 = x[..., ::2]
|
| 988 |
+
x2 = x[..., 1::2]
|
| 989 |
+
o1 = x1 * cos - x2 * sin
|
| 990 |
+
o2 = x2 * cos + x1 * sin
|
| 991 |
+
if is_neox_style:
|
| 992 |
+
return torch.cat((o1, o2), dim=-1)
|
| 993 |
+
else:
|
| 994 |
+
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
| 995 |
+
|
| 996 |
+
def rope_single(
|
| 997 |
+
self,
|
| 998 |
+
x: torch.Tensor,
|
| 999 |
+
cos: torch.Tensor,
|
| 1000 |
+
sin: torch.Tensor,
|
| 1001 |
+
) -> torch.Tensor:
|
| 1002 |
+
B, N, D = x.shape
|
| 1003 |
+
S = 1
|
| 1004 |
+
x = x.view(B, N, S, D)
|
| 1005 |
+
x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
|
| 1006 |
+
return x.view(B, N, D)
|
| 1007 |
+
|
| 1008 |
+
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
| 1009 |
+
if self.w_kc is None or self.w_vc is None:
|
| 1010 |
+
kv_b_proj_weight = self.kv_b_proj.weight.reshape(
|
| 1011 |
+
self.num_heads, self.qk_nope_head_dim + self.v_head_dim,
|
| 1012 |
+
self.kv_lora_rank)
|
| 1013 |
+
self.w_kc = kv_b_proj_weight[:, :self.
|
| 1014 |
+
qk_nope_head_dim, :].contiguous()
|
| 1015 |
+
self.w_vc = kv_b_proj_weight[:,
|
| 1016 |
+
self.qk_nope_head_dim:, :].transpose(
|
| 1017 |
+
1, 2).contiguous()
|
| 1018 |
+
|
| 1019 |
+
def forward(
|
| 1020 |
+
self,
|
| 1021 |
+
layer: AttentionLayer,
|
| 1022 |
+
hidden_states_or_q_c: torch.Tensor,
|
| 1023 |
+
hidden_states_or_kv_c_normed: torch.Tensor,
|
| 1024 |
+
k_pe: torch.Tensor,
|
| 1025 |
+
kv_cache: torch.Tensor,
|
| 1026 |
+
attn_metadata: AscendMetadata,
|
| 1027 |
+
attn_type: str = AttentionType.DECODER,
|
| 1028 |
+
output: Optional[torch.Tensor] = None,
|
| 1029 |
+
) -> torch.Tensor:
|
| 1030 |
+
"""Forward pass with Ascend attention.
|
| 1031 |
+
Args:
|
| 1032 |
+
hidden_states_or_q_c: shape = [num_tokens, num_heads * head_size]
|
| 1033 |
+
num_tokens = batch_size * seq_len
|
| 1034 |
+
hidden_states_or_kv_c_normed: shape = [num_tokens, num_kv_heads * head_size]
|
| 1035 |
+
k_pe: shape = [num_tokens, num_kv_heads * head_size]
|
| 1036 |
+
kv_cache: shape = [1, num_blocks, block_size,
|
| 1037 |
+
num_kv_heads * head_size]
|
| 1038 |
+
attn_metadata: Metadata for attention.
|
| 1039 |
+
Returns:
|
| 1040 |
+
shape = [batch_size, seq_len * num_heads * head_size]
|
| 1041 |
+
"""
|
| 1042 |
+
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
| 1043 |
+
attn_type = self.attn_type
|
| 1044 |
+
if attn_type != AttentionType.DECODER:
|
| 1045 |
+
raise NotImplementedError("Encoder self-attention and "
|
| 1046 |
+
"encoder/decoder cross-attention "
|
| 1047 |
+
"are not implemented for "
|
| 1048 |
+
"PallasAttentionBackendImpl")
|
| 1049 |
+
|
| 1050 |
+
if attn_metadata is None:
|
| 1051 |
+
# for profile run
|
| 1052 |
+
return hidden_states_or_q_c
|
| 1053 |
+
|
| 1054 |
+
num_tokens = hidden_states_or_q_c.shape[0]
|
| 1055 |
+
q = self.q_proj(hidden_states_or_q_c)[0].view(-1, self.num_heads,
|
| 1056 |
+
self.qk_head_dim)
|
| 1057 |
+
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
| 1058 |
+
dim=-1)
|
| 1059 |
+
if k_pe is None and attn_metadata.decode_metadata:
|
| 1060 |
+
seq_len = self.rotary_emb.max_position_embeddings
|
| 1061 |
+
|
| 1062 |
+
cos = self.rotary_emb.cos_cached[:seq_len].to(dtype=q_pe.dtype)
|
| 1063 |
+
sin = self.rotary_emb.sin_cached[:seq_len].to(dtype=q_pe.dtype)
|
| 1064 |
+
cos = cos[attn_metadata.input_positions]
|
| 1065 |
+
sin = sin[attn_metadata.input_positions]
|
| 1066 |
+
cos = cos[:, None, None, :]
|
| 1067 |
+
sin = sin[:, None, None, :]
|
| 1068 |
+
|
| 1069 |
+
q_pe = self.rope_single(q_pe, cos, sin)
|
| 1070 |
+
k_pe, k_nope = self.exec_kv(hidden_states_or_kv_c_normed, cos, sin,
|
| 1071 |
+
kv_cache, attn_metadata.slot_mapping)
|
| 1072 |
+
else:
|
| 1073 |
+
if k_pe is None:
|
| 1074 |
+
# NOTE: k_pe is None when graph mode enabled
|
| 1075 |
+
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
| 1076 |
+
hidden_states_or_kv_c_normed)[0].split(
|
| 1077 |
+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
| 1078 |
+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
| 1079 |
+
else:
|
| 1080 |
+
kv_c_normed = hidden_states_or_kv_c_normed
|
| 1081 |
+
k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)
|
| 1082 |
+
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
|
| 1083 |
+
# NOTE: When scaling not specified
|
| 1084 |
+
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
| 1085 |
+
q_pe = q_pe.reshape(num_tokens, -1)
|
| 1086 |
+
k_pe = k_pe.reshape(num_tokens, -1)
|
| 1087 |
+
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
|
| 1088 |
+
q_pe, k_pe)
|
| 1089 |
+
q_pe = q_pe.view(ori_q_pe_shape)
|
| 1090 |
+
k_pe = k_pe.view(ori_k_pe_shape)
|
| 1091 |
+
else:
|
| 1092 |
+
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
|
| 1093 |
+
q_pe, k_pe)
|
| 1094 |
+
|
| 1095 |
+
if attn_metadata.num_prefills > 0:
|
| 1096 |
+
kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens,
|
| 1097 |
+
self.num_heads, -1)
|
| 1098 |
+
k_nope, value = kv.split([self.qk_nope_head_dim, self.v_head_dim],
|
| 1099 |
+
dim=-1)
|
| 1100 |
+
else:
|
| 1101 |
+
q_nope_t = torch.transpose(q_nope, 0, 1)
|
| 1102 |
+
q_nope_out = torch.bmm(q_nope_t, self.w_kc)
|
| 1103 |
+
q_nope = torch.transpose(q_nope_out, 0, 1)
|
| 1104 |
+
|
| 1105 |
+
query = torch.cat([q_nope, q_pe], dim=-1).view(num_tokens,
|
| 1106 |
+
self.num_heads, -1)
|
| 1107 |
+
|
| 1108 |
+
# TODO: Replace the env with more flexible expressions
|
| 1109 |
+
if self.torchair_graph_enabled:
|
| 1110 |
+
if len(kv_cache) > 0 and kv_cache[0].numel(
|
| 1111 |
+
) > 0 and attn_metadata.num_prefills > 0:
|
| 1112 |
+
slots = attn_metadata.slot_mapping
|
| 1113 |
+
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
|
| 1114 |
+
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
|
| 1115 |
+
num_tokens, self.num_kv_heads, -1),
|
| 1116 |
+
value=k_pe,
|
| 1117 |
+
key_cache=kv_cache[0],
|
| 1118 |
+
value_cache=kv_cache[1],
|
| 1119 |
+
slot_indices=slots)
|
| 1120 |
+
elif kv_cache.numel() > 0:
|
| 1121 |
+
# TODO replace this naive implement with fusion kernel
|
| 1122 |
+
concat_and_cache_mla(kv_c_normed, k_pe, kv_cache,
|
| 1123 |
+
attn_metadata.slot_mapping)
|
| 1124 |
+
|
| 1125 |
+
if attn_metadata.num_prefills > 0:
|
| 1126 |
+
attn_output = torch.empty(num_tokens,
|
| 1127 |
+
self.num_heads,
|
| 1128 |
+
self.v_head_dim,
|
| 1129 |
+
dtype=query.dtype,
|
| 1130 |
+
device=query.device)
|
| 1131 |
+
if (attn_metadata.block_tables is None
|
| 1132 |
+
or attn_metadata.block_tables.numel() == 0):
|
| 1133 |
+
assert attn_metadata.attn_mask is not None
|
| 1134 |
+
assert attn_metadata.prefill_metadata is not None
|
| 1135 |
+
assert attn_metadata.prefill_metadata.seq_lens is not None
|
| 1136 |
+
mask = attn_metadata.attn_mask
|
| 1137 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 1138 |
+
np.array(attn_metadata.prefill_metadata.seq_lens).astype(
|
| 1139 |
+
np.int32))
|
| 1140 |
+
k_pe = k_pe.repeat(1, self.num_heads, 1)
|
| 1141 |
+
key = torch.cat(
|
| 1142 |
+
[k_nope.view(num_tokens, self.num_heads, -1), k_pe], dim=2)
|
| 1143 |
+
torch_npu._npu_flash_attention(
|
| 1144 |
+
query=query,
|
| 1145 |
+
key=key,
|
| 1146 |
+
value=value,
|
| 1147 |
+
mask=mask,
|
| 1148 |
+
seq_len=self.seq_lens_tensor_cpu,
|
| 1149 |
+
scale_value=self.scale,
|
| 1150 |
+
num_heads=self.num_heads,
|
| 1151 |
+
num_kv_heads=self.num_heads,
|
| 1152 |
+
out=attn_output)
|
| 1153 |
+
else:
|
| 1154 |
+
# TODO: Will support prefix cache and chunked prefill soon.
|
| 1155 |
+
raise RuntimeError(
|
| 1156 |
+
"Prefix cache and chunked prefill are currently not supported."
|
| 1157 |
+
)
|
| 1158 |
+
elif attn_metadata.decode_metadata:
|
| 1159 |
+
assert kv_cache is not None
|
| 1160 |
+
if self.torchair_graph_enabled:
|
| 1161 |
+
# shape of query for npu graph mode should be:
|
| 1162 |
+
# [bs, num_heads_per_rank, seq_len, dim]
|
| 1163 |
+
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
| 1164 |
+
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
| 1165 |
+
# shape of knope/k_pe for npu graph mode should be:
|
| 1166 |
+
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
| 1167 |
+
block_size = kv_cache[0].shape[1]
|
| 1168 |
+
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
| 1169 |
+
self.kv_lora_rank)
|
| 1170 |
+
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
| 1171 |
+
self.qk_rope_head_dim)
|
| 1172 |
+
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
|
| 1173 |
+
q_nope,
|
| 1174 |
+
k_nope,
|
| 1175 |
+
k_nope,
|
| 1176 |
+
query_rope=q_pe,
|
| 1177 |
+
key_rope=k_pe,
|
| 1178 |
+
num_heads=self.num_heads,
|
| 1179 |
+
num_key_value_heads=self.num_kv_heads,
|
| 1180 |
+
input_layout="BNSD",
|
| 1181 |
+
atten_mask=attn_metadata.attn_mask,
|
| 1182 |
+
scale=self.scale,
|
| 1183 |
+
antiquant_mode=0,
|
| 1184 |
+
antiquant_scale=None,
|
| 1185 |
+
block_table=attn_metadata.block_tables,
|
| 1186 |
+
block_size=block_size,
|
| 1187 |
+
actual_seq_lengths_kv=attn_metadata.seq_lens,
|
| 1188 |
+
)
|
| 1189 |
+
attn_output = attn_output.view(num_tokens, -1,
|
| 1190 |
+
self.kv_lora_rank).transpose(
|
| 1191 |
+
0, 1)
|
| 1192 |
+
attn_output = torch.bmm(attn_output, self.w_vc).transpose(0, 1)
|
| 1193 |
+
else:
|
| 1194 |
+
# if torch.empty is used here, the preemptive scheduling case of
|
| 1195 |
+
# test_mtp_correctness.py will fail to run.
|
| 1196 |
+
attn_output = torch.randn(
|
| 1197 |
+
[num_tokens, self.num_heads, self.kv_lora_rank],
|
| 1198 |
+
dtype=query.dtype,
|
| 1199 |
+
device=query.device)
|
| 1200 |
+
self.seq_lens_tensor_cpu = torch.from_numpy(
|
| 1201 |
+
np.array(attn_metadata.decode_metadata.seq_lens).astype(
|
| 1202 |
+
np.int32))
|
| 1203 |
+
block_tables = attn_metadata.decode_metadata.block_tables
|
| 1204 |
+
torch_npu._npu_paged_attention_mla(
|
| 1205 |
+
query=query,
|
| 1206 |
+
key_cache=kv_cache,
|
| 1207 |
+
num_kv_heads=self.num_kv_heads,
|
| 1208 |
+
num_heads=self.num_heads,
|
| 1209 |
+
scale_value=self.scale,
|
| 1210 |
+
block_table=block_tables,
|
| 1211 |
+
context_lens=self.seq_lens_tensor_cpu,
|
| 1212 |
+
mla_vheadsize=self.kv_lora_rank,
|
| 1213 |
+
out=attn_output)
|
| 1214 |
+
attn_output_t = torch.transpose(attn_output, 0, 1)
|
| 1215 |
+
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
|
| 1216 |
+
attn_output = torch.transpose(attn_output_t, 0, 1)
|
| 1217 |
+
|
| 1218 |
+
output, _ = self.o_proj(attn_output.reshape(num_tokens, -1))
|
| 1219 |
+
|
| 1220 |
+
return output
|
inference/vllm_ascend/attention/mla_v1.py
ADDED
|
@@ -0,0 +1,1224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import TYPE_CHECKING, Any, Optional, Tuple, Type, TypeVar
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch_npu
|
| 7 |
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
| 8 |
+
AttentionMetadata,
|
| 9 |
+
MLAAttentionImpl)
|
| 10 |
+
from vllm.attention.backends.utils import PAD_SLOT_ID
|
| 11 |
+
from vllm.config import get_current_vllm_config
|
| 12 |
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
| 13 |
+
from vllm.model_executor.layers.linear import (LinearBase,
|
| 14 |
+
UnquantizedLinearMethod)
|
| 15 |
+
from vllm.utils import cdiv, round_down
|
| 16 |
+
|
| 17 |
+
from vllm_ascend.ascend_config import get_ascend_config
|
| 18 |
+
from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV
|
| 19 |
+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
| 20 |
+
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
| 21 |
+
from vllm_ascend.multistream.context import get_multistream_comm_context
|
| 22 |
+
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
| 23 |
+
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
| 24 |
+
from vllm_ascend.utils import npu_prefetch, npu_stream_switch, npu_wait_tensor
|
| 25 |
+
from vllm_ascend.worker.npu_input_batch import InputBatch
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING:
|
| 28 |
+
from vllm.v1.core.sched.output import SchedulerOutput
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class CommonAttentionMetadata:
|
| 33 |
+
"""
|
| 34 |
+
Attention metadata attributes that can be shared by layers in different KV
|
| 35 |
+
cache groups and thus having different block table.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
query_start_loc: torch.Tensor
|
| 39 |
+
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
| 40 |
+
seq_lens: torch.Tensor
|
| 41 |
+
"""(batch_size,), the length of each request including both computed tokens
|
| 42 |
+
and newly scheduled tokens"""
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class AscendMLABackend(AttentionBackend):
|
| 46 |
+
|
| 47 |
+
accept_output_buffer: bool = True
|
| 48 |
+
|
| 49 |
+
@staticmethod
|
| 50 |
+
def get_name() -> str:
|
| 51 |
+
return "VLLM_ASCEND_MLA"
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def get_metadata_cls() -> type["AttentionMetadata"]:
|
| 55 |
+
return AscendMLAMetadata
|
| 56 |
+
|
| 57 |
+
@staticmethod
|
| 58 |
+
def get_builder_cls():
|
| 59 |
+
return AscendMLAMetadataBuilder
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
|
| 63 |
+
head_size: int) -> tuple[int, ...]:
|
| 64 |
+
return (num_blocks, block_size, num_kv_heads, head_size)
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def get_impl_cls() -> Type["MLAAttentionImpl"]:
|
| 68 |
+
return AscendMLAImpl
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class AscendMLAPrefillMetadata:
|
| 73 |
+
""" Prefill Specific Metadata for Ascend"""
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class ChunkedContextMetadata:
|
| 77 |
+
# New for MLA (compared to FlashAttention)
|
| 78 |
+
# For handling chunked prefill
|
| 79 |
+
cu_seq_lens: torch.Tensor
|
| 80 |
+
starts: torch.Tensor
|
| 81 |
+
seq_tot: list[int]
|
| 82 |
+
max_seq_lens: list[int]
|
| 83 |
+
workspace: torch.Tensor
|
| 84 |
+
chunk_seq_lens: torch.Tensor
|
| 85 |
+
|
| 86 |
+
attn_mask: torch.Tensor
|
| 87 |
+
query_lens: list[int]
|
| 88 |
+
seq_lens: list[int]
|
| 89 |
+
context_lens: torch.Tensor
|
| 90 |
+
input_positions: torch.Tensor
|
| 91 |
+
query_start_loc: torch.Tensor
|
| 92 |
+
block_table: torch.Tensor
|
| 93 |
+
max_query_len: int
|
| 94 |
+
max_seq_lens: int
|
| 95 |
+
chunked_context: Optional[ChunkedContextMetadata] = None
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@dataclass
|
| 99 |
+
class AscendMLADecodeMetadata:
|
| 100 |
+
# Input positions for rotrary embeddings since for MLA the rotary
|
| 101 |
+
# position embeddings are applied inside the attention backend
|
| 102 |
+
input_positions: torch.Tensor
|
| 103 |
+
block_table: torch.Tensor
|
| 104 |
+
seq_lens: torch.Tensor
|
| 105 |
+
max_seq_lens: int
|
| 106 |
+
seq_lens_list: list[int]
|
| 107 |
+
attn_mask: Optional[torch.Tensor] = None
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@dataclass
|
| 111 |
+
class AscendMLAMetadata:
|
| 112 |
+
"""Metadata for MLACommon.
|
| 113 |
+
|
| 114 |
+
NOTE: Please read the comment at the top of the file before trying to
|
| 115 |
+
understand this class
|
| 116 |
+
"""
|
| 117 |
+
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
| 118 |
+
# |---------- N-1 iteration --------|
|
| 119 |
+
# |---------------- N iteration ---------------------|
|
| 120 |
+
# |- tokenA -|......................|-- newTokens ---|
|
| 121 |
+
# |---------- context_len ----------|
|
| 122 |
+
# |-------------------- seq_len ---------------------|
|
| 123 |
+
# |-- query_len ---|
|
| 124 |
+
|
| 125 |
+
num_actual_tokens: int # Number of tokens excluding padding.
|
| 126 |
+
slot_mapping: torch.Tensor
|
| 127 |
+
query_start_loc: torch.Tensor
|
| 128 |
+
seq_lens: torch.Tensor
|
| 129 |
+
block_tables: torch.Tensor
|
| 130 |
+
|
| 131 |
+
# New for MLA (compared to FlashAttention)
|
| 132 |
+
# For handling prefill decode split
|
| 133 |
+
num_decodes: int
|
| 134 |
+
num_decode_tokens: int
|
| 135 |
+
num_prefills: int
|
| 136 |
+
|
| 137 |
+
# For logging.
|
| 138 |
+
num_input_tokens: int = 0 # Number of tokens including padding.
|
| 139 |
+
|
| 140 |
+
max_num_tokens_across_dp: int = 0
|
| 141 |
+
with_prefill_across_dp: bool = False
|
| 142 |
+
|
| 143 |
+
query_lens: Optional[list[int]] = None
|
| 144 |
+
# The dimension of the attention heads
|
| 145 |
+
head_dim: Optional[int] = None
|
| 146 |
+
attn_mask: torch.Tensor = None
|
| 147 |
+
# chunked prefill by default if no attn_states passed
|
| 148 |
+
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
| 149 |
+
|
| 150 |
+
decode: Optional[AscendMLADecodeMetadata] = None
|
| 151 |
+
prefill: Optional[AscendMLAPrefillMetadata] = None
|
| 152 |
+
|
| 153 |
+
def __post_init__(self):
|
| 154 |
+
pass
|
| 155 |
+
# supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
|
| 156 |
+
# if self.head_dim is not None and self.head_dim \
|
| 157 |
+
# not in supported_head_sizes:
|
| 158 |
+
# raise ValueError(
|
| 159 |
+
# f"Only {supported_head_sizes} are supported for head_dim,",
|
| 160 |
+
# f"received {self.head_dim}.")
|
| 161 |
+
|
| 162 |
+
def split_metadata_for_multistream(
|
| 163 |
+
self,
|
| 164 |
+
ms_split_config: MSAttentionMetadataSplitConfig,
|
| 165 |
+
) -> list["AscendMLAMetadata"]:
|
| 166 |
+
"""Split metadata for multi-stream with AscendMLAMetadata"""
|
| 167 |
+
return model_input_split_v1_mla_attn(
|
| 168 |
+
ms_split_config=ms_split_config,
|
| 169 |
+
attn_metadata=self,
|
| 170 |
+
_metadata_cls=AscendMLAMetadata,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
M = TypeVar("M", bound=AscendMLAMetadata)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class AscendMLAMetadataBuilder:
|
| 178 |
+
"""
|
| 179 |
+
NOTE: Please read the comment at the top of the file before trying to
|
| 180 |
+
understand this class
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
# _attn_mask_builder = None
|
| 184 |
+
def __init__(self,
|
| 185 |
+
runner,
|
| 186 |
+
metadata_cls: Optional[AscendMLAMetadata] = None):
|
| 187 |
+
self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \
|
| 188 |
+
if metadata_cls is not None else AscendMLAMetadata # type: ignore
|
| 189 |
+
self.runner = runner
|
| 190 |
+
scheduler_config = runner.scheduler_config
|
| 191 |
+
model_config = runner.model_config
|
| 192 |
+
self.block_size = runner.block_size
|
| 193 |
+
self.chunked_prefill_enabled = runner.chunked_prefill_enabled
|
| 194 |
+
if self.chunked_prefill_enabled:
|
| 195 |
+
self.chunked_prefill_workspace_size = min(
|
| 196 |
+
# Max sure there is enough for 8 full length request or at least
|
| 197 |
+
# 4 pages of cache per request
|
| 198 |
+
max(8 * model_config.max_model_len,
|
| 199 |
+
4 * scheduler_config.max_num_seqs * self.block_size),
|
| 200 |
+
# For long-context models try not to over-allocate limiting
|
| 201 |
+
# kv-cache space, limiting it to 64k tokens,
|
| 202 |
+
# which would result in the workspace being:
|
| 203 |
+
# 2*(576)*(64*1024) = 144mb
|
| 204 |
+
# (assuming 576 MLA head dim, and fp16)
|
| 205 |
+
# which would result in up-projected context being
|
| 206 |
+
# 2*(192*128)*(64*1024) = 3gb
|
| 207 |
+
# (assuming 192 QK head dim, 128 heads, and fp16)
|
| 208 |
+
128 * 1024)
|
| 209 |
+
assert self.chunked_prefill_workspace_size >= \
|
| 210 |
+
scheduler_config.max_num_seqs * self.block_size
|
| 211 |
+
self.chunked_prefill_workspace = torch.empty(
|
| 212 |
+
(self.chunked_prefill_workspace_size,
|
| 213 |
+
model_config.get_head_size()),
|
| 214 |
+
dtype=model_config.dtype,
|
| 215 |
+
device=runner.device,
|
| 216 |
+
)
|
| 217 |
+
ascend_config = get_ascend_config()
|
| 218 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 219 |
+
|
| 220 |
+
def reorder_batch(self, input_batch: "InputBatch",
|
| 221 |
+
scheduler_output: "SchedulerOutput") -> bool:
|
| 222 |
+
# We now want to reorder the batch so that the "decode" requests are at
|
| 223 |
+
# the front and the "prefill" requests are at the using the least amount
|
| 224 |
+
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
| 225 |
+
# where attention is likely memory-bound and "prefill" to mean requests
|
| 226 |
+
# where attention is likely compute-bound, TODO(lucas): figure out a
|
| 227 |
+
# better naming here)
|
| 228 |
+
decodes = []
|
| 229 |
+
prefills = []
|
| 230 |
+
num_decode_tokens = 0
|
| 231 |
+
num_prefill_tokens = 0
|
| 232 |
+
|
| 233 |
+
for i, req_id in enumerate(input_batch.req_ids):
|
| 234 |
+
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
| 235 |
+
num_spec_tokens = len(
|
| 236 |
+
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
| 237 |
+
# For torch air graph mode we treat spec decoding as decode.
|
| 238 |
+
if self.torchair_graph_enabled:
|
| 239 |
+
if num_tokens - num_spec_tokens == 1:
|
| 240 |
+
decodes.append(i)
|
| 241 |
+
num_decode_tokens += num_tokens
|
| 242 |
+
else:
|
| 243 |
+
prefills.append(i)
|
| 244 |
+
num_prefill_tokens += num_tokens
|
| 245 |
+
# For eager mode we treat spec decoding as chunked prefill.
|
| 246 |
+
else:
|
| 247 |
+
if num_tokens == 1:
|
| 248 |
+
decodes.append(i)
|
| 249 |
+
num_decode_tokens += num_tokens
|
| 250 |
+
else:
|
| 251 |
+
prefills.append(i)
|
| 252 |
+
num_prefill_tokens += num_tokens
|
| 253 |
+
|
| 254 |
+
# We hope that this is fairly minimal since decodes
|
| 255 |
+
# should be around for a number of iterations so hopefully they are
|
| 256 |
+
# relatively stationary (and new request are generally appended to the
|
| 257 |
+
# persistent batch so already should be at the back)
|
| 258 |
+
# To achieve this we loop over the decodes in descending order and
|
| 259 |
+
# the prefills in ascending order. We swap decodes from the "back"
|
| 260 |
+
# i.e. past where the last decode should be in the reodorered with
|
| 261 |
+
# prefills from the front of the batch.
|
| 262 |
+
# `decodes` and `prefills` are already in ascending order just based on
|
| 263 |
+
# the above loop
|
| 264 |
+
num_decodes = len(decodes)
|
| 265 |
+
num_prefills = len(prefills)
|
| 266 |
+
first_prefill = 0
|
| 267 |
+
modified_batch = False
|
| 268 |
+
|
| 269 |
+
for i in range(1, min(num_decodes, num_prefills) + 1):
|
| 270 |
+
# If the decode is at the "back" of the batch, i, we can swap it
|
| 271 |
+
# with the prefill closest to the front of the batch
|
| 272 |
+
if decodes[num_decodes - i] >= num_decodes:
|
| 273 |
+
input_batch.swap_states(prefills[first_prefill],
|
| 274 |
+
decodes[num_decodes - i])
|
| 275 |
+
first_prefill += 1
|
| 276 |
+
modified_batch = True
|
| 277 |
+
else:
|
| 278 |
+
break
|
| 279 |
+
|
| 280 |
+
# Save for next `build` call
|
| 281 |
+
# TODO(lucas): this is a bit of a hack, we should probably have a
|
| 282 |
+
# better way of doing this
|
| 283 |
+
self._num_decodes = num_decodes
|
| 284 |
+
self._num_prefills = num_prefills
|
| 285 |
+
self._num_decode_tokens = num_decode_tokens
|
| 286 |
+
self._num_prefill_tokens = num_prefill_tokens
|
| 287 |
+
|
| 288 |
+
return modified_batch
|
| 289 |
+
|
| 290 |
+
def _get_graph_runner_block_tables(
|
| 291 |
+
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
|
| 292 |
+
|
| 293 |
+
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
| 294 |
+
assert max_batch_size >= num_seqs
|
| 295 |
+
|
| 296 |
+
if isinstance(self.runner.graph_block_tables, np.ndarray):
|
| 297 |
+
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
|
| 298 |
+
dtype=block_tables.dtype,
|
| 299 |
+
device=block_tables.device)
|
| 300 |
+
else:
|
| 301 |
+
graph_block_tables = self.runner.graph_block_tables.to(
|
| 302 |
+
device=block_tables.device, dtype=block_tables.dtype)
|
| 303 |
+
|
| 304 |
+
num_blocks = block_tables.size(1)
|
| 305 |
+
if num_blocks <= max_blocks:
|
| 306 |
+
graph_block_tables[:num_seqs, :
|
| 307 |
+
num_blocks] = block_tables[:num_seqs, :
|
| 308 |
+
num_blocks]
|
| 309 |
+
else:
|
| 310 |
+
graph_block_tables[:num_seqs, :
|
| 311 |
+
max_blocks] = block_tables[:num_seqs, :
|
| 312 |
+
max_blocks]
|
| 313 |
+
|
| 314 |
+
return graph_block_tables[:num_seqs, :max_blocks]
|
| 315 |
+
|
| 316 |
+
def build_dummy(self, num_reqs: int,
|
| 317 |
+
num_actual_tokens: int) -> AscendMLAMetadata:
|
| 318 |
+
device = self.runner.device
|
| 319 |
+
_, max_blocks = self.runner.graph_block_tables.shape
|
| 320 |
+
block_table = torch.zeros((num_reqs, max_blocks),
|
| 321 |
+
dtype=torch.int32,
|
| 322 |
+
device=device)
|
| 323 |
+
block_table = self._get_graph_runner_block_tables(
|
| 324 |
+
num_reqs, block_table)
|
| 325 |
+
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
|
| 326 |
+
input_positions = torch.zeros(num_reqs,
|
| 327 |
+
dtype=torch.int32,
|
| 328 |
+
device=device).long()
|
| 329 |
+
slot_mapping = torch.full((num_reqs, ),
|
| 330 |
+
PAD_SLOT_ID,
|
| 331 |
+
dtype=torch.int32,
|
| 332 |
+
device=device)
|
| 333 |
+
query_start_loc = torch.full((num_reqs, ),
|
| 334 |
+
-1,
|
| 335 |
+
dtype=torch.int32,
|
| 336 |
+
device=device)
|
| 337 |
+
decode_metadata = AscendMLADecodeMetadata(
|
| 338 |
+
input_positions=input_positions,
|
| 339 |
+
block_table=block_table,
|
| 340 |
+
seq_lens=seq_lens,
|
| 341 |
+
seq_lens_list=seq_lens.tolist(),
|
| 342 |
+
max_seq_lens=1,
|
| 343 |
+
attn_mask=self.runner.spec_attn_mask)
|
| 344 |
+
return self.metadata_cls( # type: ignore
|
| 345 |
+
num_input_tokens=num_actual_tokens,
|
| 346 |
+
num_actual_tokens=num_actual_tokens,
|
| 347 |
+
slot_mapping=slot_mapping,
|
| 348 |
+
head_dim=self.runner.model_config.get_head_size(),
|
| 349 |
+
num_decodes=1,
|
| 350 |
+
num_decode_tokens=1,
|
| 351 |
+
num_prefills=0,
|
| 352 |
+
attn_mask=self.runner.attn_mask,
|
| 353 |
+
attn_state=AscendAttentionState.DecodeOnly,
|
| 354 |
+
prefill=None,
|
| 355 |
+
decode=decode_metadata,
|
| 356 |
+
query_start_loc=query_start_loc,
|
| 357 |
+
seq_lens=seq_lens,
|
| 358 |
+
block_tables=block_table,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
def build(
|
| 362 |
+
self,
|
| 363 |
+
num_reqs: int,
|
| 364 |
+
num_actual_tokens: int,
|
| 365 |
+
max_query_len: int,
|
| 366 |
+
common_attn_metadata: CommonAttentionMetadata,
|
| 367 |
+
common_prefix_len: Optional[int] = None,
|
| 368 |
+
graph_pad_size: int = -1,
|
| 369 |
+
max_num_tokens_across_dp: int = 0,
|
| 370 |
+
with_prefill_across_dp: bool = False,
|
| 371 |
+
) -> AscendMLAMetadata:
|
| 372 |
+
assert self._num_decodes + self._num_prefills == num_reqs
|
| 373 |
+
|
| 374 |
+
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
| 375 |
+
# function. We should avoid GPU -> CPU sync as much as possible because
|
| 376 |
+
# it blocks on all previous kernels.
|
| 377 |
+
device = self.runner.device
|
| 378 |
+
|
| 379 |
+
block_table = (self.runner.input_batch.block_table[0].
|
| 380 |
+
get_device_tensor()[:num_reqs])
|
| 381 |
+
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
| 382 |
+
device, non_blocking=True)
|
| 383 |
+
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
|
| 384 |
+
device, non_blocking=True).long()
|
| 385 |
+
|
| 386 |
+
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
|
| 387 |
+
query_lens = seq_lens_cpu - self.runner.input_batch.num_computed_tokens_cpu_tensor[:
|
| 388 |
+
num_reqs]
|
| 389 |
+
seq_lens = seq_lens_cpu
|
| 390 |
+
max_query_len = query_lens.max().item()
|
| 391 |
+
max_seq_lens = seq_lens.max().item()
|
| 392 |
+
query_start_loc = common_attn_metadata.query_start_loc
|
| 393 |
+
|
| 394 |
+
prefill_metadata = None
|
| 395 |
+
chunked_context_metadata = None
|
| 396 |
+
if self._num_prefills > 0:
|
| 397 |
+
reqs_start = self._num_decodes # prefill_start
|
| 398 |
+
tokens_start = self._num_decode_tokens
|
| 399 |
+
max_query_len = query_lens[tokens_start:].max().item()
|
| 400 |
+
max_seq_lens = seq_lens[tokens_start:].max().item()
|
| 401 |
+
query_start_loc = common_attn_metadata.query_start_loc
|
| 402 |
+
prefill_query_start_loc = query_start_loc[
|
| 403 |
+
reqs_start:] - query_start_loc[reqs_start]
|
| 404 |
+
|
| 405 |
+
context_lens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[
|
| 406 |
+
reqs_start:num_reqs]
|
| 407 |
+
max_context_len_cpu = context_lens_cpu.max().item()
|
| 408 |
+
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
| 409 |
+
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
| 410 |
+
max_context_chunk = (self.chunked_prefill_workspace_size //
|
| 411 |
+
num_prefills_with_context_cpu)
|
| 412 |
+
max_context_chunk = round_down(max_context_chunk,
|
| 413 |
+
self.block_size)
|
| 414 |
+
|
| 415 |
+
assert max_context_chunk > 0
|
| 416 |
+
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
|
| 417 |
+
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
|
| 418 |
+
.unsqueeze(1).expand(-1, self._num_prefills) * max_context_chunk
|
| 419 |
+
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
| 420 |
+
chunk_starts + max_context_chunk)
|
| 421 |
+
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
| 422 |
+
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
| 423 |
+
self._num_prefills + 1,
|
| 424 |
+
dtype=torch.int32,
|
| 425 |
+
pin_memory=True)
|
| 426 |
+
torch.cumsum(chunk_seq_lens,
|
| 427 |
+
dim=1,
|
| 428 |
+
out=cu_seq_lens_cpu[:, 1:],
|
| 429 |
+
dtype=torch.int32)
|
| 430 |
+
chunked_context_metadata = \
|
| 431 |
+
AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
| 432 |
+
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
| 433 |
+
starts=chunk_starts.to(device, non_blocking=True),
|
| 434 |
+
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
| 435 |
+
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
| 436 |
+
chunk_seq_lens=chunk_seq_lens,
|
| 437 |
+
workspace=self.chunked_prefill_workspace,
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
prefill_metadata = AscendMLAPrefillMetadata(
|
| 441 |
+
attn_mask=self.runner.attn_mask,
|
| 442 |
+
query_lens=query_lens[tokens_start:],
|
| 443 |
+
seq_lens=seq_lens,
|
| 444 |
+
context_lens=seq_lens[tokens_start:],
|
| 445 |
+
input_positions=input_positions[tokens_start:],
|
| 446 |
+
block_table=block_table[reqs_start:, ...],
|
| 447 |
+
max_query_len=max_query_len,
|
| 448 |
+
max_seq_lens=max_seq_lens,
|
| 449 |
+
query_start_loc=prefill_query_start_loc,
|
| 450 |
+
chunked_context=chunked_context_metadata,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
decode_metadata = None
|
| 454 |
+
use_torchair_graph = graph_pad_size != -1
|
| 455 |
+
if self._num_decodes > 0:
|
| 456 |
+
max_seq_lens = seq_lens[:self._num_decodes].max().item()
|
| 457 |
+
seq_lens = seq_lens[:self._num_decode_tokens]
|
| 458 |
+
input_positions = input_positions[:self._num_decode_tokens]
|
| 459 |
+
block_table = block_table[:self._num_decode_tokens, ...]
|
| 460 |
+
if use_torchair_graph and self.runner.attn_state in [
|
| 461 |
+
AscendAttentionState.DecodeOnly,
|
| 462 |
+
AscendAttentionState.SpecDecoding
|
| 463 |
+
]:
|
| 464 |
+
num_seqs = len(seq_lens)
|
| 465 |
+
if graph_pad_size != 0:
|
| 466 |
+
pad_value = 1
|
| 467 |
+
padded_seq_lens = seq_lens.tolist() + [pad_value
|
| 468 |
+
] * graph_pad_size
|
| 469 |
+
else:
|
| 470 |
+
padded_seq_lens = seq_lens.tolist()
|
| 471 |
+
|
| 472 |
+
seq_lens = torch.from_numpy(
|
| 473 |
+
np.array(padded_seq_lens).astype(np.int32))
|
| 474 |
+
padding = torch.full((graph_pad_size, ),
|
| 475 |
+
PAD_SLOT_ID,
|
| 476 |
+
dtype=slot_mapping.dtype,
|
| 477 |
+
device=slot_mapping.device)
|
| 478 |
+
slot_mapping = torch.cat([slot_mapping, padding])
|
| 479 |
+
block_table_padding = torch.zeros(
|
| 480 |
+
(graph_pad_size, ) + block_table.shape[1:],
|
| 481 |
+
dtype=block_table.dtype,
|
| 482 |
+
device=block_table.device)
|
| 483 |
+
block_table = torch.cat([block_table, block_table_padding],
|
| 484 |
+
dim=0)
|
| 485 |
+
block_table = self._get_graph_runner_block_tables(
|
| 486 |
+
num_seqs + graph_pad_size, block_table)
|
| 487 |
+
padding_0 = torch.zeros(graph_pad_size,
|
| 488 |
+
dtype=input_positions.dtype,
|
| 489 |
+
device=input_positions.device)
|
| 490 |
+
input_positions = torch.cat([input_positions, padding_0])
|
| 491 |
+
|
| 492 |
+
decode_metadata = AscendMLADecodeMetadata(
|
| 493 |
+
input_positions=input_positions,
|
| 494 |
+
block_table=block_table,
|
| 495 |
+
seq_lens=seq_lens,
|
| 496 |
+
seq_lens_list=seq_lens.tolist(),
|
| 497 |
+
max_seq_lens=max_seq_lens,
|
| 498 |
+
attn_mask=self.runner.spec_attn_mask)
|
| 499 |
+
|
| 500 |
+
return self.metadata_cls( # type: ignore
|
| 501 |
+
num_actual_tokens=num_actual_tokens,
|
| 502 |
+
query_lens=query_lens.tolist(),
|
| 503 |
+
slot_mapping=slot_mapping,
|
| 504 |
+
head_dim=self.runner.model_config.get_head_size(),
|
| 505 |
+
num_decodes=self._num_decodes,
|
| 506 |
+
num_decode_tokens=self._num_decode_tokens,
|
| 507 |
+
num_prefills=self._num_prefills,
|
| 508 |
+
attn_mask=self.runner.attn_mask,
|
| 509 |
+
attn_state=self.runner.attn_state,
|
| 510 |
+
prefill=prefill_metadata,
|
| 511 |
+
decode=decode_metadata,
|
| 512 |
+
query_start_loc=query_start_loc,
|
| 513 |
+
block_tables=block_table,
|
| 514 |
+
seq_lens=seq_lens,
|
| 515 |
+
max_num_tokens_across_dp=max_num_tokens_across_dp,
|
| 516 |
+
with_prefill_across_dp=with_prefill_across_dp,
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class AscendMLAImpl(MLAAttentionImpl):
|
| 521 |
+
"""
|
| 522 |
+
NOTE: Please read the comment at the top of the file before trying to
|
| 523 |
+
understand this class
|
| 524 |
+
"""
|
| 525 |
+
|
| 526 |
+
def __init__(
|
| 527 |
+
self,
|
| 528 |
+
num_heads: int,
|
| 529 |
+
head_size: int,
|
| 530 |
+
scale: float,
|
| 531 |
+
num_kv_heads: int,
|
| 532 |
+
alibi_slopes: Optional[list[float]],
|
| 533 |
+
sliding_window: Optional[int],
|
| 534 |
+
kv_cache_dtype: str,
|
| 535 |
+
blocksparse_params: Optional[dict[str, Any]],
|
| 536 |
+
logits_soft_cap: Optional[float],
|
| 537 |
+
attn_type: str,
|
| 538 |
+
kv_sharing_target_layer_name: Optional[str] = None,
|
| 539 |
+
**kwargs,
|
| 540 |
+
) -> None:
|
| 541 |
+
self.num_heads = num_heads
|
| 542 |
+
self.head_size = head_size
|
| 543 |
+
self.scale = float(scale)
|
| 544 |
+
self.num_kv_heads = num_kv_heads
|
| 545 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 546 |
+
|
| 547 |
+
# MLA Args
|
| 548 |
+
self.q_lora_rank = kwargs['q_lora_rank']
|
| 549 |
+
self.kv_lora_rank = kwargs['kv_lora_rank']
|
| 550 |
+
self.qk_nope_head_dim = kwargs['qk_nope_head_dim']
|
| 551 |
+
self.qk_rope_head_dim = kwargs['qk_rope_head_dim']
|
| 552 |
+
self.qk_head_dim = kwargs['qk_head_dim']
|
| 553 |
+
self.v_head_dim = kwargs['v_head_dim']
|
| 554 |
+
self.rotary_emb = kwargs['rotary_emb']
|
| 555 |
+
self.q_proj = kwargs['q_proj']
|
| 556 |
+
self.kv_b_proj = kwargs['kv_b_proj']
|
| 557 |
+
self.o_proj = kwargs['o_proj']
|
| 558 |
+
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
| 559 |
+
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
| 560 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
| 561 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 562 |
+
|
| 563 |
+
ascend_config = get_ascend_config()
|
| 564 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 565 |
+
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
| 566 |
+
|
| 567 |
+
# Adapt torch air graph mode with spec decoding.
|
| 568 |
+
speculative_config = get_current_vllm_config().speculative_config
|
| 569 |
+
if speculative_config is not None:
|
| 570 |
+
self.spec_token_num = speculative_config.num_speculative_tokens
|
| 571 |
+
assert self.spec_token_num > 0
|
| 572 |
+
self.SHARE_MASK_TRIL_SPARSE = ~torch.tril(torch.ones((2048, 2048), dtype=torch.bool)).npu()
|
| 573 |
+
|
| 574 |
+
def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False):
|
| 575 |
+
# Convert from (B, N, L) to (N, B, L)
|
| 576 |
+
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
| 577 |
+
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
| 578 |
+
x = torch.bmm(x, self.W_UV)
|
| 579 |
+
# Convert from (N, B, V) to (B, N * V)
|
| 580 |
+
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
| 581 |
+
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
|
| 582 |
+
npu_prefetch(self.o_proj.weight,
|
| 583 |
+
x,
|
| 584 |
+
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
| 585 |
+
enabled=enable_multistream_mla)
|
| 586 |
+
return self.o_proj(x, is_prefill=False)[0]
|
| 587 |
+
|
| 588 |
+
# Return `ql_nope`, `q_pe`
|
| 589 |
+
def _q_proj_and_k_up_proj(self, x):
|
| 590 |
+
q_nope, q_pe = self.q_proj(x)[0]\
|
| 591 |
+
.view(-1, self.num_heads, self.qk_head_dim)\
|
| 592 |
+
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
| 593 |
+
|
| 594 |
+
# Convert from (B, N, P) to (N, B, P)
|
| 595 |
+
q_nope = q_nope.transpose(0, 1)
|
| 596 |
+
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
| 597 |
+
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
| 598 |
+
# Convert from (N, B, L) to (B, N, L)
|
| 599 |
+
return ql_nope.transpose(0, 1), q_pe
|
| 600 |
+
|
| 601 |
+
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
| 602 |
+
|
| 603 |
+
def get_layer_weight(layer):
|
| 604 |
+
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
| 605 |
+
for attr in WEIGHT_NAMES:
|
| 606 |
+
if hasattr(layer, attr):
|
| 607 |
+
return getattr(layer, attr)
|
| 608 |
+
raise AttributeError(
|
| 609 |
+
f"Layer '{layer}' has no recognized weight attribute:"
|
| 610 |
+
f" {WEIGHT_NAMES}.")
|
| 611 |
+
|
| 612 |
+
def get_and_maybe_dequant_weights(layer: LinearBase):
|
| 613 |
+
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
| 614 |
+
# NOTE: This should only be used offline, since it's O(N^3)
|
| 615 |
+
eye = torch.eye(layer.input_size_per_partition,
|
| 616 |
+
dtype=act_dtype,
|
| 617 |
+
device=get_layer_weight(layer).device)
|
| 618 |
+
dequant_weights = layer.quant_method.apply(layer,
|
| 619 |
+
eye,
|
| 620 |
+
bias=None)
|
| 621 |
+
del eye
|
| 622 |
+
# standardize to (output, input)
|
| 623 |
+
return dequant_weights.T
|
| 624 |
+
return layer.weight
|
| 625 |
+
|
| 626 |
+
# we currently do not have quantized bmm's which are needed for
|
| 627 |
+
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
|
| 628 |
+
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
| 629 |
+
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
| 630 |
+
assert kv_b_proj_weight.shape == (
|
| 631 |
+
self.kv_lora_rank,
|
| 632 |
+
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
| 633 |
+
f"{kv_b_proj_weight.shape=}, "
|
| 634 |
+
f"{self.kv_lora_rank=}, "
|
| 635 |
+
f"{self.num_heads=}, "
|
| 636 |
+
f"{self.qk_nope_head_dim=}, "
|
| 637 |
+
f"{self.v_head_dim=}")
|
| 638 |
+
kv_b_proj_weight = kv_b_proj_weight.view(
|
| 639 |
+
self.kv_lora_rank,
|
| 640 |
+
self.num_heads,
|
| 641 |
+
self.qk_nope_head_dim + self.v_head_dim,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
W_UK, W_UV = kv_b_proj_weight.split(
|
| 645 |
+
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 646 |
+
|
| 647 |
+
# Convert from (L, N, V) to (N, L, V)
|
| 648 |
+
self.W_UV = W_UV.transpose(0, 1).contiguous()
|
| 649 |
+
# Convert from (L, N, P) to (N, P, L)
|
| 650 |
+
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
|
| 651 |
+
|
| 652 |
+
# Waiting for BMM NZ support
|
| 653 |
+
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
| 654 |
+
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
| 655 |
+
|
| 656 |
+
def _compute_prefill_context(
|
| 657 |
+
self,
|
| 658 |
+
query: torch.Tensor,
|
| 659 |
+
kv_c_and_k_pe_cache: torch.Tensor,
|
| 660 |
+
rope_dim: int,
|
| 661 |
+
attn_metadata: AscendMLAMetadata,
|
| 662 |
+
prefix_output: torch.Tensor,
|
| 663 |
+
prefix_lse: torch.Tensor,
|
| 664 |
+
):
|
| 665 |
+
prefill_metadata = attn_metadata.prefill
|
| 666 |
+
if prefill_metadata is None or prefill_metadata.chunked_context is None:
|
| 667 |
+
return prefix_output, prefix_lse
|
| 668 |
+
|
| 669 |
+
iters = len(prefill_metadata.chunked_context.seq_tot)
|
| 670 |
+
q_pe = query[..., self.qk_nope_head_dim:]
|
| 671 |
+
q_nope = query[..., :self.qk_nope_head_dim]
|
| 672 |
+
|
| 673 |
+
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
|
| 674 |
+
latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim
|
| 675 |
+
cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim]
|
| 676 |
+
cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:]
|
| 677 |
+
for i in range(iters):
|
| 678 |
+
toks = prefill_metadata.chunked_context.seq_tot[i]
|
| 679 |
+
|
| 680 |
+
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
|
| 681 |
+
seq_len = torch.stack([seq_len1, seq_len2])
|
| 682 |
+
kv_c_normed = torch.empty(toks,
|
| 683 |
+
kv_c_and_k_pe_cache.size(2),
|
| 684 |
+
latent_kv_dim,
|
| 685 |
+
dtype=query.dtype,
|
| 686 |
+
device=query.device)
|
| 687 |
+
k_pe = torch.empty(toks,
|
| 688 |
+
kv_c_and_k_pe_cache.size(2),
|
| 689 |
+
rope_dim,
|
| 690 |
+
dtype=query.dtype,
|
| 691 |
+
device=query.device)
|
| 692 |
+
|
| 693 |
+
torch_npu.atb.npu_paged_cache_load(
|
| 694 |
+
cache_kv_c,
|
| 695 |
+
cache_k_pe,
|
| 696 |
+
prefill_metadata.block_table,
|
| 697 |
+
seq_len2.to(query.device),
|
| 698 |
+
seq_starts=prefill_metadata.chunked_context.starts[i],
|
| 699 |
+
key=kv_c_normed,
|
| 700 |
+
value=k_pe,
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
kv_c_normed = kv_c_normed.squeeze()
|
| 704 |
+
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
|
| 705 |
+
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
| 706 |
+
k_nope, v = kv_nope\
|
| 707 |
+
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 708 |
+
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
| 709 |
+
mask = torch.triu(
|
| 710 |
+
torch.ones(512, 512, device=query.device, dtype=query.dtype),
|
| 711 |
+
1)
|
| 712 |
+
torch_npu.atb.npu_ring_mla(
|
| 713 |
+
q_nope=q_nope,
|
| 714 |
+
q_rope=q_pe,
|
| 715 |
+
k_nope=k_nope,
|
| 716 |
+
k_rope=k_pe,
|
| 717 |
+
value=v,
|
| 718 |
+
mask=mask,
|
| 719 |
+
seqlen=seq_len,
|
| 720 |
+
head_num=self.num_heads,
|
| 721 |
+
kv_head_num=self.num_heads,
|
| 722 |
+
pre_out=prefix_output,
|
| 723 |
+
prev_lse=prefix_lse,
|
| 724 |
+
qk_scale=self.scale,
|
| 725 |
+
kernel_type="kernel_type_high_precision",
|
| 726 |
+
mask_type="no_mask",
|
| 727 |
+
input_layout="type_bsnd",
|
| 728 |
+
calc_type="calc_type_default",
|
| 729 |
+
output=prefix_output,
|
| 730 |
+
softmax_lse=prefix_lse)
|
| 731 |
+
return prefix_output, prefix_lse
|
| 732 |
+
|
| 733 |
+
def _forward_prefill(
|
| 734 |
+
self,
|
| 735 |
+
query: torch.Tensor,
|
| 736 |
+
kv_c_normed: torch.Tensor,
|
| 737 |
+
k_pe: torch.Tensor,
|
| 738 |
+
kv_c_and_k_pe_cache: torch.Tensor,
|
| 739 |
+
attn_metadata: AscendMLAMetadata,
|
| 740 |
+
) -> torch.Tensor:
|
| 741 |
+
assert attn_metadata.prefill is not None
|
| 742 |
+
|
| 743 |
+
num_tokens = query.size(0)
|
| 744 |
+
attn_output = torch.empty(num_tokens,
|
| 745 |
+
self.num_heads,
|
| 746 |
+
self.v_head_dim,
|
| 747 |
+
dtype=query.dtype,
|
| 748 |
+
device=query.device)
|
| 749 |
+
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
|
| 750 |
+
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
|
| 751 |
+
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
| 752 |
+
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
| 753 |
+
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
|
| 754 |
+
ascend_config = get_ascend_config()
|
| 755 |
+
|
| 756 |
+
if attn_metadata.attn_state in [
|
| 757 |
+
AscendAttentionState.ChunkedPrefill,
|
| 758 |
+
AscendAttentionState.SpecDecoding,
|
| 759 |
+
AscendAttentionState.PrefillCacheHit
|
| 760 |
+
] and not ascend_config.chunked_prefill_for_mla:
|
| 761 |
+
attn_output_torch = torch.empty(num_tokens,
|
| 762 |
+
self.num_heads * self.v_head_dim,
|
| 763 |
+
dtype=query.dtype,
|
| 764 |
+
device=query.device)
|
| 765 |
+
# current requests is chunked in prefill, disable flash attention with chunked prefill
|
| 766 |
+
vanilla_chunked_prefill_mla(
|
| 767 |
+
output=attn_output_torch,
|
| 768 |
+
query=query,
|
| 769 |
+
kv_cache=kv_c_and_k_pe_cache,
|
| 770 |
+
block_tables=attn_metadata.prefill.block_table,
|
| 771 |
+
query_lens=attn_metadata.prefill.query_lens,
|
| 772 |
+
context_lens=attn_metadata.prefill.context_lens,
|
| 773 |
+
kv_b_proj=self.kv_b_proj,
|
| 774 |
+
max_query_len=attn_metadata.prefill.max_query_len,
|
| 775 |
+
max_context_len=attn_metadata.prefill.max_seq_lens,
|
| 776 |
+
nope_dim=self.qk_nope_head_dim,
|
| 777 |
+
rope_dim=self.qk_rope_head_dim,
|
| 778 |
+
v_head_dim=self.v_head_dim,
|
| 779 |
+
scale=self.scale,
|
| 780 |
+
alibi_slopes=None,
|
| 781 |
+
causal=True)
|
| 782 |
+
elif attn_metadata.attn_state in [
|
| 783 |
+
AscendAttentionState.ChunkedPrefill,
|
| 784 |
+
AscendAttentionState.SpecDecoding,
|
| 785 |
+
AscendAttentionState.PrefillCacheHit
|
| 786 |
+
]:
|
| 787 |
+
attn_lse = torch.empty(self.num_heads,
|
| 788 |
+
num_tokens,
|
| 789 |
+
dtype=torch.float32,
|
| 790 |
+
device=query.device)
|
| 791 |
+
q_pe = query[..., self.qk_nope_head_dim:]
|
| 792 |
+
q_nope = query[..., :self.qk_nope_head_dim]
|
| 793 |
+
mask = torch.triu(
|
| 794 |
+
torch.ones(512, 512, device=query.device, dtype=query.dtype),
|
| 795 |
+
1) # 512: mask only support 512
|
| 796 |
+
if attn_metadata.num_prefills > 1:
|
| 797 |
+
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
|
| 798 |
+
1)
|
| 799 |
+
torch_npu.atb.npu_ring_mla(
|
| 800 |
+
q_nope=q_nope,
|
| 801 |
+
q_rope=q_pe,
|
| 802 |
+
k_nope=k_nope,
|
| 803 |
+
k_rope=k_pe,
|
| 804 |
+
value=value,
|
| 805 |
+
mask=mask,
|
| 806 |
+
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
|
| 807 |
+
dtype=torch.int32),
|
| 808 |
+
head_num=self.num_heads,
|
| 809 |
+
kv_head_num=self.num_heads,
|
| 810 |
+
pre_out=None,
|
| 811 |
+
prev_lse=None,
|
| 812 |
+
qk_scale=self.scale,
|
| 813 |
+
kernel_type="kernel_type_high_precision",
|
| 814 |
+
mask_type="mask_type_triu",
|
| 815 |
+
input_layout="type_bsnd",
|
| 816 |
+
calc_type="calc_type_first_ring",
|
| 817 |
+
output=attn_output,
|
| 818 |
+
softmax_lse=attn_lse)
|
| 819 |
+
attn_output, attn_lse = self._compute_prefill_context( \
|
| 820 |
+
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
| 821 |
+
|
| 822 |
+
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
| 823 |
+
key = torch.cat((k_nope, k_pe), dim=-1)
|
| 824 |
+
context_lens_list = torch.cumsum(attn_metadata.prefill.context_lens, dim=0).tolist()
|
| 825 |
+
attn_output = torch_npu.npu_fused_infer_attention_score(
|
| 826 |
+
query,
|
| 827 |
+
key,
|
| 828 |
+
value,
|
| 829 |
+
num_heads=self.num_heads,
|
| 830 |
+
input_layout="TND",
|
| 831 |
+
scale=self.scale,
|
| 832 |
+
sparse_mode=3,
|
| 833 |
+
atten_mask=self.SHARE_MASK_TRIL_SPARSE,
|
| 834 |
+
actual_seq_lengths=context_lens_list,
|
| 835 |
+
actual_seq_lengths_kv=context_lens_list,
|
| 836 |
+
inner_precise=0)[0]
|
| 837 |
+
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
|
| 838 |
+
else:
|
| 839 |
+
raise RuntimeError(
|
| 840 |
+
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
|
| 841 |
+
)
|
| 842 |
+
attn_output = attn_output.reshape(
|
| 843 |
+
[num_tokens, self.num_heads * self.v_head_dim])
|
| 844 |
+
if attn_metadata.attn_state in [
|
| 845 |
+
AscendAttentionState.ChunkedPrefill,
|
| 846 |
+
AscendAttentionState.SpecDecoding,
|
| 847 |
+
AscendAttentionState.PrefillCacheHit
|
| 848 |
+
] and not ascend_config.chunked_prefill_for_mla:
|
| 849 |
+
attn_output = attn_output_torch
|
| 850 |
+
|
| 851 |
+
current_ms_metadata = get_multistream_comm_context()
|
| 852 |
+
if current_ms_metadata is None:
|
| 853 |
+
return self.o_proj(attn_output, is_prefill=True)[0]
|
| 854 |
+
else:
|
| 855 |
+
current_ms_metadata.before_comm_event.record()
|
| 856 |
+
with torch.npu.stream(current_ms_metadata.comm_stream):
|
| 857 |
+
current_ms_metadata.before_comm_event.wait()
|
| 858 |
+
return self.o_proj(attn_output, is_prefill=True)[0]
|
| 859 |
+
|
| 860 |
+
def exec_kv(
|
| 861 |
+
self,
|
| 862 |
+
hidden_states: torch.Tensor,
|
| 863 |
+
cos: torch.Tensor,
|
| 864 |
+
sin: torch.Tensor,
|
| 865 |
+
kv_cache: Tuple,
|
| 866 |
+
slots: torch.Tensor,
|
| 867 |
+
):
|
| 868 |
+
|
| 869 |
+
B = hidden_states.shape[0]
|
| 870 |
+
N = self.num_kv_heads
|
| 871 |
+
S = 1
|
| 872 |
+
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
| 873 |
+
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
| 874 |
+
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
| 875 |
+
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
| 876 |
+
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
| 877 |
+
kv,
|
| 878 |
+
self.kv_a_layernorm.weight,
|
| 879 |
+
cos,
|
| 880 |
+
sin,
|
| 881 |
+
slots.to(torch.int64),
|
| 882 |
+
kv_cache[1],
|
| 883 |
+
kv_cache[0],
|
| 884 |
+
epsilon=self.kv_a_layernorm.variance_epsilon,
|
| 885 |
+
cache_mode=cache_mode,
|
| 886 |
+
)
|
| 887 |
+
return k_pe, k_nope, kv
|
| 888 |
+
|
| 889 |
+
def exec_kv_prefill(
|
| 890 |
+
self,
|
| 891 |
+
hidden_states: torch.Tensor,
|
| 892 |
+
cos: torch.Tensor,
|
| 893 |
+
sin: torch.Tensor,
|
| 894 |
+
kv_cache: Tuple,
|
| 895 |
+
slots: torch.Tensor,
|
| 896 |
+
):
|
| 897 |
+
|
| 898 |
+
B = hidden_states.shape[0]
|
| 899 |
+
N = self.num_kv_heads
|
| 900 |
+
S = 1
|
| 901 |
+
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
| 902 |
+
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
| 903 |
+
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
| 904 |
+
cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA"
|
| 905 |
+
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
|
| 906 |
+
kv,
|
| 907 |
+
self.kv_a_layernorm.weight,
|
| 908 |
+
cos,
|
| 909 |
+
sin,
|
| 910 |
+
slots.to(torch.int64),
|
| 911 |
+
kv_cache[1],
|
| 912 |
+
kv_cache[0],
|
| 913 |
+
epsilon=self.kv_a_layernorm.variance_epsilon,
|
| 914 |
+
cache_mode=cache_mode,
|
| 915 |
+
is_output_kv=True,
|
| 916 |
+
)
|
| 917 |
+
return k_pe, k_nope
|
| 918 |
+
|
| 919 |
+
def rope_single(
|
| 920 |
+
self,
|
| 921 |
+
x: torch.Tensor,
|
| 922 |
+
cos: torch.Tensor,
|
| 923 |
+
sin: torch.Tensor,
|
| 924 |
+
) -> torch.Tensor:
|
| 925 |
+
B, N, D = x.shape
|
| 926 |
+
S = 1
|
| 927 |
+
x = x.view(B, N, S, D)
|
| 928 |
+
x = torch_npu.npu_interleave_rope(x, cos, sin)
|
| 929 |
+
return x.view(B, N, D)
|
| 930 |
+
|
| 931 |
+
def _forward_decode(
|
| 932 |
+
self,
|
| 933 |
+
q_nope: torch.Tensor,
|
| 934 |
+
q_pe: torch.Tensor,
|
| 935 |
+
k_nope: torch.Tensor,
|
| 936 |
+
k_pe: torch.Tensor,
|
| 937 |
+
kv_c_and_k_pe_cache: torch.Tensor,
|
| 938 |
+
attn_metadata: AscendMLAMetadata,
|
| 939 |
+
enable_multistream_mla: bool = False,
|
| 940 |
+
) -> torch.Tensor:
|
| 941 |
+
decode_meta = attn_metadata.decode
|
| 942 |
+
assert decode_meta is not None
|
| 943 |
+
|
| 944 |
+
q = torch.cat([q_nope, q_pe], dim=-1)
|
| 945 |
+
num_tokens = q.size(0)
|
| 946 |
+
attn_output = torch.empty(
|
| 947 |
+
[num_tokens, self.num_heads, self.kv_lora_rank],
|
| 948 |
+
dtype=q.dtype,
|
| 949 |
+
device=q.device)
|
| 950 |
+
if self.running_in_graph:
|
| 951 |
+
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
|
| 952 |
+
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
| 953 |
+
assert num_tokens % self.spec_token_num == 0
|
| 954 |
+
q_nope = q_nope.view(num_tokens // (self.spec_token_num + 1),
|
| 955 |
+
self.spec_token_num + 1, self.num_heads,
|
| 956 |
+
-1)
|
| 957 |
+
q_pe = q_pe.view(num_tokens // (self.spec_token_num + 1),
|
| 958 |
+
self.spec_token_num + 1, self.num_heads, -1)
|
| 959 |
+
if not self.enable_kv_nz:
|
| 960 |
+
q_nope = q_nope.transpose(1, 2).contiguous()
|
| 961 |
+
q_pe = q_pe.transpose(1, 2).contiguous()
|
| 962 |
+
sparse_mode = 3
|
| 963 |
+
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
| 964 |
+
else:
|
| 965 |
+
if self.enable_kv_nz:
|
| 966 |
+
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
|
| 967 |
+
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
|
| 968 |
+
else:
|
| 969 |
+
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
| 970 |
+
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
| 971 |
+
sparse_mode = 0
|
| 972 |
+
spec_attn_mask = None
|
| 973 |
+
# shape of knope/k_pe for npu graph mode should be:
|
| 974 |
+
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
| 975 |
+
block_size = kv_c_and_k_pe_cache[0].shape[1]
|
| 976 |
+
if self.enable_kv_nz:
|
| 977 |
+
k_nope = k_nope.view(-1, self.num_kv_heads,
|
| 978 |
+
self.kv_lora_rank // 16, block_size, 16)
|
| 979 |
+
k_pe = k_pe.view(-1, self.num_kv_heads,
|
| 980 |
+
self.qk_rope_head_dim // 16, block_size, 16)
|
| 981 |
+
input_layout = "BSND"
|
| 982 |
+
else:
|
| 983 |
+
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
|
| 984 |
+
self.kv_lora_rank)
|
| 985 |
+
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
|
| 986 |
+
self.qk_rope_head_dim)
|
| 987 |
+
input_layout = "BNSD"
|
| 988 |
+
|
| 989 |
+
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
| 990 |
+
q_nope,
|
| 991 |
+
k_nope,
|
| 992 |
+
k_nope,
|
| 993 |
+
query_rope=q_pe,
|
| 994 |
+
key_rope=k_pe,
|
| 995 |
+
num_heads=self.num_heads,
|
| 996 |
+
num_key_value_heads=self.num_kv_heads,
|
| 997 |
+
input_layout=input_layout,
|
| 998 |
+
atten_mask=spec_attn_mask,
|
| 999 |
+
sparse_mode=sparse_mode,
|
| 1000 |
+
scale=self.scale,
|
| 1001 |
+
antiquant_mode=0,
|
| 1002 |
+
antiquant_scale=None,
|
| 1003 |
+
block_table=decode_meta.block_table,
|
| 1004 |
+
block_size=block_size,
|
| 1005 |
+
actual_seq_lengths_kv=decode_meta.seq_lens_list,
|
| 1006 |
+
)
|
| 1007 |
+
else:
|
| 1008 |
+
torch_npu._npu_paged_attention_mla(
|
| 1009 |
+
query=q,
|
| 1010 |
+
key_cache=kv_c_and_k_pe_cache,
|
| 1011 |
+
num_kv_heads=self.num_kv_heads,
|
| 1012 |
+
num_heads=self.num_heads,
|
| 1013 |
+
scale_value=self.scale,
|
| 1014 |
+
block_table=attn_metadata.decode.block_table, # type:ignore
|
| 1015 |
+
context_lens=attn_metadata.decode.seq_lens, # type:ignore
|
| 1016 |
+
mla_vheadsize=self.kv_lora_rank,
|
| 1017 |
+
out=attn_output)
|
| 1018 |
+
current_ms_metadata = get_multistream_comm_context()
|
| 1019 |
+
if current_ms_metadata is None:
|
| 1020 |
+
return self._v_up_proj_and_o_proj(attn_output,
|
| 1021 |
+
enable_multistream_mla)
|
| 1022 |
+
else:
|
| 1023 |
+
current_ms_metadata.before_comm_event.record()
|
| 1024 |
+
with torch.npu.stream(current_ms_metadata.comm_stream):
|
| 1025 |
+
current_ms_metadata.before_comm_event.wait()
|
| 1026 |
+
return self._v_up_proj_and_o_proj(attn_output)
|
| 1027 |
+
|
| 1028 |
+
def forward(
|
| 1029 |
+
self,
|
| 1030 |
+
layer: AttentionLayer,
|
| 1031 |
+
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
| 1032 |
+
hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn
|
| 1033 |
+
k_pe: torch.Tensor, # value in unified attn
|
| 1034 |
+
kv_cache: torch.Tensor,
|
| 1035 |
+
attn_metadata: M,
|
| 1036 |
+
output: Optional[torch.Tensor] = None,
|
| 1037 |
+
enable_multistream_mla: bool = False,
|
| 1038 |
+
ckq: Optional[torch.Tensor] = None,
|
| 1039 |
+
) -> torch.Tensor:
|
| 1040 |
+
assert output is not None, "Output tensor must be provided."
|
| 1041 |
+
if attn_metadata is None:
|
| 1042 |
+
# Profiling run.
|
| 1043 |
+
return output
|
| 1044 |
+
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
|
| 1045 |
+
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
| 1046 |
+
]
|
| 1047 |
+
num_actual_toks = attn_metadata.num_actual_tokens
|
| 1048 |
+
if k_pe is None and not self.running_in_graph:
|
| 1049 |
+
if not self.torchair_graph_enabled:
|
| 1050 |
+
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
| 1051 |
+
hidden_states_or_kv_c_normed)[0].split(
|
| 1052 |
+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
| 1053 |
+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
| 1054 |
+
else:
|
| 1055 |
+
kv_c_normed = hidden_states_or_kv_c_normed
|
| 1056 |
+
assert attn_metadata.num_decodes is not None and \
|
| 1057 |
+
attn_metadata.num_prefills is not None and \
|
| 1058 |
+
attn_metadata.num_decode_tokens is not None
|
| 1059 |
+
has_decode = attn_metadata.num_decodes > 0
|
| 1060 |
+
has_prefill = attn_metadata.num_prefills > 0
|
| 1061 |
+
num_decode_tokens = attn_metadata.num_decode_tokens
|
| 1062 |
+
if not self.running_in_graph:
|
| 1063 |
+
# Inputs and outputs may be padded for CUDA graphs
|
| 1064 |
+
output_padded = output
|
| 1065 |
+
output = output[:num_actual_toks, ...]
|
| 1066 |
+
if not self.torchair_graph_enabled:
|
| 1067 |
+
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
|
| 1068 |
+
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
|
| 1069 |
+
if not self.running_in_graph:
|
| 1070 |
+
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
|
| 1071 |
+
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
| 1072 |
+
if not self.torchair_graph_enabled:
|
| 1073 |
+
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
| 1074 |
+
k_pe = k_pe[:num_actual_toks, ...]
|
| 1075 |
+
k_pe = k_pe.unsqueeze(1)
|
| 1076 |
+
decode_k_pe = k_pe[:num_decode_tokens]
|
| 1077 |
+
prefill_k_pe = k_pe[num_decode_tokens:]
|
| 1078 |
+
else:
|
| 1079 |
+
decode_hs_or_q_c = hidden_states_or_q_c
|
| 1080 |
+
if has_decode:
|
| 1081 |
+
decode_k_nope = None
|
| 1082 |
+
assert attn_metadata.decode is not None
|
| 1083 |
+
if self.running_in_graph:
|
| 1084 |
+
seq_len = self.rotary_emb.max_position_embeddings * \
|
| 1085 |
+
getattr(self.rotary_emb, "scaling_factor", 1)
|
| 1086 |
+
cos = self.rotary_emb.cos_cached[:seq_len].to(
|
| 1087 |
+
dtype=decode_hs_or_q_c.dtype)
|
| 1088 |
+
sin = self.rotary_emb.sin_cached[:seq_len].to(
|
| 1089 |
+
dtype=decode_hs_or_q_c.dtype)
|
| 1090 |
+
cos = cos[attn_metadata.decode.input_positions]
|
| 1091 |
+
sin = sin[attn_metadata.decode.input_positions]
|
| 1092 |
+
cos = cos[:, None, None, :]
|
| 1093 |
+
sin = sin[:, None, None, :]
|
| 1094 |
+
with npu_stream_switch("mla_secondary",
|
| 1095 |
+
0,
|
| 1096 |
+
enabled=enable_multistream_mla):
|
| 1097 |
+
npu_wait_tensor(hidden_states_or_kv_c_normed,
|
| 1098 |
+
ckq,
|
| 1099 |
+
enabled=enable_multistream_mla)
|
| 1100 |
+
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
|
| 1101 |
+
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
| 1102 |
+
attn_metadata.slot_mapping)
|
| 1103 |
+
# Without explicitly controlling the order, IndexByTensor operations
|
| 1104 |
+
# would be placed after `matmul W_KV_T` hindering the overlapping of
|
| 1105 |
+
# KvRmsNormRopeCache and SingleRope.
|
| 1106 |
+
npu_wait_tensor(decode_hs_or_q_c,
|
| 1107 |
+
cos,
|
| 1108 |
+
enabled=enable_multistream_mla)
|
| 1109 |
+
npu_wait_tensor(decode_hs_or_q_c,
|
| 1110 |
+
sin,
|
| 1111 |
+
enabled=enable_multistream_mla)
|
| 1112 |
+
npu_wait_tensor(decode_hs_or_q_c,
|
| 1113 |
+
decode_kv,
|
| 1114 |
+
enabled=enable_multistream_mla)
|
| 1115 |
+
|
| 1116 |
+
decode_ql_nope, decode_q_pe = \
|
| 1117 |
+
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
| 1118 |
+
if self.running_in_graph:
|
| 1119 |
+
with npu_stream_switch("mla_secondary",
|
| 1120 |
+
0,
|
| 1121 |
+
enabled=enable_multistream_mla):
|
| 1122 |
+
npu_wait_tensor(decode_q_pe,
|
| 1123 |
+
decode_k_pe,
|
| 1124 |
+
enabled=enable_multistream_mla)
|
| 1125 |
+
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
|
| 1126 |
+
else:
|
| 1127 |
+
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
| 1128 |
+
attn_metadata.decode.input_positions,
|
| 1129 |
+
decode_q_pe.contiguous(),
|
| 1130 |
+
decode_k_pe,
|
| 1131 |
+
max_seq_len=attn_metadata.decode.max_seq_lens)
|
| 1132 |
+
if has_prefill:
|
| 1133 |
+
assert attn_metadata.prefill is not None
|
| 1134 |
+
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
| 1135 |
+
.view(-1, self.num_heads, self.qk_head_dim)
|
| 1136 |
+
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
| 1137 |
+
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
|
| 1138 |
+
if self.torchair_graph_enabled:
|
| 1139 |
+
num_tokens = prefill_hs_or_q_c.shape[0]
|
| 1140 |
+
seq_len = self.rotary_emb.max_position_embeddings * \
|
| 1141 |
+
getattr(self.rotary_emb, "scaling_factor", 1)
|
| 1142 |
+
cos = self.rotary_emb.cos_cached[:seq_len].to(
|
| 1143 |
+
dtype=prefill_q_pe.dtype)
|
| 1144 |
+
sin = self.rotary_emb.sin_cached[:seq_len].to(
|
| 1145 |
+
dtype=prefill_q_pe.dtype)
|
| 1146 |
+
cos = cos[attn_metadata.prefill.input_positions]
|
| 1147 |
+
sin = sin[attn_metadata.prefill.input_positions]
|
| 1148 |
+
cos = cos[:, None, None, :]
|
| 1149 |
+
sin = sin[:, None, None, :]
|
| 1150 |
+
|
| 1151 |
+
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
| 1152 |
+
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
|
| 1153 |
+
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
| 1154 |
+
attn_metadata.slot_mapping)
|
| 1155 |
+
|
| 1156 |
+
kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
|
| 1157 |
+
prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
|
| 1158 |
+
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
|
| 1159 |
+
-1)
|
| 1160 |
+
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
|
| 1161 |
+
else:
|
| 1162 |
+
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
| 1163 |
+
attn_metadata.prefill.input_positions,
|
| 1164 |
+
prefill_q_pe.contiguous(),
|
| 1165 |
+
prefill_k_pe,
|
| 1166 |
+
max_seq_len=attn_metadata.prefill.max_seq_lens)
|
| 1167 |
+
if self.torchair_graph_enabled:
|
| 1168 |
+
if len(kv_cache) > 0 and kv_cache[0].numel(
|
| 1169 |
+
) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
| 1170 |
+
slots = attn_metadata.slot_mapping
|
| 1171 |
+
# NOTE: Separate the kv cache in advance to avoid OOM or other issues
|
| 1172 |
+
torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
|
| 1173 |
+
num_tokens, self.num_kv_heads, -1),
|
| 1174 |
+
value=prefill_k_pe,
|
| 1175 |
+
key_cache=kv_cache[0],
|
| 1176 |
+
value_cache=kv_cache[1],
|
| 1177 |
+
slot_indices=slots)
|
| 1178 |
+
elif kv_cache.numel() > 0:
|
| 1179 |
+
key = torch.cat([
|
| 1180 |
+
kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]),
|
| 1181 |
+
k_pe
|
| 1182 |
+
],
|
| 1183 |
+
dim=2)
|
| 1184 |
+
torch_npu._npu_reshape_and_cache_siso(
|
| 1185 |
+
key=key,
|
| 1186 |
+
key_cache=kv_cache,
|
| 1187 |
+
slot_indices=attn_metadata.slot_mapping.flatten())
|
| 1188 |
+
if has_prefill:
|
| 1189 |
+
# FIX: aicore move should be also placed on the comm stream in dbo,
|
| 1190 |
+
# otherwise it may affect the accuracy
|
| 1191 |
+
# TODO: use an elegant way to overlap
|
| 1192 |
+
output_prefill = self._forward_prefill(prefill_q,
|
| 1193 |
+
prefill_k_c_normed,
|
| 1194 |
+
prefill_k_pe, kv_cache,
|
| 1195 |
+
attn_metadata)
|
| 1196 |
+
current_ms_metadata = get_multistream_comm_context()
|
| 1197 |
+
if current_ms_metadata is not None:
|
| 1198 |
+
with torch.npu.stream(current_ms_metadata.comm_stream):
|
| 1199 |
+
output[num_decode_tokens:] = output_prefill
|
| 1200 |
+
current_ms_metadata.after_comm_event.record()
|
| 1201 |
+
else:
|
| 1202 |
+
output[num_decode_tokens:] = output_prefill
|
| 1203 |
+
|
| 1204 |
+
if has_decode:
|
| 1205 |
+
if self.running_in_graph:
|
| 1206 |
+
return self._forward_decode(decode_ql_nope, decode_q_pe,
|
| 1207 |
+
decode_k_nope, decode_k_pe,
|
| 1208 |
+
kv_cache, attn_metadata,
|
| 1209 |
+
enable_multistream_mla)
|
| 1210 |
+
else:
|
| 1211 |
+
output_decode = self._forward_decode(decode_ql_nope,
|
| 1212 |
+
decode_q_pe,
|
| 1213 |
+
decode_k_nope,
|
| 1214 |
+
decode_k_pe, kv_cache,
|
| 1215 |
+
attn_metadata)
|
| 1216 |
+
current_ms_metadata = get_multistream_comm_context()
|
| 1217 |
+
if current_ms_metadata is not None:
|
| 1218 |
+
with torch.npu.stream(current_ms_metadata.comm_stream):
|
| 1219 |
+
output[:num_decode_tokens] = output_decode
|
| 1220 |
+
current_ms_metadata.after_comm_event.record()
|
| 1221 |
+
else:
|
| 1222 |
+
output[:num_decode_tokens] = output_decode
|
| 1223 |
+
|
| 1224 |
+
return output_padded
|
inference/vllm_ascend/entrypoints/openai/reasoning_parsers/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 2 |
+
from .pangu_reasoning_parser import PanguReasoningParser
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"PanguReasoningParser"
|
| 6 |
+
]
|
inference/vllm_ascend/entrypoints/openai/reasoning_parsers/pangu_reasoning_parser.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
| 3 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 4 |
+
|
| 5 |
+
from collections.abc import Sequence
|
| 6 |
+
from typing import Optional, Union
|
| 7 |
+
|
| 8 |
+
from transformers import PreTrainedTokenizerBase
|
| 9 |
+
|
| 10 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
| 11 |
+
DeltaMessage)
|
| 12 |
+
from vllm.logger import init_logger
|
| 13 |
+
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
| 14 |
+
|
| 15 |
+
logger = init_logger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@ReasoningParserManager.register_module("pangu")
|
| 19 |
+
class PanguReasoningParser(ReasoningParser):
|
| 20 |
+
"""
|
| 21 |
+
Reasoning parser for Pangu model.
|
| 22 |
+
|
| 23 |
+
The Pangu model uses [unused16]...[unused17] tokens to denote reasoning
|
| 24 |
+
text. This parser extracts the reasoning content from the model output.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
start_token_id: int
|
| 28 |
+
end_token_id: int
|
| 29 |
+
|
| 30 |
+
start_token: str = "[unused16]"
|
| 31 |
+
end_token: str = "[unused17]"
|
| 32 |
+
|
| 33 |
+
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
| 34 |
+
super().__init__(tokenizer)
|
| 35 |
+
|
| 36 |
+
if not self.model_tokenizer:
|
| 37 |
+
raise ValueError(
|
| 38 |
+
"The model tokenizer must be passed to the ReasoningParser "
|
| 39 |
+
"constructor during construction.")
|
| 40 |
+
|
| 41 |
+
self.start_token_id = self.vocab.get(self.start_token)
|
| 42 |
+
self.end_token_id = self.vocab.get(self.end_token)
|
| 43 |
+
if self.start_token_id is None or self.end_token_id is None:
|
| 44 |
+
raise RuntimeError(
|
| 45 |
+
"Pangu reasoning parser could not locate think start/end "
|
| 46 |
+
"tokens in the tokenizer!")
|
| 47 |
+
|
| 48 |
+
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
| 49 |
+
return self.end_token_id in input_ids
|
| 50 |
+
|
| 51 |
+
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
| 52 |
+
"""
|
| 53 |
+
Extract the content after the end tokens
|
| 54 |
+
"""
|
| 55 |
+
if self.end_token_id not in input_ids[:-1]:
|
| 56 |
+
return []
|
| 57 |
+
else:
|
| 58 |
+
return input_ids[input_ids.index(self.end_token_id) + 1:]
|
| 59 |
+
|
| 60 |
+
def extract_reasoning_content_streaming(
|
| 61 |
+
self,
|
| 62 |
+
previous_text: str,
|
| 63 |
+
current_text: str,
|
| 64 |
+
delta_text: str,
|
| 65 |
+
previous_token_ids: Sequence[int],
|
| 66 |
+
current_token_ids: Sequence[int],
|
| 67 |
+
delta_token_ids: Sequence[int],
|
| 68 |
+
) -> Union[DeltaMessage, None]:
|
| 69 |
+
"""
|
| 70 |
+
Extract reasoning content from a delta message.
|
| 71 |
+
Handles streaming output where previous + delta = current.
|
| 72 |
+
Uses token IDs for faster processing.
|
| 73 |
+
For text [unused16]abc[unused17]xyz:
|
| 74 |
+
- 'abc' goes to reasoning_content
|
| 75 |
+
- 'xyz' goes to content
|
| 76 |
+
"""
|
| 77 |
+
# Skip single special tokens
|
| 78 |
+
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
|
| 79 |
+
self.start_token_id, self.end_token_id
|
| 80 |
+
]):
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
# Check if [unused16] is present in previous or delta.
|
| 84 |
+
# Keep compatibility with models that don't generate [unused16] tokens.
|
| 85 |
+
if self.start_token_id in previous_token_ids:
|
| 86 |
+
if self.end_token_id in delta_token_ids:
|
| 87 |
+
# [unused16] in previous, [unused17] in delta,
|
| 88 |
+
# extract reasoning content
|
| 89 |
+
end_index = delta_text.find(self.end_token)
|
| 90 |
+
reasoning_content = delta_text[:end_index]
|
| 91 |
+
content = delta_text[end_index + len(self.end_token):]
|
| 92 |
+
return DeltaMessage(
|
| 93 |
+
reasoning_content=reasoning_content,
|
| 94 |
+
content=content if content else None,
|
| 95 |
+
)
|
| 96 |
+
elif self.end_token_id in previous_token_ids:
|
| 97 |
+
# [unused16] in previous, [unused17] in previous,
|
| 98 |
+
# reasoning content continues
|
| 99 |
+
return DeltaMessage(content=delta_text)
|
| 100 |
+
else:
|
| 101 |
+
# [unused16] in previous, no [unused17] in previous or delta,
|
| 102 |
+
# reasoning content continues
|
| 103 |
+
return DeltaMessage(reasoning_content=delta_text)
|
| 104 |
+
elif self.start_token_id in delta_token_ids:
|
| 105 |
+
if self.end_token_id in delta_token_ids:
|
| 106 |
+
# [unused16] in delta, [unused17] in delta, extract reasoning content
|
| 107 |
+
start_index = delta_text.find(self.start_token)
|
| 108 |
+
end_index = delta_text.find(self.end_token)
|
| 109 |
+
reasoning_content = delta_text[start_index +
|
| 110 |
+
len(self.start_token):end_index]
|
| 111 |
+
content = delta_text[end_index + len(self.end_token):]
|
| 112 |
+
return DeltaMessage(
|
| 113 |
+
reasoning_content=reasoning_content,
|
| 114 |
+
content=content if content else None,
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
# [unused16] in delta, no [unused17] in delta,
|
| 118 |
+
# reasoning content continues
|
| 119 |
+
return DeltaMessage(reasoning_content=delta_text)
|
| 120 |
+
else:
|
| 121 |
+
# No [unused16] in previous or delta, also need to check for [unused17].
|
| 122 |
+
# Because the model may have generated [unused17] without [unused16]
|
| 123 |
+
if self.end_token_id in delta_token_ids:
|
| 124 |
+
# [unused17] in delta with more tokens,
|
| 125 |
+
# extract reasoning content and content
|
| 126 |
+
end_index = delta_text.find(self.end_token)
|
| 127 |
+
reasoning_content = delta_text[:end_index]
|
| 128 |
+
content = delta_text[end_index + len(self.end_token):]
|
| 129 |
+
return DeltaMessage(
|
| 130 |
+
reasoning_content=reasoning_content,
|
| 131 |
+
content=content if content else None,
|
| 132 |
+
)
|
| 133 |
+
elif self.end_token_id in previous_token_ids:
|
| 134 |
+
# [unused17] in previous, thinking content ends
|
| 135 |
+
return DeltaMessage(content=delta_text)
|
| 136 |
+
else:
|
| 137 |
+
# no [unused17] in previous or delta, reasoning content continues
|
| 138 |
+
return DeltaMessage(reasoning_content=delta_text)
|
| 139 |
+
|
| 140 |
+
def extract_reasoning_content(
|
| 141 |
+
self, model_output: str, request: ChatCompletionRequest
|
| 142 |
+
) -> tuple[Optional[str], Optional[str]]:
|
| 143 |
+
"""
|
| 144 |
+
Extract reasoning content from the model output.
|
| 145 |
+
|
| 146 |
+
For text [unused16]abc[unused17]xyz:
|
| 147 |
+
- 'abc' goes to reasoning_content
|
| 148 |
+
- 'xyz' goes to content
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
tuple[Optional[str], Optional[str]]: reasoning content and content
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
# Check if the start token is present in the model output, remove it
|
| 155 |
+
# if it is present.
|
| 156 |
+
model_output_parts = model_output.partition(self.start_token)
|
| 157 |
+
model_output = model_output_parts[2] if model_output_parts[
|
| 158 |
+
1] else model_output_parts[0]
|
| 159 |
+
|
| 160 |
+
# Thus we assume the reasoning content is always at the start.
|
| 161 |
+
if self.end_token not in model_output:
|
| 162 |
+
return model_output, None
|
| 163 |
+
else:
|
| 164 |
+
reasoning_content, _, content = model_output.partition(
|
| 165 |
+
self.end_token)
|
| 166 |
+
# If the end token is not found, return the model output as is.
|
| 167 |
+
# It should not happen since we already checked for the presence
|
| 168 |
+
# of the end token.
|
| 169 |
+
# If generation stops right after end-of-think, return null content
|
| 170 |
+
final_content = content or None
|
| 171 |
+
return reasoning_content, final_content
|
inference/vllm_ascend/entrypoints/openai/tool_parsers/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 2 |
+
from .pangu_tool_parser import PanguToolParser
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"PanguToolParser"
|
| 6 |
+
]
|
inference/vllm_ascend/entrypoints/openai/tool_parsers/pangu_tool_parser.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 2 |
+
# Copyright 2023 The vLLM team.
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
from json import JSONDecodeError, JSONDecoder
|
| 7 |
+
from typing import Dict, List, Sequence, Union, Optional
|
| 8 |
+
from pydantic import Field
|
| 9 |
+
import partial_json_parser
|
| 10 |
+
from partial_json_parser.core.options import Allow
|
| 11 |
+
from transformers import PreTrainedTokenizerBase
|
| 12 |
+
|
| 13 |
+
from vllm.entrypoints.chat_utils import random_tool_call_id
|
| 14 |
+
from vllm.entrypoints.openai.tool_parsers.utils import (
|
| 15 |
+
extract_intermediate_diff)
|
| 16 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
| 17 |
+
DeltaFunctionCall, DeltaMessage,
|
| 18 |
+
DeltaToolCall,
|
| 19 |
+
ExtractedToolCallInformation,
|
| 20 |
+
FunctionCall, ToolCall,
|
| 21 |
+
)
|
| 22 |
+
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
| 23 |
+
ToolParser, ToolParserManager)
|
| 24 |
+
from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix,
|
| 25 |
+
is_complete_json)
|
| 26 |
+
from vllm.logger import init_logger
|
| 27 |
+
import os
|
| 28 |
+
|
| 29 |
+
logger = init_logger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@ToolParserManager.register_module("pangu")
|
| 33 |
+
class PanguToolParser(ToolParser):
|
| 34 |
+
|
| 35 |
+
def __init__(self, tokenizer: PreTrainedTokenizerBase, enable_reasoning=False):
|
| 36 |
+
super().__init__(tokenizer)
|
| 37 |
+
|
| 38 |
+
# initialize properties used for state when parsing tool calls in
|
| 39 |
+
# streaming mode
|
| 40 |
+
self.prev_tool_call_arr: List[Dict] = []
|
| 41 |
+
self.current_tool_id: int = -1
|
| 42 |
+
self.current_tool_name_sent: bool = False
|
| 43 |
+
self.streamed_args_for_tool: List[str] = [
|
| 44 |
+
] # map what has been streamed for each tool so far to a list
|
| 45 |
+
|
| 46 |
+
self.tool_call_start_token = "[unused11]"
|
| 47 |
+
self.tool_call_end_token = "[unused12]"
|
| 48 |
+
self.pattern = re.escape(self.tool_call_start_token) \
|
| 49 |
+
+ "(.*?)" + re.escape(self.tool_call_end_token)
|
| 50 |
+
self.tool_call_regex = re.compile(self.pattern, re.DOTALL)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
self.tool_call_start_token_id = self.vocab.get(
|
| 54 |
+
self.tool_call_start_token)
|
| 55 |
+
self.tool_call_end_token_id = self.vocab.get(
|
| 56 |
+
self.tool_call_end_token)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
if (self.tool_call_start_token_id is None
|
| 60 |
+
or self.tool_call_end_token_id is None):
|
| 61 |
+
raise RuntimeError(
|
| 62 |
+
"Pangu Tool parser could not locate tool calls start/end "
|
| 63 |
+
"tokens in the tokenizer!")
|
| 64 |
+
self.is_complete = []
|
| 65 |
+
self.text_after_start_token = ""
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def extract_tool_calls(
|
| 69 |
+
self, model_output: str,
|
| 70 |
+
request: ChatCompletionRequest
|
| 71 |
+
) -> ExtractedToolCallInformation:
|
| 72 |
+
"""
|
| 73 |
+
Extract the tool calls from a complete model response.
|
| 74 |
+
"""
|
| 75 |
+
# case -- if a tool call token is not present, return a text response
|
| 76 |
+
if not (self.tool_call_start_token in model_output and \
|
| 77 |
+
model_output.find(self.tool_call_end_token) != -1):
|
| 78 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 79 |
+
tool_calls=[],
|
| 80 |
+
content=model_output)
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
raw_function_calls = []
|
| 84 |
+
# use a regex to find the tool call between the tags
|
| 85 |
+
function_call_tuples = self.tool_call_regex.findall(model_output)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# load the JSON, and then use it to build the Function and
|
| 89 |
+
# Tool Call
|
| 90 |
+
for function_call_str in function_call_tuples:
|
| 91 |
+
function_call = json.loads(function_call_str)
|
| 92 |
+
raw_function_calls.extend(function_call)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
tool_calls: List[ToolCall] = [
|
| 96 |
+
ToolCall(
|
| 97 |
+
type="function",
|
| 98 |
+
function=FunctionCall(
|
| 99 |
+
name=function_call["name"],
|
| 100 |
+
# function call args are JSON but as a string
|
| 101 |
+
arguments=json.dumps(function_call["arguments"] \
|
| 102 |
+
if "arguments" in function_call \
|
| 103 |
+
else function_call["parameters"], ensure_ascii=False)))
|
| 104 |
+
for function_call in raw_function_calls
|
| 105 |
+
]
|
| 106 |
+
content = model_output[:model_output.
|
| 107 |
+
find(self.tool_call_start_token)]
|
| 108 |
+
|
| 109 |
+
# get any content before the tool call
|
| 110 |
+
ret = ExtractedToolCallInformation(tools_called=True,
|
| 111 |
+
tool_calls=tool_calls,
|
| 112 |
+
content=content if content else None)
|
| 113 |
+
|
| 114 |
+
return ret
|
| 115 |
+
|
| 116 |
+
except Exception:
|
| 117 |
+
logger.exception("Error in extracting tool call from response.")
|
| 118 |
+
# return information to just treat the tool call as regular JSON
|
| 119 |
+
return ExtractedToolCallInformation(tools_called=False,
|
| 120 |
+
tool_calls=[],
|
| 121 |
+
content=model_output)
|
| 122 |
+
|
| 123 |
+
def extract_tool_calls_streaming(
|
| 124 |
+
self,
|
| 125 |
+
previous_text: str,
|
| 126 |
+
current_text: str,
|
| 127 |
+
delta_text: str,
|
| 128 |
+
previous_token_ids: Sequence[int],
|
| 129 |
+
current_token_ids: Sequence[int],
|
| 130 |
+
delta_token_ids: Sequence[int],
|
| 131 |
+
request: ChatCompletionRequest,
|
| 132 |
+
) -> Union[DeltaMessage, None]:
|
| 133 |
+
|
| 134 |
+
if (self.tool_call_end_token_id in delta_token_ids
|
| 135 |
+
and len(delta_token_ids) == 1):
|
| 136 |
+
# if it's the only token, return None, so we don't send a chat
|
| 137 |
+
# completion and don't send a control token
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
if (self.tool_call_end_token in current_text
|
| 141 |
+
and self.tool_call_end_token not in delta_text):
|
| 142 |
+
return DeltaMessage(content=delta_text)
|
| 143 |
+
|
| 144 |
+
if self.tool_call_start_token not in current_text:
|
| 145 |
+
return DeltaMessage(content=delta_text)
|
| 146 |
+
|
| 147 |
+
if self.tool_call_start_token in delta_text:
|
| 148 |
+
texts = delta_text.split(self.tool_call_start_token)
|
| 149 |
+
text_before_start_token = texts[0]
|
| 150 |
+
if text_before_start_token:
|
| 151 |
+
return DeltaMessage(content=text_before_start_token)
|
| 152 |
+
|
| 153 |
+
if (self.tool_call_start_token_id in delta_token_ids
|
| 154 |
+
and len(delta_token_ids) == 1):
|
| 155 |
+
# if it's the only token, return None, so we don't send a chat
|
| 156 |
+
# completion and don't send a control token
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
# bit mask flags for partial JSON parsing. If the name hasn't been
|
| 160 |
+
# sent yet, don't allow sending
|
| 161 |
+
# an incomplete string since OpenAI only ever (as far as I have
|
| 162 |
+
# seen) allows sending the entire tool/ function name at once.
|
| 163 |
+
flags = Allow.ALL if self.current_tool_name_sent \
|
| 164 |
+
else Allow.ALL & ~Allow.STR
|
| 165 |
+
try:
|
| 166 |
+
|
| 167 |
+
tool_call_portion = current_text.split(
|
| 168 |
+
self.tool_call_start_token)[-1].split(self.tool_call_end_token)[0]
|
| 169 |
+
try:
|
| 170 |
+
tool_call_arr: list[dict] = partial_json_parser.loads(
|
| 171 |
+
tool_call_portion, flags)
|
| 172 |
+
|
| 173 |
+
self.is_complete.append(
|
| 174 |
+
is_complete_json(tool_call_portion))
|
| 175 |
+
except partial_json_parser.core.exceptions.MalformedJSON:
|
| 176 |
+
logger.debug('not enough tokens to parse into JSON yet')
|
| 177 |
+
return None
|
| 178 |
+
|
| 179 |
+
# select as the current tool call the one we're on the state at
|
| 180 |
+
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
|
| 181 |
+
if len(tool_call_arr) > 0 else {}
|
| 182 |
+
|
| 183 |
+
# case -- if no tokens have been streamed for the tool, e.g.
|
| 184 |
+
# only the array brackets, stream nothing
|
| 185 |
+
if len(tool_call_arr) == 0:
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
# case: we are starting a new tool in the array
|
| 189 |
+
# -> array has > 0 length AND length has moved past cursor
|
| 190 |
+
elif (len(tool_call_arr) > 0
|
| 191 |
+
and len(tool_call_arr) > self.current_tool_id + 1):
|
| 192 |
+
|
| 193 |
+
# if we're moving on to a new call, first make sure we
|
| 194 |
+
# haven't missed anything in the previous one that was
|
| 195 |
+
# auto-generated due to JSON completions, but wasn't
|
| 196 |
+
# streamed to the client yet.
|
| 197 |
+
if self.current_tool_id >= 0:
|
| 198 |
+
cur_arguments = current_tool_call.get("arguments")
|
| 199 |
+
if cur_arguments:
|
| 200 |
+
cur_args_json = json.dumps(cur_arguments,
|
| 201 |
+
ensure_ascii=False)
|
| 202 |
+
sent = len(
|
| 203 |
+
self.streamed_args_for_tool[self.current_tool_id])
|
| 204 |
+
argument_diff = cur_args_json[sent:]
|
| 205 |
+
|
| 206 |
+
logger.debug("got arguments diff: %s", argument_diff)
|
| 207 |
+
delta = DeltaMessage(tool_calls=[
|
| 208 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 209 |
+
function=DeltaFunctionCall(
|
| 210 |
+
arguments=argument_diff).
|
| 211 |
+
model_dump(exclude_none=True))
|
| 212 |
+
])
|
| 213 |
+
self.streamed_args_for_tool[
|
| 214 |
+
self.current_tool_id] += argument_diff
|
| 215 |
+
else:
|
| 216 |
+
delta = None
|
| 217 |
+
else:
|
| 218 |
+
delta = None
|
| 219 |
+
# re-set stuff pertaining to progress in the current tool
|
| 220 |
+
self.current_tool_id = len(tool_call_arr) - 1
|
| 221 |
+
self.current_tool_name_sent = False
|
| 222 |
+
self.streamed_args_for_tool.append("")
|
| 223 |
+
self.is_complete = []
|
| 224 |
+
logger.debug("starting on new tool %d", self.current_tool_id)
|
| 225 |
+
return delta
|
| 226 |
+
|
| 227 |
+
# if the current tool name hasn't been sent, send if available
|
| 228 |
+
# - otherwise send nothing
|
| 229 |
+
elif not self.current_tool_name_sent:
|
| 230 |
+
function_name = current_tool_call.get("name")
|
| 231 |
+
if function_name:
|
| 232 |
+
delta = DeltaMessage(tool_calls=[
|
| 233 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 234 |
+
type="function",
|
| 235 |
+
id=random_tool_call_id(),
|
| 236 |
+
function=DeltaFunctionCall(
|
| 237 |
+
name=function_name).model_dump(
|
| 238 |
+
exclude_none=True))
|
| 239 |
+
])
|
| 240 |
+
self.current_tool_name_sent = True
|
| 241 |
+
else:
|
| 242 |
+
delta = None
|
| 243 |
+
|
| 244 |
+
# now we know we're on the same tool call and we're streaming
|
| 245 |
+
# arguments
|
| 246 |
+
else:
|
| 247 |
+
cur_arguments = current_tool_call.get("arguments")
|
| 248 |
+
delta = None
|
| 249 |
+
if (self.is_complete[-1] and not cur_arguments
|
| 250 |
+
and not self.streamed_args_for_tool[-1]):
|
| 251 |
+
argument_diff = "{}"
|
| 252 |
+
delta = DeltaMessage(tool_calls=[
|
| 253 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 254 |
+
function=DeltaFunctionCall(
|
| 255 |
+
arguments=argument_diff).
|
| 256 |
+
model_dump(exclude_none=True))
|
| 257 |
+
])
|
| 258 |
+
self.streamed_args_for_tool[
|
| 259 |
+
self.current_tool_id] += argument_diff
|
| 260 |
+
|
| 261 |
+
if cur_arguments:
|
| 262 |
+
sent = len(
|
| 263 |
+
self.streamed_args_for_tool[self.current_tool_id])
|
| 264 |
+
cur_args_json = json.dumps(cur_arguments,
|
| 265 |
+
ensure_ascii=False)
|
| 266 |
+
prev_arguments = self.prev_tool_call_arr[
|
| 267 |
+
self.current_tool_id].get("arguments")
|
| 268 |
+
|
| 269 |
+
argument_diff = None
|
| 270 |
+
if self.is_complete[-1]:
|
| 271 |
+
argument_diff = cur_args_json[sent:]
|
| 272 |
+
elif prev_arguments:
|
| 273 |
+
prev_args_json = json.dumps(prev_arguments,
|
| 274 |
+
ensure_ascii=False)
|
| 275 |
+
if cur_args_json != prev_args_json:
|
| 276 |
+
|
| 277 |
+
prefix = find_common_prefix(
|
| 278 |
+
prev_args_json, cur_args_json)
|
| 279 |
+
argument_diff = prefix[sent:]
|
| 280 |
+
|
| 281 |
+
if argument_diff is not None:
|
| 282 |
+
delta = DeltaMessage(tool_calls=[
|
| 283 |
+
DeltaToolCall(index=self.current_tool_id,
|
| 284 |
+
function=DeltaFunctionCall(
|
| 285 |
+
arguments=argument_diff).
|
| 286 |
+
model_dump(exclude_none=True))
|
| 287 |
+
])
|
| 288 |
+
self.streamed_args_for_tool[
|
| 289 |
+
self.current_tool_id] += argument_diff
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
self.prev_tool_call_arr = tool_call_arr
|
| 293 |
+
return delta
|
| 294 |
+
|
| 295 |
+
except Exception:
|
| 296 |
+
logger.exception("Error trying to handle streaming tool call.")
|
| 297 |
+
logger.debug(
|
| 298 |
+
"Skipping chunk as a result of tool streaming extraction "
|
| 299 |
+
"error")
|
| 300 |
+
return None
|
inference/vllm_ascend/envs.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# This file is a part of the vllm-ascend project.
|
| 4 |
+
#
|
| 5 |
+
# This file is mainly Adapted from vllm-project/vllm/vllm/envs.py
|
| 6 |
+
# Copyright 2023 The vLLM team.
|
| 7 |
+
#
|
| 8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 9 |
+
# you may not use this file except in compliance with the License.
|
| 10 |
+
# You may obtain a copy of the License at
|
| 11 |
+
#
|
| 12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 13 |
+
#
|
| 14 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
+
# See the License for the specific language governing permissions and
|
| 18 |
+
# limitations under the License.
|
| 19 |
+
#
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
from typing import Any, Callable, Dict
|
| 23 |
+
|
| 24 |
+
# The begin-* and end* here are used by the documentation generator
|
| 25 |
+
# to extract the used env vars.
|
| 26 |
+
|
| 27 |
+
# begin-env-vars-definition
|
| 28 |
+
|
| 29 |
+
env_variables: Dict[str, Callable[[], Any]] = {
|
| 30 |
+
# max compile thread number for package building. Usually, it is set to
|
| 31 |
+
# the number of CPU cores. If not set, the default value is None, which
|
| 32 |
+
# means all number of CPU cores will be used.
|
| 33 |
+
"MAX_JOBS":
|
| 34 |
+
lambda: os.getenv("MAX_JOBS", None),
|
| 35 |
+
# The build type of the package. It can be one of the following values:
|
| 36 |
+
# Release, Debug, RelWithDebugInfo. If not set, the default value is Release.
|
| 37 |
+
"CMAKE_BUILD_TYPE":
|
| 38 |
+
lambda: os.getenv("CMAKE_BUILD_TYPE"),
|
| 39 |
+
# Whether to compile custom kernels. If not set, the default value is True.
|
| 40 |
+
# If set to False, the custom kernels will not be compiled. Please note that
|
| 41 |
+
# the sleep mode feature will be disabled as well if custom kernels are not
|
| 42 |
+
# compiled.
|
| 43 |
+
"COMPILE_CUSTOM_KERNELS":
|
| 44 |
+
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
|
| 45 |
+
# The CXX compiler used for compiling the package. If not set, the default
|
| 46 |
+
# value is None, which means the system default CXX compiler will be used.
|
| 47 |
+
"CXX_COMPILER":
|
| 48 |
+
lambda: os.getenv("CXX_COMPILER", None),
|
| 49 |
+
# The C compiler used for compiling the package. If not set, the default
|
| 50 |
+
# value is None, which means the system default C compiler will be used.
|
| 51 |
+
"C_COMPILER":
|
| 52 |
+
lambda: os.getenv("C_COMPILER", None),
|
| 53 |
+
# The version of the Ascend chip. If not set, the default value is
|
| 54 |
+
# ASCEND910B1. It's used for package building. Please make sure that the
|
| 55 |
+
# version is correct.
|
| 56 |
+
"SOC_VERSION":
|
| 57 |
+
lambda: os.getenv("SOC_VERSION", "ASCEND910B1"),
|
| 58 |
+
# If set, vllm-ascend will print verbose logs during compilation
|
| 59 |
+
"VERBOSE":
|
| 60 |
+
lambda: bool(int(os.getenv('VERBOSE', '0'))),
|
| 61 |
+
# The home path for CANN toolkit. If not set, the default value is
|
| 62 |
+
# /usr/local/Ascend/ascend-toolkit/latest
|
| 63 |
+
"ASCEND_HOME_PATH":
|
| 64 |
+
lambda: os.getenv("ASCEND_HOME_PATH", None),
|
| 65 |
+
# The path for HCCN Tool, the tool will be called by disaggregated prefilling
|
| 66 |
+
# case.
|
| 67 |
+
"HCCN_PATH":
|
| 68 |
+
lambda: os.getenv("HCCN_PATH", "/usr/local/Ascend/driver/tools/hccn_tool"),
|
| 69 |
+
# The path for HCCL library, it's used by pyhccl communicator backend. If
|
| 70 |
+
# not set, the default value is libhccl.so。
|
| 71 |
+
"HCCL_SO_PATH":
|
| 72 |
+
# The prefill device id for disaggregated prefilling case.
|
| 73 |
+
lambda: os.environ.get("HCCL_SO_PATH", None),
|
| 74 |
+
"PROMPT_DEVICE_ID":
|
| 75 |
+
lambda: os.getenv("PROMPT_DEVICE_ID", None),
|
| 76 |
+
# The decode device id for disaggregated prefilling case.
|
| 77 |
+
"DECODE_DEVICE_ID":
|
| 78 |
+
lambda: os.getenv("DECODE_DEVICE_ID", None),
|
| 79 |
+
# The port number for llmdatadist communication. If not set, the default
|
| 80 |
+
# value is 26000.
|
| 81 |
+
"LLMDATADIST_COMM_PORT":
|
| 82 |
+
lambda: os.getenv("LLMDATADIST_COMM_PORT", "26000"),
|
| 83 |
+
# The wait time for llmdatadist sync cache. If not set, the default value is
|
| 84 |
+
# 5000ms.
|
| 85 |
+
"LLMDATADIST_SYNC_CACHE_WAIT_TIME":
|
| 86 |
+
lambda: os.getenv("LLMDATADIST_SYNC_CACHE_WAIT_TIME", "5000"),
|
| 87 |
+
# The version of vllm is installed. This value is used for developers who
|
| 88 |
+
# installed vllm from source locally. In this case, the version of vllm is
|
| 89 |
+
# usually changed. For example, if the version of vllm is "0.9.0", but when
|
| 90 |
+
# it's installed from source, the version of vllm is usually set to "0.9.1".
|
| 91 |
+
# In this case, developers need to set this value to "0.9.0" to make sure
|
| 92 |
+
# that the correct package is installed.
|
| 93 |
+
"VLLM_VERSION":
|
| 94 |
+
lambda: os.getenv("VLLM_VERSION", None),
|
| 95 |
+
# Whether to enable the trace recompiles from pytorch.
|
| 96 |
+
"VLLM_ASCEND_TRACE_RECOMPILES":
|
| 97 |
+
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
|
| 98 |
+
# Whether to enable fused_experts_allgather_ep. MoeInitRoutingV3 and
|
| 99 |
+
# GroupedMatmulFinalizeRouting operators are combined to implement EP.
|
| 100 |
+
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP":
|
| 101 |
+
lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0'))
|
| 102 |
+
),
|
| 103 |
+
"VLLM_ASCEND_ENABLE_DBO":
|
| 104 |
+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
|
| 105 |
+
# Whether to enable the model execute time observe profile. Disable it when
|
| 106 |
+
# running vllm ascend in production environment.
|
| 107 |
+
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
|
| 108 |
+
lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
|
| 109 |
+
),
|
| 110 |
+
# MOE_ALL2ALL_BUFFER:
|
| 111 |
+
# 0: default, normal init.
|
| 112 |
+
# 1: enable moe_all2all_buffer.
|
| 113 |
+
"MOE_ALL2ALL_BUFFER":
|
| 114 |
+
lambda: bool(int(os.getenv("MOE_ALL2ALL_BUFFER", '0'))),
|
| 115 |
+
# Some models are optimized by vllm ascend. While in some case, e.g. rlhf
|
| 116 |
+
# training, the optimized model may not be suitable. In this case, set this
|
| 117 |
+
# value to False to disable the optimized model.
|
| 118 |
+
"USE_OPTIMIZED_MODEL":
|
| 119 |
+
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
|
| 120 |
+
# SELECT_GATING_TOPK_SOTFMAX_EXPERTS is the equivalent of select_experts in non-quantized scenarios.
|
| 121 |
+
# In theory, it should have better performance than select_experts.
|
| 122 |
+
# Subsequent versions will remove the SELECT_GATING_TOPK_SOTFMAX_EXPERTS tag and use it as the default mode.
|
| 123 |
+
"SELECT_GATING_TOPK_SOTFMAX_EXPERTS":
|
| 124 |
+
lambda: bool(int(os.getenv("SELECT_GATING_TOPK_SOTFMAX_EXPERTS", '0'))),
|
| 125 |
+
# The tolerance of the kv cache size, if the difference between the
|
| 126 |
+
# actual kv cache size and the cached kv cache size is less than this value,
|
| 127 |
+
# then the cached kv cache size will be used.
|
| 128 |
+
"VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE":
|
| 129 |
+
lambda: int(
|
| 130 |
+
os.getenv("VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", 64)),
|
| 131 |
+
# Whether to enable the topk optimization. It's disabled by default for experimental support
|
| 132 |
+
# We'll make it enabled by default in the future.
|
| 133 |
+
"VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION":
|
| 134 |
+
lambda: bool(
|
| 135 |
+
int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '0'))),
|
| 136 |
+
# Whether to enable top n sigma sampling
|
| 137 |
+
"VLLM_ASCEND_ENABLE_TOP_N_SIGMA":
|
| 138 |
+
lambda: bool(
|
| 139 |
+
int(os.getenv("VLLM_ASCEND_ENABLE_TOP_N_SIGMA", '0'))),
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
# end-env-vars-definition
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def __getattr__(name: str):
|
| 146 |
+
# lazy evaluation of environment variables
|
| 147 |
+
if name in env_variables:
|
| 148 |
+
return env_variables[name]()
|
| 149 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def __dir__():
|
| 153 |
+
return list(env_variables.keys())
|
inference/vllm_ascend/models/__init__.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from vllm import ModelRegistry
|
| 2 |
+
|
| 3 |
+
import vllm_ascend.envs as envs
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def register_model():
|
| 7 |
+
from .deepseek_dbo import CustomDeepseekDBOForCausalLM # noqa: F401
|
| 8 |
+
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
|
| 9 |
+
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
|
| 10 |
+
from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
|
| 11 |
+
from .open_pangu import PanguUltraMoEForCausalLM # noqa: F401
|
| 12 |
+
from .open_pangu import PanguEmbeddedForCausalLM # noqa: F401
|
| 13 |
+
from .qwen2_5_vl import \
|
| 14 |
+
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
|
| 15 |
+
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
|
| 16 |
+
|
| 17 |
+
ModelRegistry.register_model(
|
| 18 |
+
"DeepSeekMTPModel",
|
| 19 |
+
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
|
| 20 |
+
|
| 21 |
+
ModelRegistry.register_model(
|
| 22 |
+
"Qwen2VLForConditionalGeneration",
|
| 23 |
+
"vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")
|
| 24 |
+
|
| 25 |
+
if envs.USE_OPTIMIZED_MODEL:
|
| 26 |
+
ModelRegistry.register_model(
|
| 27 |
+
"Qwen2_5_VLForConditionalGeneration",
|
| 28 |
+
"vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration"
|
| 29 |
+
)
|
| 30 |
+
else:
|
| 31 |
+
ModelRegistry.register_model(
|
| 32 |
+
"Qwen2_5_VLForConditionalGeneration",
|
| 33 |
+
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
if envs.VLLM_ASCEND_ENABLE_DBO:
|
| 37 |
+
ModelRegistry.register_model(
|
| 38 |
+
"DeepseekV2ForCausalLM",
|
| 39 |
+
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
|
| 40 |
+
|
| 41 |
+
ModelRegistry.register_model(
|
| 42 |
+
"DeepseekV3ForCausalLM",
|
| 43 |
+
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
|
| 44 |
+
|
| 45 |
+
else:
|
| 46 |
+
ModelRegistry.register_model(
|
| 47 |
+
"DeepseekV2ForCausalLM",
|
| 48 |
+
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
|
| 49 |
+
|
| 50 |
+
ModelRegistry.register_model(
|
| 51 |
+
"DeepseekV3ForCausalLM",
|
| 52 |
+
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
|
| 53 |
+
|
| 54 |
+
ModelRegistry.register_model(
|
| 55 |
+
"Qwen3MoeForCausalLM",
|
| 56 |
+
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
| 57 |
+
|
| 58 |
+
ModelRegistry.register_model(
|
| 59 |
+
"PanguProMoEForCausalLM",
|
| 60 |
+
"vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM")
|
| 61 |
+
|
| 62 |
+
ModelRegistry.register_model(
|
| 63 |
+
"PanguUltraMoEForCausalLM",
|
| 64 |
+
"vllm_ascend.models.open_pangu:PanguUltraMoEForCausalLM")
|
| 65 |
+
|
| 66 |
+
ModelRegistry.register_model(
|
| 67 |
+
"PanguEmbeddedForCausalLM",
|
| 68 |
+
"vllm_ascend.models.open_pangu:PanguEmbeddedForCausalLM")
|
inference/vllm_ascend/models/open_pangu.py
ADDED
|
@@ -0,0 +1,1127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# Copyright 2023 The vLLM team.
|
| 4 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 7 |
+
# and OPT implementations in this library. It has been modified from its
|
| 8 |
+
# original forms to accommodate minor architectural differences compared
|
| 9 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 10 |
+
#
|
| 11 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 12 |
+
# you may not use this file except in compliance with the License.
|
| 13 |
+
# You may obtain a copy of the License at
|
| 14 |
+
#
|
| 15 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 16 |
+
#
|
| 17 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 18 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 19 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 20 |
+
# See the License for the specific language governing permissions and
|
| 21 |
+
# limitations under the License.
|
| 22 |
+
|
| 23 |
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
| 24 |
+
import torch
|
| 25 |
+
import torch_npu
|
| 26 |
+
import vllm.envs as envs
|
| 27 |
+
from torch import nn
|
| 28 |
+
from transformers import PretrainedConfig
|
| 29 |
+
from vllm.compilation.decorators import support_torch_compile
|
| 30 |
+
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
| 31 |
+
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
| 32 |
+
from vllm.distributed import (get_tensor_model_parallel_rank,
|
| 33 |
+
get_tensor_model_parallel_world_size,
|
| 34 |
+
get_tp_group, split_tensor_along_last_dim,
|
| 35 |
+
tensor_model_parallel_all_gather,
|
| 36 |
+
tensor_model_parallel_all_reduce,
|
| 37 |
+
tensor_model_parallel_reduce_scatter)
|
| 38 |
+
from vllm.distributed.parallel_state import get_dp_group
|
| 39 |
+
from vllm.forward_context import get_forward_context
|
| 40 |
+
from vllm.model_executor.layers.activation import SiluAndMul
|
| 41 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 42 |
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
| 43 |
+
MergedColumnParallelLinear,
|
| 44 |
+
ReplicatedLinear,
|
| 45 |
+
RowParallelLinear,
|
| 46 |
+
UnquantizedLinearMethod,
|
| 47 |
+
QKVParallelLinear)
|
| 48 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 49 |
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
| 50 |
+
from vllm.model_executor.layers.rotary_embedding import get_rope, _rotate_gptj
|
| 51 |
+
from vllm.model_executor.layers.sampler import get_sampler
|
| 52 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
| 53 |
+
ParallelLMHead, VocabParallelEmbedding)
|
| 54 |
+
from vllm.model_executor.model_loader.weight_utils import (
|
| 55 |
+
default_weight_loader, maybe_remap_kv_scale_name)
|
| 56 |
+
from vllm.model_executor.models.utils import (
|
| 57 |
+
make_layers, maybe_prefix, extract_layer_index)
|
| 58 |
+
from vllm_ascend.ascend_config import get_ascend_config
|
| 59 |
+
from vllm_ascend.distributed.parallel_state import get_ep_group
|
| 60 |
+
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
| 61 |
+
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
| 62 |
+
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
| 63 |
+
from vllm_ascend.utils import dispose_tensor, npu_prefetch, get_fused_moe_state, FusedMoEState
|
| 64 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class OpenPanguMergedReplicatedLinear(ReplicatedLinear):
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
input_size: int,
|
| 72 |
+
output_sizes: list[int],
|
| 73 |
+
bias: bool = True,
|
| 74 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 75 |
+
prefix: str = "",
|
| 76 |
+
):
|
| 77 |
+
self.output_sizes = output_sizes
|
| 78 |
+
super().__init__(input_size,
|
| 79 |
+
sum(output_sizes),
|
| 80 |
+
bias=bias,
|
| 81 |
+
quant_config=quant_config,
|
| 82 |
+
prefix=prefix)
|
| 83 |
+
|
| 84 |
+
def weight_loader(self, param: torch.nn.Parameter,
|
| 85 |
+
loaded_weight: torch.Tensor, loaded_shard_id: int):
|
| 86 |
+
# With no support for GGUF format yet.
|
| 87 |
+
if getattr(param, "is_gguf_weight", False) or getattr(param, "is_gguf_weight_type", False):
|
| 88 |
+
raise ValueError('With no support for GGUF format yet.')
|
| 89 |
+
if loaded_shard_id >= len(self.output_sizes):
|
| 90 |
+
raise ValueError(f'loaded_shard_id {loaded_shard_id} >= len(self.output_sizes) {len(self.output_sizes)}.')
|
| 91 |
+
shard_offset = sum(self.output_sizes[:loaded_shard_id])
|
| 92 |
+
shard_size = self.output_sizes[loaded_shard_id]
|
| 93 |
+
shard = param.data.narrow(param.output_dim, shard_offset, shard_size)
|
| 94 |
+
if shard.size() != loaded_weight.size():
|
| 95 |
+
raise ValueError(f"Tried to load weights of size {loaded_weight.size()} "
|
| 96 |
+
f"to a parameter shard of id {loaded_shard_id} size {shard.size()}.")
|
| 97 |
+
shard.copy_(loaded_weight)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class OpenPanguRowParallelLinearReplaceAllreduce(RowParallelLinear):
|
| 101 |
+
|
| 102 |
+
def forward(
|
| 103 |
+
self,
|
| 104 |
+
input_,
|
| 105 |
+
is_prefill=True
|
| 106 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
|
| 107 |
+
if self.input_is_parallel:
|
| 108 |
+
input_parallel = input_
|
| 109 |
+
else:
|
| 110 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 111 |
+
splitted_input = split_tensor_along_last_dim(
|
| 112 |
+
input_, num_partitions=self.tp_size)
|
| 113 |
+
input_parallel = splitted_input[tp_rank].contiguous()
|
| 114 |
+
|
| 115 |
+
# Matrix multiply.
|
| 116 |
+
if self.quant_method is None:
|
| 117 |
+
raise ValueError('self.quant_method is None.')
|
| 118 |
+
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
| 119 |
+
# bias will not get added more than once in TP>1 case)
|
| 120 |
+
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
| 121 |
+
output_parallel = self.quant_method.apply(self,
|
| 122 |
+
input_parallel,
|
| 123 |
+
bias=bias_)
|
| 124 |
+
if self.reduce_results and self.tp_size > 1:
|
| 125 |
+
if not is_prefill and output_parallel.shape[0] % self.tp_size == 0:
|
| 126 |
+
output = tensor_model_parallel_reduce_scatter(output_parallel,
|
| 127 |
+
dim=0)
|
| 128 |
+
else:
|
| 129 |
+
output = tensor_model_parallel_all_reduce(output_parallel)
|
| 130 |
+
else:
|
| 131 |
+
output = output_parallel
|
| 132 |
+
|
| 133 |
+
output_bias = self.bias if self.skip_bias_add else None
|
| 134 |
+
|
| 135 |
+
if not self.return_bias:
|
| 136 |
+
return output
|
| 137 |
+
return output, output_bias
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class OpenPanguRowParallelLinear(RowParallelLinear):
|
| 141 |
+
|
| 142 |
+
def forward(
|
| 143 |
+
self,
|
| 144 |
+
input_,
|
| 145 |
+
is_prefill=True
|
| 146 |
+
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
|
| 147 |
+
return super().forward(input_)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class OpenPanguRotaryEmbedding(nn.Module):
|
| 151 |
+
def __init__(self,
|
| 152 |
+
head_size: int,
|
| 153 |
+
rotary_dim: int,
|
| 154 |
+
max_position_embeddings: int,
|
| 155 |
+
base: float,
|
| 156 |
+
):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.dim = rotary_dim
|
| 159 |
+
self.max_position_embeddings = max_position_embeddings
|
| 160 |
+
self.base = base
|
| 161 |
+
self._set_cos_sin_cache(
|
| 162 |
+
seq_len=max_position_embeddings,
|
| 163 |
+
device='npu',
|
| 164 |
+
dtype=torch.get_default_dtype(),
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def _set_cos_sin_cache(self,
|
| 168 |
+
seq_len: int,
|
| 169 |
+
device: str,
|
| 170 |
+
dtype: torch.dtype
|
| 171 |
+
):
|
| 172 |
+
self.max_seq_len = seq_len
|
| 173 |
+
inv_freq = 1.0 / (
|
| 174 |
+
self.base
|
| 175 |
+
** (torch.arange(0, self.dim, 2, dtype=torch.float32, device='npu') / self.dim)
|
| 176 |
+
)
|
| 177 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 178 |
+
t = torch.arange(seq_len, device='npu', dtype=torch.float32)
|
| 179 |
+
freqs = torch.outer(t, inv_freq)
|
| 180 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 181 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 182 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 183 |
+
|
| 184 |
+
def forward(self,
|
| 185 |
+
positions: torch.Tensor,
|
| 186 |
+
query: torch.Tensor,
|
| 187 |
+
key: torch.Tensor,
|
| 188 |
+
offsets: Optional[torch.Tensor] = None,
|
| 189 |
+
max_seq_len: Optional[int] = None,
|
| 190 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 191 |
+
if max_seq_len is not None and max_seq_len > self.max_seq_len:
|
| 192 |
+
self._set_cos_sin_cache(max_seq_len, query.device, query.dtype)
|
| 193 |
+
idx = torch.add(positions, offsets) if offsets is not None else positions
|
| 194 |
+
cos = self.cos_cached[idx]
|
| 195 |
+
sin = self.sin_cached[idx]
|
| 196 |
+
# Adapt: adapt cos and sin shape
|
| 197 |
+
cos = cos.view(-1, 1, cos.shape[-1])
|
| 198 |
+
sin = sin.view(-1, 1, sin.shape[-1])
|
| 199 |
+
# Adapt end.
|
| 200 |
+
query_rot = query * cos + _rotate_gptj(query) * sin
|
| 201 |
+
if key is not None:
|
| 202 |
+
key_rot = key * cos + _rotate_gptj(key) * sin
|
| 203 |
+
return query_rot, key_rot
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class OpenPanguSiluAndMul(SiluAndMul):
|
| 207 |
+
|
| 208 |
+
def __init__(self,
|
| 209 |
+
*,
|
| 210 |
+
weight_scale: Optional[Callable[[], torch.Tensor]] = None):
|
| 211 |
+
super().__init__()
|
| 212 |
+
self.weight_scale = weight_scale
|
| 213 |
+
|
| 214 |
+
def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
|
| 215 |
+
torch.Tensor]]):
|
| 216 |
+
if isinstance(x, tuple):
|
| 217 |
+
if self.weight_scale is None:
|
| 218 |
+
raise ValueError('self.weight_scale is None.')
|
| 219 |
+
quantized_x, dynamic_scale = x
|
| 220 |
+
return torch_npu.npu_dequant_swiglu_quant(
|
| 221 |
+
x=quantized_x,
|
| 222 |
+
weight_scale=self.weight_scale(),
|
| 223 |
+
activation_scale=dynamic_scale,
|
| 224 |
+
activate_left=True,
|
| 225 |
+
quant_mode=1)
|
| 226 |
+
else:
|
| 227 |
+
return super().forward_oot(x)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def check_ffn_act_fn(act_fn: str):
|
| 231 |
+
if act_fn != "silu":
|
| 232 |
+
raise ValueError(
|
| 233 |
+
f"Unsupported activation: {act_fn}. Only silu is supported for now.")
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class OpenPanguMLP(nn.Module):
|
| 237 |
+
|
| 238 |
+
def __init__(
|
| 239 |
+
self,
|
| 240 |
+
hidden_size: int,
|
| 241 |
+
intermediate_size: int,
|
| 242 |
+
hidden_act: str,
|
| 243 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 244 |
+
bias: bool = False,
|
| 245 |
+
reduce_results: bool = True,
|
| 246 |
+
force_replicate: bool = False,
|
| 247 |
+
prefix: str = "",
|
| 248 |
+
) -> None:
|
| 249 |
+
super().__init__()
|
| 250 |
+
if not force_replicate:
|
| 251 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
| 252 |
+
hidden_size, [intermediate_size] * 2,
|
| 253 |
+
bias=bias,
|
| 254 |
+
quant_config=quant_config,
|
| 255 |
+
prefix=f"{prefix}.gate_up_proj")
|
| 256 |
+
self.down_proj = RowParallelLinear(intermediate_size,
|
| 257 |
+
hidden_size,
|
| 258 |
+
bias=bias,
|
| 259 |
+
quant_config=quant_config,
|
| 260 |
+
reduce_results=reduce_results,
|
| 261 |
+
prefix=f"{prefix}.down_proj")
|
| 262 |
+
else:
|
| 263 |
+
self.gate_up_proj = OpenPanguMergedReplicatedLinear(
|
| 264 |
+
hidden_size, [intermediate_size] * 2,
|
| 265 |
+
bias=bias,
|
| 266 |
+
quant_config=quant_config,
|
| 267 |
+
prefix=f"{prefix}.gate_up_proj")
|
| 268 |
+
self.down_proj = ReplicatedLinear(intermediate_size,
|
| 269 |
+
hidden_size,
|
| 270 |
+
bias=bias,
|
| 271 |
+
quant_config=quant_config,
|
| 272 |
+
prefix=f"{prefix}.down_proj")
|
| 273 |
+
|
| 274 |
+
check_ffn_act_fn(hidden_act)
|
| 275 |
+
|
| 276 |
+
quant_method = self.gate_up_proj.quant_method
|
| 277 |
+
if isinstance(quant_method, UnquantizedLinearMethod):
|
| 278 |
+
self.act_fn = OpenPanguSiluAndMul()
|
| 279 |
+
elif (isinstance(quant_method, AscendLinearMethod) and isinstance(
|
| 280 |
+
quant_method.quant_method, AscendW8A8DynamicLinearMethod)):
|
| 281 |
+
# TODO(sdmyzlp): Currently preserved as before:
|
| 282 |
+
# 1. The only quantization supported for silu is W8A8Dynamic
|
| 283 |
+
# 2. Output dtype of gate_up/down is fixed to be int32/bfloat16
|
| 284 |
+
#
|
| 285 |
+
# Maybe one can implement a better and more general configuration
|
| 286 |
+
# scheme, e.g. by somehow passing around the tweaked `quant_config`
|
| 287 |
+
self.act_fn = OpenPanguSiluAndMul(
|
| 288 |
+
# Use lazy binding, for `weight_scale_fp32` is accessible
|
| 289 |
+
# only after `process_weights_after_loading`.
|
| 290 |
+
weight_scale=lambda: self.gate_up_proj.weight_scale_fp32)
|
| 291 |
+
# To be consumed by AscendW8A8DynamicLinearMethod.apply()
|
| 292 |
+
self.gate_up_proj._ascend_quant_config = {
|
| 293 |
+
"output_dtype": torch.int32,
|
| 294 |
+
"pertoken_scale": False,
|
| 295 |
+
"return_scale": True,
|
| 296 |
+
}
|
| 297 |
+
self.down_proj._ascend_quant_config = {
|
| 298 |
+
"output_dtype": torch.bfloat16,
|
| 299 |
+
"pertoken_scale": True,
|
| 300 |
+
"return_scale": False,
|
| 301 |
+
}
|
| 302 |
+
else:
|
| 303 |
+
raise NotImplementedError(
|
| 304 |
+
f"Quantization with [{type(quant_method)}] is NOT supported")
|
| 305 |
+
|
| 306 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 307 |
+
return self.down_proj(self.act_fn(self.gate_up_proj(x)[0]))[0]
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class OpenPanguMoE(nn.Module):
|
| 311 |
+
|
| 312 |
+
top_k: int
|
| 313 |
+
|
| 314 |
+
def __init__(
|
| 315 |
+
self,
|
| 316 |
+
config: PretrainedConfig,
|
| 317 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 318 |
+
prefix: str = "",
|
| 319 |
+
):
|
| 320 |
+
super().__init__()
|
| 321 |
+
ascend_config = get_ascend_config()
|
| 322 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 323 |
+
self.enable_multistream_moe = \
|
| 324 |
+
ascend_config.torchair_graph_config.enable_multistream_moe
|
| 325 |
+
self.routed_scaling_factor = config.routed_scaling_factor
|
| 326 |
+
check_ffn_act_fn(config.hidden_act)
|
| 327 |
+
|
| 328 |
+
self.gate = ReplicatedLinear(config.hidden_size,
|
| 329 |
+
config.num_routed_experts,
|
| 330 |
+
bias=False,
|
| 331 |
+
quant_config=None,
|
| 332 |
+
prefix=f"{prefix}.gate")
|
| 333 |
+
|
| 334 |
+
self.experts = AscendFusedMoE(
|
| 335 |
+
num_experts=config.num_routed_experts,
|
| 336 |
+
top_k=config.num_experts_per_tok,
|
| 337 |
+
hidden_size=config.hidden_size,
|
| 338 |
+
intermediate_size=config.moe_intermediate_size,
|
| 339 |
+
reduce_results=False,
|
| 340 |
+
renormalize=config.norm_topk_prob,
|
| 341 |
+
quant_config=quant_config,
|
| 342 |
+
use_grouped_topk=True,
|
| 343 |
+
num_expert_group=1,
|
| 344 |
+
topk_group=1,
|
| 345 |
+
prefix=f"{prefix}.experts",
|
| 346 |
+
scoring_func='sigmoid',
|
| 347 |
+
e_score_correction_bias=None)
|
| 348 |
+
|
| 349 |
+
if config.num_shared_experts is not None:
|
| 350 |
+
self.all_reduce_merge = self.experts.all_reduce_merge
|
| 351 |
+
reduce_results = not self.all_reduce_merge
|
| 352 |
+
intermediate_size = (config.moe_intermediate_size * config.num_shared_experts)
|
| 353 |
+
self.shared_experts = OpenPanguMLP(
|
| 354 |
+
hidden_size=config.hidden_size,
|
| 355 |
+
intermediate_size=intermediate_size,
|
| 356 |
+
hidden_act=config.hidden_act,
|
| 357 |
+
quant_config=quant_config,
|
| 358 |
+
reduce_results=reduce_results,
|
| 359 |
+
force_replicate=self.enable_multistream_moe,
|
| 360 |
+
prefix=f"{prefix}.shared_experts",
|
| 361 |
+
)
|
| 362 |
+
else:
|
| 363 |
+
self.shared_experts = None # type: ignore
|
| 364 |
+
|
| 365 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 366 |
+
self.dp_size = get_dp_group().world_size
|
| 367 |
+
self.tp_group = get_tp_group().device_group
|
| 368 |
+
self.tp_rank = get_tp_group().rank_in_group
|
| 369 |
+
self.ep_group = get_ep_group()
|
| 370 |
+
|
| 371 |
+
self.params_dtype = torch.get_default_dtype()
|
| 372 |
+
self.rm_router_logits = self.experts.rm_router_logits
|
| 373 |
+
|
| 374 |
+
self.__class__.top_k = config.num_experts_per_tok
|
| 375 |
+
|
| 376 |
+
def forward(self,
|
| 377 |
+
hidden_states: torch.Tensor,
|
| 378 |
+
attn_metadata: Optional[AttentionMetadata] = None,
|
| 379 |
+
replace_allreduce: bool = False) -> torch.Tensor:
|
| 380 |
+
|
| 381 |
+
if attn_metadata is None:
|
| 382 |
+
attn_metadata = get_forward_context().attn_metadata
|
| 383 |
+
# when profile runs, force experts to load balanced tokens
|
| 384 |
+
# to avoid high memory consumption on a single rank.
|
| 385 |
+
# TODO: need a better flag to indicate whether in profile run or not.
|
| 386 |
+
if attn_metadata is None:
|
| 387 |
+
# for profile run
|
| 388 |
+
is_prefill = True
|
| 389 |
+
fused_moe_state = get_fused_moe_state(self.ep_group.world_size, is_prefill, True)
|
| 390 |
+
enable_force_load_balance = fused_moe_state != FusedMoEState.AllGatherEP
|
| 391 |
+
else:
|
| 392 |
+
is_prefill = attn_metadata.num_prefills > 0
|
| 393 |
+
enable_force_load_balance = False
|
| 394 |
+
if hasattr(attn_metadata, 'with_prefill_across_dp'):
|
| 395 |
+
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
|
| 396 |
+
fused_moe_state = get_fused_moe_state(self.ep_group.world_size, is_prefill, True)
|
| 397 |
+
|
| 398 |
+
# router_logits: (num_tokens, n_experts)
|
| 399 |
+
router_logits = None
|
| 400 |
+
if not self.rm_router_logits or fused_moe_state == FusedMoEState.All2All:
|
| 401 |
+
router_logits, _ = self.gate(hidden_states.float())
|
| 402 |
+
|
| 403 |
+
routed_hidden_states, shared_hidden_states = self.experts(
|
| 404 |
+
hidden_states=hidden_states,
|
| 405 |
+
router_logits=router_logits,
|
| 406 |
+
is_prefill=is_prefill,
|
| 407 |
+
top_k=self.__class__.top_k,
|
| 408 |
+
enable_force_load_balance=enable_force_load_balance,
|
| 409 |
+
shared_experts=self.shared_experts,
|
| 410 |
+
gate=self.gate,
|
| 411 |
+
replace_allreduce=replace_allreduce)
|
| 412 |
+
|
| 413 |
+
if self.all_reduce_merge and fused_moe_state == FusedMoEState.All2All:
|
| 414 |
+
shared_hidden_states = tensor_model_parallel_all_reduce(shared_hidden_states)
|
| 415 |
+
hidden_states = routed_hidden_states * self.routed_scaling_factor + shared_hidden_states
|
| 416 |
+
if self.all_reduce_merge and fused_moe_state != FusedMoEState.All2All:
|
| 417 |
+
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
| 418 |
+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
| 419 |
+
|
| 420 |
+
return hidden_states
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
class OpenPanguMLAAttention(nn.Module):
|
| 424 |
+
|
| 425 |
+
def __init__(
|
| 426 |
+
self,
|
| 427 |
+
config: PretrainedConfig,
|
| 428 |
+
hidden_size: int,
|
| 429 |
+
num_heads: int,
|
| 430 |
+
attention_qk_dim: int,
|
| 431 |
+
attention_qk_rope_dim: int,
|
| 432 |
+
attention_v_dim: int,
|
| 433 |
+
attention_q_lora_dim: Optional[int],
|
| 434 |
+
attention_kv_lora_dim: int,
|
| 435 |
+
rope_theta: float = 10000,
|
| 436 |
+
max_position_embeddings: int = 8192,
|
| 437 |
+
cache_config: Optional[CacheConfig] = None,
|
| 438 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 439 |
+
prefix: str = "",
|
| 440 |
+
) -> None:
|
| 441 |
+
super().__init__()
|
| 442 |
+
ascend_config = get_ascend_config()
|
| 443 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 444 |
+
self.enable_multistream_mla = ascend_config.torchair_graph_config.enable_multistream_mla
|
| 445 |
+
|
| 446 |
+
self.hidden_size = hidden_size
|
| 447 |
+
self.num_heads = num_heads
|
| 448 |
+
self.attention_qk_dim = attention_qk_dim
|
| 449 |
+
self.attention_qk_rope_dim = attention_qk_rope_dim
|
| 450 |
+
self.qk_head_dim = attention_qk_dim + attention_qk_rope_dim
|
| 451 |
+
self.attention_v_dim = attention_v_dim
|
| 452 |
+
self.attention_q_lora_dim = attention_q_lora_dim
|
| 453 |
+
self.attention_kv_lora_dim = attention_kv_lora_dim
|
| 454 |
+
self.rope_theta = rope_theta
|
| 455 |
+
|
| 456 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 457 |
+
if num_heads % tp_size != 0:
|
| 458 |
+
raise ValueError(f'num_heads {num_heads} is not divisible by tp_size {tp_size}.')
|
| 459 |
+
self.num_local_heads = num_heads // tp_size
|
| 460 |
+
|
| 461 |
+
self.scaling = self.qk_head_dim**-0.5
|
| 462 |
+
self.max_position_embeddings = max_position_embeddings
|
| 463 |
+
|
| 464 |
+
self.prefix = prefix
|
| 465 |
+
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
| 466 |
+
|
| 467 |
+
if self.attention_q_lora_dim is not None:
|
| 468 |
+
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
| 469 |
+
self.attention_q_lora_dim,
|
| 470 |
+
bias=False,
|
| 471 |
+
quant_config=quant_config,
|
| 472 |
+
prefix=f"{prefix}.q_a_proj")
|
| 473 |
+
self.q_a_layernorm = RMSNorm(self.attention_q_lora_dim, eps=config.rms_norm_eps)
|
| 474 |
+
self.q_b_proj = ColumnParallelLinear(attention_q_lora_dim,
|
| 475 |
+
self.num_heads * self.qk_head_dim,
|
| 476 |
+
bias=False,
|
| 477 |
+
quant_config=quant_config,
|
| 478 |
+
prefix=f"{prefix}.q_b_proj")
|
| 479 |
+
else:
|
| 480 |
+
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
| 481 |
+
self.num_heads * self.qk_head_dim,
|
| 482 |
+
bias=False,
|
| 483 |
+
quant_config=quant_config,
|
| 484 |
+
prefix=f"{prefix}.q_proj")
|
| 485 |
+
|
| 486 |
+
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
| 487 |
+
self.hidden_size,
|
| 488 |
+
self.attention_kv_lora_dim + self.attention_qk_rope_dim,
|
| 489 |
+
bias=False,
|
| 490 |
+
quant_config=quant_config,
|
| 491 |
+
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
| 492 |
+
self.kv_a_layernorm = RMSNorm(self.attention_kv_lora_dim,
|
| 493 |
+
eps=config.rms_norm_eps)
|
| 494 |
+
self.kv_b_proj = ColumnParallelLinear(
|
| 495 |
+
self.attention_kv_lora_dim,
|
| 496 |
+
self.num_heads * (self.attention_qk_dim + self.attention_v_dim),
|
| 497 |
+
bias=False,
|
| 498 |
+
quant_config=quant_config,
|
| 499 |
+
prefix=f"{prefix}.kv_b_proj")
|
| 500 |
+
if (config.num_routed_experts is not None
|
| 501 |
+
and self.debug_layer_idx >= config.num_dense_layers and
|
| 502 |
+
ascend_config.torchair_graph_config.enable_multistream_moe):
|
| 503 |
+
self.o_proj = OpenPanguRowParallelLinearReplaceAllreduce(
|
| 504 |
+
self.num_heads * self.attention_v_dim,
|
| 505 |
+
self.hidden_size,
|
| 506 |
+
bias=False,
|
| 507 |
+
quant_config=quant_config,
|
| 508 |
+
prefix=f"{prefix}.o_proj")
|
| 509 |
+
else:
|
| 510 |
+
self.o_proj = OpenPanguRowParallelLinear(
|
| 511 |
+
self.num_heads * self.attention_v_dim,
|
| 512 |
+
self.hidden_size,
|
| 513 |
+
bias=False,
|
| 514 |
+
quant_config=quant_config,
|
| 515 |
+
prefix=f"{prefix}.o_proj")
|
| 516 |
+
|
| 517 |
+
self.rotary_emb = OpenPanguRotaryEmbedding(attention_qk_rope_dim,
|
| 518 |
+
rotary_dim=attention_qk_rope_dim,
|
| 519 |
+
max_position_embeddings=max_position_embeddings,
|
| 520 |
+
base=rope_theta)
|
| 521 |
+
|
| 522 |
+
self.mla_attn = Attention(
|
| 523 |
+
num_heads=self.num_local_heads,
|
| 524 |
+
head_size=self.attention_kv_lora_dim + self.attention_qk_rope_dim,
|
| 525 |
+
scale=self.scaling,
|
| 526 |
+
num_kv_heads=1,
|
| 527 |
+
cache_config=cache_config,
|
| 528 |
+
quant_config=quant_config,
|
| 529 |
+
prefix=f"{prefix}.attn",
|
| 530 |
+
use_mla=True,
|
| 531 |
+
# MLA Args
|
| 532 |
+
q_lora_rank=self.attention_q_lora_dim,
|
| 533 |
+
kv_lora_rank=self.attention_kv_lora_dim,
|
| 534 |
+
qk_nope_head_dim=self.attention_qk_dim,
|
| 535 |
+
qk_rope_head_dim=self.attention_qk_rope_dim,
|
| 536 |
+
qk_head_dim=self.qk_head_dim,
|
| 537 |
+
v_head_dim=self.attention_v_dim,
|
| 538 |
+
rotary_emb=self.rotary_emb,
|
| 539 |
+
q_proj=self.q_proj if self.attention_q_lora_dim is None else self.q_b_proj,
|
| 540 |
+
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
| 541 |
+
kv_a_layernorm=self.kv_a_layernorm,
|
| 542 |
+
kv_b_proj=self.kv_b_proj,
|
| 543 |
+
o_proj=self.o_proj,
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
def forward(
|
| 547 |
+
self,
|
| 548 |
+
positions: torch.Tensor,
|
| 549 |
+
hidden_states: torch.Tensor,
|
| 550 |
+
kv_cache: Optional[torch.Tensor] = None,
|
| 551 |
+
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
| 552 |
+
enable_multistream_mla = (self.enable_multistream_mla
|
| 553 |
+
and attn_metadata is not None
|
| 554 |
+
and not attn_metadata.with_prefill_across_dp
|
| 555 |
+
and attn_metadata.num_decodes > 0)
|
| 556 |
+
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
|
| 557 |
+
if self.attention_q_lora_dim is not None:
|
| 558 |
+
npu_prefetch(self.q_a_proj.weight,
|
| 559 |
+
hidden_states,
|
| 560 |
+
enabled=enable_multistream_mla)
|
| 561 |
+
ckq = self.q_a_proj(hidden_states)[0]
|
| 562 |
+
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
| 563 |
+
forward_kwargs['ckq'] = ckq
|
| 564 |
+
else:
|
| 565 |
+
hidden_states_or_q_c = hidden_states
|
| 566 |
+
if self.torchair_graph_enabled:
|
| 567 |
+
if envs.VLLM_USE_V1:
|
| 568 |
+
output_shape = hidden_states.shape
|
| 569 |
+
output = torch.empty(output_shape,
|
| 570 |
+
dtype=hidden_states_or_q_c.dtype,
|
| 571 |
+
device=hidden_states_or_q_c.device)
|
| 572 |
+
forward_kwargs['output'] = output
|
| 573 |
+
|
| 574 |
+
output = self.mla_attn.impl.forward(self.mla_attn,
|
| 575 |
+
hidden_states_or_q_c,
|
| 576 |
+
hidden_states, None, kv_cache,
|
| 577 |
+
attn_metadata,
|
| 578 |
+
**forward_kwargs)
|
| 579 |
+
if envs.VLLM_USE_V1:
|
| 580 |
+
output = output.view(-1, output_shape[-1])
|
| 581 |
+
return output
|
| 582 |
+
else:
|
| 583 |
+
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
| 584 |
+
[self.attention_kv_lora_dim, self.attention_qk_rope_dim], dim=-1)
|
| 585 |
+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
| 586 |
+
return self.mla_attn(hidden_states_or_q_c,
|
| 587 |
+
kv_c_normed,
|
| 588 |
+
k_pe,
|
| 589 |
+
output_shape=hidden_states.shape)
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
class OpenPanguEmbeddedAttention(nn.Module):
|
| 593 |
+
|
| 594 |
+
def __init__(
|
| 595 |
+
self,
|
| 596 |
+
config: PretrainedConfig,
|
| 597 |
+
hidden_size: int,
|
| 598 |
+
num_heads: int,
|
| 599 |
+
num_kv_heads: int,
|
| 600 |
+
rope_theta: float = 10000,
|
| 601 |
+
rope_scaling: Optional[dict[str, Any]] = None,
|
| 602 |
+
max_position_embeddings: int = 8192,
|
| 603 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 604 |
+
bias: bool = False,
|
| 605 |
+
bias_o_proj: bool = False,
|
| 606 |
+
cache_config: Optional[CacheConfig] = None,
|
| 607 |
+
prefix: str = "",
|
| 608 |
+
attn_type: str = AttentionType.DECODER,
|
| 609 |
+
) -> None:
|
| 610 |
+
super().__init__()
|
| 611 |
+
layer_idx = extract_layer_index(prefix)
|
| 612 |
+
self.hidden_size = hidden_size
|
| 613 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 614 |
+
self.total_num_heads = num_heads
|
| 615 |
+
if self.total_num_heads % tp_size != 0:
|
| 616 |
+
raise ValueError(f'total_num_heads {total_num_heads} is not divisible by tp_size {tp_size}.')
|
| 617 |
+
self.num_heads = self.total_num_heads // tp_size
|
| 618 |
+
self.total_num_kv_heads = num_kv_heads
|
| 619 |
+
if self.total_num_kv_heads >= tp_size and self.total_num_kv_heads % tp_size != 0:
|
| 620 |
+
# Number of KV heads is greater than TP size, so we partition
|
| 621 |
+
# the KV heads across multiple tensor parallel NPUs.
|
| 622 |
+
raise ValueError(f'Number of KV heads is less than TP size, but total_num_kv_heads {self.total_num_kv_heads} '
|
| 623 |
+
f'is not divisible by tp_size {tp_size}.')
|
| 624 |
+
elif self.total_num_kv_heads < tp_size and tp_size % self.total_num_kv_heads != 0:
|
| 625 |
+
# Number of KV heads is less than TP size, so we replicate
|
| 626 |
+
# the KV heads across multiple tensor parallel NPUs.
|
| 627 |
+
raise ValueError(f'Number of KV heads is less than TP size, but tp_size {tp_size} '
|
| 628 |
+
f'is not divisible by total_num_kv_heads {self.total_num_kv_heads}.')
|
| 629 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
| 630 |
+
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
| 631 |
+
head_dim = getattr(config, "head_dim", None)
|
| 632 |
+
if head_dim is None:
|
| 633 |
+
head_dim = self.hidden_size // self.total_num_heads
|
| 634 |
+
self.head_dim = head_dim
|
| 635 |
+
# Phi models introduced a partial_rotary_factor parameter in the config
|
| 636 |
+
self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
|
| 637 |
+
self.q_size = self.num_heads * self.head_dim
|
| 638 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
| 639 |
+
self.scaling = self.head_dim**-0.5
|
| 640 |
+
self.rope_theta = rope_theta
|
| 641 |
+
self.max_position_embeddings = max_position_embeddings
|
| 642 |
+
|
| 643 |
+
self.qkv_proj = QKVParallelLinear(
|
| 644 |
+
hidden_size=hidden_size,
|
| 645 |
+
head_size=self.head_dim,
|
| 646 |
+
total_num_heads=self.total_num_heads,
|
| 647 |
+
total_num_kv_heads=self.total_num_kv_heads,
|
| 648 |
+
bias=bias,
|
| 649 |
+
quant_config=quant_config,
|
| 650 |
+
prefix=f"{prefix}.qkv_proj",
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
self.o_proj = RowParallelLinear(
|
| 654 |
+
input_size=self.total_num_heads * self.head_dim,
|
| 655 |
+
output_size=hidden_size,
|
| 656 |
+
bias=bias_o_proj,
|
| 657 |
+
quant_config=quant_config,
|
| 658 |
+
prefix=f"{prefix}.o_proj",
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
self._init_rotary_emb(config,
|
| 662 |
+
rope_scaling=rope_scaling,
|
| 663 |
+
quant_config=quant_config)
|
| 664 |
+
|
| 665 |
+
if hasattr(config, "interleaved_sliding_window"):
|
| 666 |
+
interleaved_sliding_window = config.interleaved_sliding_window
|
| 667 |
+
if isinstance(interleaved_sliding_window, int):
|
| 668 |
+
sliding_window = interleaved_sliding_window
|
| 669 |
+
elif isinstance(interleaved_sliding_window, list):
|
| 670 |
+
sw_idx = layer_idx % len(interleaved_sliding_window)
|
| 671 |
+
sliding_window = interleaved_sliding_window[sw_idx]
|
| 672 |
+
else:
|
| 673 |
+
raise ValueError(
|
| 674 |
+
f"{type(interleaved_sliding_window)} is not supported.")
|
| 675 |
+
else:
|
| 676 |
+
sliding_window = None
|
| 677 |
+
|
| 678 |
+
self.attn = Attention(
|
| 679 |
+
self.num_heads,
|
| 680 |
+
self.head_dim,
|
| 681 |
+
self.scaling,
|
| 682 |
+
num_kv_heads=self.num_kv_heads,
|
| 683 |
+
cache_config=cache_config,
|
| 684 |
+
quant_config=quant_config,
|
| 685 |
+
per_layer_sliding_window=sliding_window,
|
| 686 |
+
attn_type=attn_type,
|
| 687 |
+
prefix=f"{prefix}.attn",
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
def forward(
|
| 691 |
+
self,
|
| 692 |
+
positions: torch.Tensor,
|
| 693 |
+
hidden_states: torch.Tensor,
|
| 694 |
+
kv_cache: Optional[torch.Tensor] = None,
|
| 695 |
+
attn_metadata: Optional[AttentionMetadata] = None
|
| 696 |
+
) -> torch.Tensor:
|
| 697 |
+
qkv, _ = self.qkv_proj(hidden_states)
|
| 698 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
| 699 |
+
q, k = self.rotary_emb(positions, q, k)
|
| 700 |
+
attn_output = self.attn(q, k, v)
|
| 701 |
+
output, _ = self.o_proj(attn_output)
|
| 702 |
+
return output
|
| 703 |
+
|
| 704 |
+
def _init_rotary_emb(self, config: PretrainedConfig,
|
| 705 |
+
rope_scaling: Optional[dict[str, Any]],
|
| 706 |
+
quant_config: Optional[QuantizationConfig]) -> None:
|
| 707 |
+
is_neox_style = True
|
| 708 |
+
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
| 709 |
+
if is_gguf and config.model_type == "Pangu":
|
| 710 |
+
is_neox_style = False
|
| 711 |
+
|
| 712 |
+
self.rotary_emb = get_rope(
|
| 713 |
+
self.head_dim,
|
| 714 |
+
rotary_dim=self.head_dim,
|
| 715 |
+
max_position=self.max_position_embeddings,
|
| 716 |
+
base=self.rope_theta,
|
| 717 |
+
rope_scaling=rope_scaling,
|
| 718 |
+
is_neox_style=is_neox_style,
|
| 719 |
+
#partial_rotary_factor=self.partial_rotary_factor,
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
class OpenPanguDecoderLayer(nn.Module):
|
| 724 |
+
|
| 725 |
+
def __init__(
|
| 726 |
+
self,
|
| 727 |
+
config: PretrainedConfig,
|
| 728 |
+
prefix: str,
|
| 729 |
+
model_config: ModelConfig,
|
| 730 |
+
cache_config: Optional[CacheConfig] = None,
|
| 731 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 732 |
+
) -> None:
|
| 733 |
+
super().__init__()
|
| 734 |
+
self.hidden_size = config.hidden_size
|
| 735 |
+
rope_theta = getattr(config, "rope_theta", 10000)
|
| 736 |
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
| 737 |
+
|
| 738 |
+
layer_idx = int(prefix.split(sep='.')[-1])
|
| 739 |
+
self.layer_idx = layer_idx
|
| 740 |
+
self.layers = config.num_hidden_layers
|
| 741 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 742 |
+
self.tp_rank = get_tp_group().rank_in_group
|
| 743 |
+
ascend_config = get_ascend_config()
|
| 744 |
+
|
| 745 |
+
self.use_mla = hasattr(config, 'attention_qk_dim') and hasattr(config, 'attention_qk_rope_dim') \
|
| 746 |
+
and hasattr(config, 'attention_v_dim') and hasattr(config, 'attention_kv_lora_dim')
|
| 747 |
+
if self.use_mla:
|
| 748 |
+
self.self_attn = OpenPanguMLAAttention(
|
| 749 |
+
config=config,
|
| 750 |
+
hidden_size=self.hidden_size,
|
| 751 |
+
num_heads=config.num_attention_heads,
|
| 752 |
+
attention_qk_dim=config.attention_qk_dim,
|
| 753 |
+
attention_qk_rope_dim=config.attention_qk_rope_dim,
|
| 754 |
+
attention_v_dim=config.attention_v_dim,
|
| 755 |
+
attention_q_lora_dim=config.attention_q_lora_dim
|
| 756 |
+
if hasattr(config, "attention_q_lora_dim") else None,
|
| 757 |
+
attention_kv_lora_dim=config.attention_kv_lora_dim,
|
| 758 |
+
rope_theta=rope_theta,
|
| 759 |
+
max_position_embeddings=max_position_embeddings,
|
| 760 |
+
cache_config=cache_config,
|
| 761 |
+
quant_config=quant_config,
|
| 762 |
+
prefix=f"{prefix}.self_attn",
|
| 763 |
+
)
|
| 764 |
+
else:
|
| 765 |
+
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
| 766 |
+
config, "bias", False)
|
| 767 |
+
bias_o_proj = attention_bias
|
| 768 |
+
if hasattr(config, 'qkv_bias'):
|
| 769 |
+
attention_bias = config.qkv_bias
|
| 770 |
+
# By default, PanguEmbedded uses causal attention as it is a decoder-only model.
|
| 771 |
+
# You can override the HF config with `is_causal=False` to enable
|
| 772 |
+
# bidirectional attention, which is used in some embedding models
|
| 773 |
+
if getattr(config, "is_causal", True):
|
| 774 |
+
attn_type = AttentionType.DECODER
|
| 775 |
+
else:
|
| 776 |
+
attn_type = AttentionType.ENCODER_ONLY
|
| 777 |
+
self.self_attn = OpenPanguEmbeddedAttention(
|
| 778 |
+
config=config,
|
| 779 |
+
hidden_size=self.hidden_size,
|
| 780 |
+
num_heads=config.num_attention_heads,
|
| 781 |
+
num_kv_heads=getattr(config, "num_key_value_heads", config.num_attention_heads),
|
| 782 |
+
rope_theta=rope_theta,
|
| 783 |
+
rope_scaling=getattr(config, "rope_scaling", None),
|
| 784 |
+
max_position_embeddings=max_position_embeddings,
|
| 785 |
+
quant_config=quant_config,
|
| 786 |
+
bias=attention_bias,
|
| 787 |
+
bias_o_proj=bias_o_proj,
|
| 788 |
+
cache_config=cache_config,
|
| 789 |
+
prefix=f"{prefix}.self_attn",
|
| 790 |
+
attn_type=attn_type,
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
if getattr(config, 'num_routed_experts', None) is not None and layer_idx >= config.num_dense_layers:
|
| 794 |
+
self.mlp = OpenPanguMoE(
|
| 795 |
+
config=config,
|
| 796 |
+
quant_config=quant_config,
|
| 797 |
+
prefix=f"{prefix}.mlp",
|
| 798 |
+
)
|
| 799 |
+
self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \
|
| 800 |
+
and model_config.use_mla and envs.VLLM_USE_V1 and self.tp_size > 1
|
| 801 |
+
else:
|
| 802 |
+
self.mlp = OpenPanguMLP(
|
| 803 |
+
hidden_size=self.hidden_size,
|
| 804 |
+
intermediate_size=config.intermediate_size,
|
| 805 |
+
hidden_act=config.hidden_act,
|
| 806 |
+
quant_config=quant_config,
|
| 807 |
+
bias=getattr(config, "mlp_bias", False),
|
| 808 |
+
prefix=f"{prefix}.mlp",
|
| 809 |
+
)
|
| 810 |
+
self.mla_moe_communication = False
|
| 811 |
+
self.routed_scaling_factor = getattr(config, 'routed_scaling_factor', None)
|
| 812 |
+
self.num_dense_layers = getattr(config, 'num_dense_layers', None)
|
| 813 |
+
|
| 814 |
+
self.input_layernorm = RMSNorm(config.hidden_size,
|
| 815 |
+
eps=config.rms_norm_eps)
|
| 816 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
| 817 |
+
eps=config.rms_norm_eps)
|
| 818 |
+
if getattr(config, 'sandwich_norm', False):
|
| 819 |
+
self.sandwich_norm = True
|
| 820 |
+
self.pre_mlp_layernorm = RMSNorm(config.hidden_size,
|
| 821 |
+
eps=config.rms_norm_eps)
|
| 822 |
+
self.post_mlp_layernorm = RMSNorm(config.hidden_size,
|
| 823 |
+
eps=config.rms_norm_eps)
|
| 824 |
+
else:
|
| 825 |
+
self.sandwich_norm = False
|
| 826 |
+
|
| 827 |
+
def forward(
|
| 828 |
+
self,
|
| 829 |
+
positions: torch.Tensor,
|
| 830 |
+
hidden_states: torch.Tensor,
|
| 831 |
+
residual: Optional[torch.Tensor],
|
| 832 |
+
kv_cache: Optional[torch.Tensor] = None,
|
| 833 |
+
attn_metadata: Optional[AttentionMetadata] = None,
|
| 834 |
+
replace_allreduce: bool = False,
|
| 835 |
+
) -> torch.Tensor:
|
| 836 |
+
# Self Attention
|
| 837 |
+
if self.use_mla and attn_metadata is not None and attn_metadata.num_decodes > 0:
|
| 838 |
+
mla_moe_communication = self.mla_moe_communication and replace_allreduce
|
| 839 |
+
else:
|
| 840 |
+
mla_moe_communication = False
|
| 841 |
+
if residual is None:
|
| 842 |
+
residual = hidden_states
|
| 843 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 844 |
+
else:
|
| 845 |
+
previous_hidden_states, previous_residual = hidden_states, residual
|
| 846 |
+
hidden_states, residual = self.input_layernorm(
|
| 847 |
+
hidden_states, residual)
|
| 848 |
+
# Dispose hidden_states and residual from the previous layer
|
| 849 |
+
# to save npu memory because they're no longer used.
|
| 850 |
+
dispose_tensor(previous_hidden_states)
|
| 851 |
+
dispose_tensor(previous_residual)
|
| 852 |
+
if mla_moe_communication and self.layer_idx > self.num_dense_layers:
|
| 853 |
+
hidden_states = tensor_model_parallel_all_gather(hidden_states,
|
| 854 |
+
dim=0)
|
| 855 |
+
|
| 856 |
+
hidden_states = self.self_attn(
|
| 857 |
+
positions=positions,
|
| 858 |
+
hidden_states=hidden_states,
|
| 859 |
+
kv_cache=kv_cache,
|
| 860 |
+
attn_metadata=attn_metadata,
|
| 861 |
+
)
|
| 862 |
+
|
| 863 |
+
if mla_moe_communication and residual.shape[0] != hidden_states.shape[0]:
|
| 864 |
+
chunk_hidden_states = torch.tensor_split(residual,
|
| 865 |
+
self.tp_size,
|
| 866 |
+
dim=0)
|
| 867 |
+
residual = chunk_hidden_states[self.tp_rank]
|
| 868 |
+
|
| 869 |
+
if self.routed_scaling_factor is not None and hidden_states.dtype == torch.float16:
|
| 870 |
+
# Fix FP16 overflow
|
| 871 |
+
# We scale both hidden_states and residual before
|
| 872 |
+
# rmsnorm, and rmsnorm result would not affect by scale.
|
| 873 |
+
hidden_states *= 1. / self.routed_scaling_factor
|
| 874 |
+
if self.layer_idx == 0:
|
| 875 |
+
# The residual is shared by all layers, we only scale it on
|
| 876 |
+
# first layer.
|
| 877 |
+
residual *= 1. / self.routed_scaling_factor
|
| 878 |
+
|
| 879 |
+
if self.sandwich_norm:
|
| 880 |
+
hidden_states = self.post_attention_layernorm(
|
| 881 |
+
hidden_states)
|
| 882 |
+
hidden_states, residual = self.pre_mlp_layernorm(
|
| 883 |
+
hidden_states, residual)
|
| 884 |
+
else:
|
| 885 |
+
hidden_states, residual = self.post_attention_layernorm(
|
| 886 |
+
hidden_states, residual)
|
| 887 |
+
|
| 888 |
+
# Fully Connected
|
| 889 |
+
if isinstance(self.mlp, OpenPanguMoE):
|
| 890 |
+
hidden_states = self.mlp(hidden_states,
|
| 891 |
+
attn_metadata,
|
| 892 |
+
replace_allreduce=mla_moe_communication)
|
| 893 |
+
else:
|
| 894 |
+
hidden_states = self.mlp(hidden_states)
|
| 895 |
+
|
| 896 |
+
if self.routed_scaling_factor is not None and isinstance(self.mlp, OpenPanguMLP) \
|
| 897 |
+
and hidden_states.dtype == torch.float16:
|
| 898 |
+
hidden_states *= 1. / self.routed_scaling_factor
|
| 899 |
+
|
| 900 |
+
if self.sandwich_norm:
|
| 901 |
+
hidden_states = self.post_mlp_layernorm(hidden_states)
|
| 902 |
+
|
| 903 |
+
if mla_moe_communication and self.layer_idx == self.layers - 1:
|
| 904 |
+
hidden_states = tensor_model_parallel_all_gather(hidden_states,
|
| 905 |
+
dim=0)
|
| 906 |
+
residual = tensor_model_parallel_all_gather(residual, dim=0)
|
| 907 |
+
|
| 908 |
+
return hidden_states, residual
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
@support_torch_compile
|
| 912 |
+
class OpenPanguModel(nn.Module):
|
| 913 |
+
|
| 914 |
+
fall_back_to_pt_during_load = False
|
| 915 |
+
|
| 916 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 917 |
+
super().__init__()
|
| 918 |
+
|
| 919 |
+
config = vllm_config.model_config.hf_config
|
| 920 |
+
model_config = vllm_config.model_config
|
| 921 |
+
cache_config = vllm_config.cache_config
|
| 922 |
+
quant_config = vllm_config.quant_config
|
| 923 |
+
|
| 924 |
+
self.padding_idx = config.pad_token_id
|
| 925 |
+
self.vocab_size = config.vocab_size
|
| 926 |
+
self.tp_size = get_tensor_model_parallel_world_size()
|
| 927 |
+
|
| 928 |
+
self.embed_tokens = VocabParallelEmbedding(
|
| 929 |
+
config.vocab_size,
|
| 930 |
+
config.hidden_size,
|
| 931 |
+
quant_config=quant_config,
|
| 932 |
+
prefix=f"{prefix}.embed_tokens")
|
| 933 |
+
|
| 934 |
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
| 935 |
+
config.num_hidden_layers,
|
| 936 |
+
lambda prefix: OpenPanguDecoderLayer(
|
| 937 |
+
config,
|
| 938 |
+
prefix,
|
| 939 |
+
model_config=model_config,
|
| 940 |
+
cache_config=cache_config,
|
| 941 |
+
quant_config=quant_config,
|
| 942 |
+
),
|
| 943 |
+
prefix=f"{prefix}.layers")
|
| 944 |
+
|
| 945 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 946 |
+
|
| 947 |
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 948 |
+
return self.embed_tokens(input_ids)
|
| 949 |
+
|
| 950 |
+
def forward(
|
| 951 |
+
self,
|
| 952 |
+
input_ids: torch.Tensor,
|
| 953 |
+
positions: torch.Tensor,
|
| 954 |
+
kv_caches: Optional[List[torch.Tensor]] = None,
|
| 955 |
+
attn_metadata: Optional[AttentionMetadata] = None,
|
| 956 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 957 |
+
**kwargs,
|
| 958 |
+
) -> torch.Tensor:
|
| 959 |
+
if inputs_embeds is not None:
|
| 960 |
+
hidden_states = inputs_embeds
|
| 961 |
+
else:
|
| 962 |
+
hidden_states = self.get_input_embeddings(input_ids)
|
| 963 |
+
residual = None
|
| 964 |
+
|
| 965 |
+
replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
|
| 966 |
+
|
| 967 |
+
for i in range(self.start_layer, self.end_layer):
|
| 968 |
+
layer = self.layers[i]
|
| 969 |
+
hidden_states, residual = layer(
|
| 970 |
+
positions,
|
| 971 |
+
hidden_states,
|
| 972 |
+
residual,
|
| 973 |
+
kv_caches[i -
|
| 974 |
+
self.start_layer] if kv_caches is not None else None,
|
| 975 |
+
attn_metadata,
|
| 976 |
+
replace_allreduce=replace_allreduce)
|
| 977 |
+
|
| 978 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
| 979 |
+
return hidden_states
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
class OpenPanguForCausalLM(nn.Module):
|
| 983 |
+
packed_modules_mapping = {
|
| 984 |
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
| 985 |
+
"experts":
|
| 986 |
+
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
| 987 |
+
}
|
| 988 |
+
|
| 989 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 990 |
+
super().__init__()
|
| 991 |
+
config = vllm_config.model_config.hf_config
|
| 992 |
+
quant_config = vllm_config.quant_config
|
| 993 |
+
self.config = config
|
| 994 |
+
self.quant_config = quant_config
|
| 995 |
+
self.model = OpenPanguModel(vllm_config=vllm_config,
|
| 996 |
+
prefix=maybe_prefix(prefix, "model"))
|
| 997 |
+
self.lm_head = ParallelLMHead(config.vocab_size,
|
| 998 |
+
config.hidden_size,
|
| 999 |
+
quant_config=quant_config,
|
| 1000 |
+
prefix=maybe_prefix(prefix, "lm_head"))
|
| 1001 |
+
self.logits_processor = LogitsProcessor(config.vocab_size)
|
| 1002 |
+
self.sampler = get_sampler()
|
| 1003 |
+
|
| 1004 |
+
def load_attn_mlp_weight(self,
|
| 1005 |
+
attn_mlp_replace_mapping: List[Tuple[str, str, int]],
|
| 1006 |
+
params_dict: Dict[str, Any],
|
| 1007 |
+
weight_name: str,
|
| 1008 |
+
loaded_weight: torch.Tensor,
|
| 1009 |
+
loaded_params: set[str]) -> bool:
|
| 1010 |
+
for (param_name, origin_name, shard_id) in attn_mlp_replace_mapping:
|
| 1011 |
+
if origin_name not in weight_name or \
|
| 1012 |
+
(("mlp.experts." in weight_name) and weight_name not in params_dict):
|
| 1013 |
+
continue
|
| 1014 |
+
weight_name = weight_name.replace(origin_name, param_name)
|
| 1015 |
+
if weight_name.endswith(".bias") and weight_name not in params_dict:
|
| 1016 |
+
continue
|
| 1017 |
+
param = params_dict[weight_name]
|
| 1018 |
+
weight_loader = param.weight_loader
|
| 1019 |
+
weight_loader(param, loaded_weight, shard_id)
|
| 1020 |
+
loaded_params.add(weight_name)
|
| 1021 |
+
return True
|
| 1022 |
+
return False
|
| 1023 |
+
|
| 1024 |
+
def load_expert_weight(self,
|
| 1025 |
+
expert_merge_mapping: List[Tuple[str, str, int, str]],
|
| 1026 |
+
params_dict: Dict[str, Any],
|
| 1027 |
+
weight_name: str,
|
| 1028 |
+
loaded_weight: torch.Tensor,
|
| 1029 |
+
loaded_params: set[str]) -> bool:
|
| 1030 |
+
for mapping in expert_merge_mapping:
|
| 1031 |
+
param_name, origin_name, expert_id, shard_id = mapping
|
| 1032 |
+
if origin_name not in weight_name:
|
| 1033 |
+
continue
|
| 1034 |
+
weight_name = weight_name.replace(origin_name, param_name)
|
| 1035 |
+
param = params_dict[weight_name]
|
| 1036 |
+
weight_loader = param.weight_loader
|
| 1037 |
+
weight_loader(param,
|
| 1038 |
+
loaded_weight,
|
| 1039 |
+
weight_name,
|
| 1040 |
+
shard_id=shard_id,
|
| 1041 |
+
expert_id=expert_id,
|
| 1042 |
+
return_success=False)
|
| 1043 |
+
loaded_params.add(weight_name)
|
| 1044 |
+
return True
|
| 1045 |
+
return False
|
| 1046 |
+
|
| 1047 |
+
def load_weights(self, weights: Iterable[tuple[str,
|
| 1048 |
+
torch.Tensor]]) -> set[str]:
|
| 1049 |
+
# (param_name, shard_name, shard_id)
|
| 1050 |
+
attn_mlp_replace_mapping = [
|
| 1051 |
+
(".qkv_proj", ".q_proj", "q"),
|
| 1052 |
+
(".qkv_proj", ".k_proj", "k"),
|
| 1053 |
+
(".qkv_proj", ".v_proj", "v"),
|
| 1054 |
+
(".gate_up_proj", ".gate_proj", 0),
|
| 1055 |
+
(".gate_up_proj", ".up_proj", 1),
|
| 1056 |
+
]
|
| 1057 |
+
has_experts = hasattr(self.config, 'num_routed_experts')
|
| 1058 |
+
if has_experts:
|
| 1059 |
+
expert_merge_mapping = AscendFusedMoE.make_expert_params_mapping(
|
| 1060 |
+
ckpt_gate_proj_name="gate_proj",
|
| 1061 |
+
ckpt_down_proj_name="down_proj",
|
| 1062 |
+
ckpt_up_proj_name="up_proj",
|
| 1063 |
+
num_experts=self.config.num_routed_experts)
|
| 1064 |
+
|
| 1065 |
+
params_dict = dict(self.named_parameters())
|
| 1066 |
+
loaded_params: set[str] = set()
|
| 1067 |
+
for name, loaded_weight in weights:
|
| 1068 |
+
if "rotary_emb.inv_freq" in name:
|
| 1069 |
+
continue
|
| 1070 |
+
if 'layers' in name: # skip spec decode layers for main model
|
| 1071 |
+
layer_idx = int(name.split('layers.')[-1].split('.')[0])
|
| 1072 |
+
if layer_idx > self.config.num_hidden_layers:
|
| 1073 |
+
continue
|
| 1074 |
+
|
| 1075 |
+
if 'layers' in name and hasattr(self.config, "num_mtp_layers") \
|
| 1076 |
+
and (self.config.num_mtp_layers > 0):
|
| 1077 |
+
layer_idx = int(name.split('layers.')[-1].split('.')[0])
|
| 1078 |
+
mtp_idx = layer_idx - self.config.num_hidden_layers
|
| 1079 |
+
if mtp_idx >= 0 and mtp_idx < self.config.num_mtp_layers:
|
| 1080 |
+
continue # skip spec decode layers for main model
|
| 1081 |
+
if self.load_attn_mlp_weight(attn_mlp_replace_mapping, params_dict, name, loaded_weight, loaded_params):
|
| 1082 |
+
continue
|
| 1083 |
+
elif has_experts and self.load_expert_weight(expert_merge_mapping, params_dict, name, loaded_weight, loaded_params):
|
| 1084 |
+
continue
|
| 1085 |
+
else:
|
| 1086 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 1087 |
+
continue
|
| 1088 |
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
| 1089 |
+
if name is None:
|
| 1090 |
+
continue
|
| 1091 |
+
param = params_dict[name]
|
| 1092 |
+
weight_loader = getattr(param, "weight_loader",
|
| 1093 |
+
default_weight_loader)
|
| 1094 |
+
weight_loader(param, loaded_weight)
|
| 1095 |
+
loaded_params.add(name)
|
| 1096 |
+
if self.config.tie_word_embeddings:
|
| 1097 |
+
self.lm_head.weight = self.model.embed_tokens.weight
|
| 1098 |
+
return loaded_params
|
| 1099 |
+
|
| 1100 |
+
def forward(
|
| 1101 |
+
self,
|
| 1102 |
+
input_ids: torch.Tensor,
|
| 1103 |
+
positions: torch.Tensor,
|
| 1104 |
+
kv_caches: Optional[List[torch.Tensor]] = None,
|
| 1105 |
+
attn_metadata: Optional[AttentionMetadata] = None,
|
| 1106 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1107 |
+
**kwargs,
|
| 1108 |
+
) -> torch.Tensor:
|
| 1109 |
+
hidden_states = self.model(input_ids, positions, kv_caches,
|
| 1110 |
+
attn_metadata, inputs_embeds)
|
| 1111 |
+
return hidden_states
|
| 1112 |
+
|
| 1113 |
+
def compute_logits(
|
| 1114 |
+
self,
|
| 1115 |
+
hidden_states: torch.Tensor,
|
| 1116 |
+
sampling_metadata: SamplingMetadata,
|
| 1117 |
+
) -> Optional[torch.Tensor]:
|
| 1118 |
+
logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
|
| 1119 |
+
return logits
|
| 1120 |
+
|
| 1121 |
+
|
| 1122 |
+
class PanguUltraMoEForCausalLM(OpenPanguForCausalLM):
|
| 1123 |
+
pass
|
| 1124 |
+
|
| 1125 |
+
|
| 1126 |
+
class PanguEmbeddedForCausalLM(OpenPanguForCausalLM):
|
| 1127 |
+
pass
|
inference/vllm_ascend/ops/fused_moe.py
ADDED
|
@@ -0,0 +1,1530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 2 |
+
# Copyright 2023 The vLLM team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# This file is a part of the vllm-ascend project.
|
| 16 |
+
# Adapted from vllm/tests/kernels/test_moe.py
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
import torch_npu
|
| 24 |
+
from torch import nn
|
| 25 |
+
from vllm.config import get_current_vllm_config
|
| 26 |
+
from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
|
| 27 |
+
get_tensor_model_parallel_world_size,
|
| 28 |
+
tensor_model_parallel_all_reduce)
|
| 29 |
+
from vllm.distributed.parallel_state import get_dp_group, get_tp_group
|
| 30 |
+
from vllm.forward_context import get_forward_context
|
| 31 |
+
from vllm.model_executor.layers.fused_moe.config import \
|
| 32 |
+
FusedMoEConfig # isort: skip
|
| 33 |
+
from vllm.model_executor.layers.fused_moe.config import \
|
| 34 |
+
FusedMoEParallelConfig # isort: skip
|
| 35 |
+
from vllm.model_executor.layers.fused_moe.layer import (
|
| 36 |
+
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
| 37 |
+
from vllm.model_executor.layers.quantization.base_config import \
|
| 38 |
+
QuantizationConfig
|
| 39 |
+
|
| 40 |
+
import vllm_ascend.envs as envs_ascend
|
| 41 |
+
from vllm_ascend.ascend_config import get_ascend_config
|
| 42 |
+
from vllm_ascend.distributed.communication_op import \
|
| 43 |
+
data_parallel_reduce_scatter
|
| 44 |
+
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
|
| 45 |
+
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
| 46 |
+
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
|
| 47 |
+
get_all_reduce_merge_state, get_fused_moe_state,
|
| 48 |
+
get_rm_router_logits_state, is_310p,
|
| 49 |
+
npu_stream_switch, npu_wait_tensor)
|
| 50 |
+
|
| 51 |
+
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
| 52 |
+
SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
| 56 |
+
max_row_per_ep_rank: int, num_tokens: int,
|
| 57 |
+
top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
| 58 |
+
original_total_elements = num_tokens * top_k
|
| 59 |
+
device = topk_ids.device
|
| 60 |
+
original_dtype = topk_ids.dtype
|
| 61 |
+
|
| 62 |
+
if original_total_elements == 0:
|
| 63 |
+
output_len = ep_size * max_row_per_ep_rank
|
| 64 |
+
topk_ids_pad = torch.full((output_len, ),
|
| 65 |
+
expert_num,
|
| 66 |
+
dtype=original_dtype,
|
| 67 |
+
device=device)
|
| 68 |
+
unpad_indices = torch.full((original_total_elements, ),
|
| 69 |
+
-1,
|
| 70 |
+
dtype=torch.long,
|
| 71 |
+
device=device)
|
| 72 |
+
return topk_ids_pad, unpad_indices
|
| 73 |
+
|
| 74 |
+
experts_per_ep_rank_val = expert_num // ep_size
|
| 75 |
+
if experts_per_ep_rank_val == 0:
|
| 76 |
+
raise ValueError(
|
| 77 |
+
"expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. "
|
| 78 |
+
"Ensure expert_num >= ep_size.")
|
| 79 |
+
|
| 80 |
+
assigned_ep_rank = (topk_ids.float() /
|
| 81 |
+
experts_per_ep_rank_val).to(original_dtype)
|
| 82 |
+
indices_arange = torch.arange(topk_ids.shape[0], device=device)
|
| 83 |
+
|
| 84 |
+
is_new_segment = torch.cat(
|
| 85 |
+
(torch.tensor([True], device=device), assigned_ep_rank[1:]
|
| 86 |
+
!= assigned_ep_rank[:-1]))
|
| 87 |
+
temp_start_markers = torch.full_like(indices_arange,
|
| 88 |
+
-1,
|
| 89 |
+
dtype=indices_arange.dtype)
|
| 90 |
+
temp_start_markers[is_new_segment] = indices_arange[is_new_segment]
|
| 91 |
+
start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0]
|
| 92 |
+
token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token
|
| 93 |
+
is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank
|
| 94 |
+
cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long)
|
| 95 |
+
indices_in_rec_cond_list_for_all = cumsum_kept - 1
|
| 96 |
+
unpad_indices = torch.where(
|
| 97 |
+
is_kept_mask, indices_in_rec_cond_list_for_all,
|
| 98 |
+
torch.tensor(-1, device=device, dtype=torch.long))
|
| 99 |
+
output_len = ep_size * max_row_per_ep_rank
|
| 100 |
+
topk_ids_pad = torch.full((output_len, ),
|
| 101 |
+
expert_num,
|
| 102 |
+
dtype=original_dtype,
|
| 103 |
+
device=device)
|
| 104 |
+
if topk_ids.shape[0] > 0:
|
| 105 |
+
all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx
|
| 106 |
+
temp_pad_buffer = torch.full((output_len + 1, ),
|
| 107 |
+
expert_num,
|
| 108 |
+
dtype=original_dtype,
|
| 109 |
+
device=device)
|
| 110 |
+
output_len_tensor = torch.tensor(output_len,
|
| 111 |
+
dtype=torch.long,
|
| 112 |
+
device=device)
|
| 113 |
+
scatter_indices = torch.where(is_kept_mask, all_destination_indices,
|
| 114 |
+
output_len_tensor)
|
| 115 |
+
temp_pad_buffer.scatter_(0, scatter_indices, topk_ids)
|
| 116 |
+
topk_ids_pad = temp_pad_buffer[:output_len]
|
| 117 |
+
return topk_ids_pad, unpad_indices
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def fused_experts_with_mc2(
|
| 121 |
+
hidden_states: torch.Tensor,
|
| 122 |
+
w1: torch.Tensor,
|
| 123 |
+
w2: torch.Tensor,
|
| 124 |
+
topk_weights: torch.Tensor,
|
| 125 |
+
topk_ids: torch.Tensor,
|
| 126 |
+
top_k: int,
|
| 127 |
+
expert_map: torch.Tensor = None,
|
| 128 |
+
moe_all_to_all_group_name: Optional[str] = None,
|
| 129 |
+
shared_experts: Optional[Any] = None
|
| 130 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 131 |
+
global_bs = 0
|
| 132 |
+
moe_expert_num = len(expert_map)
|
| 133 |
+
kwargs_mc2 = {
|
| 134 |
+
"x": hidden_states,
|
| 135 |
+
"expert_ids": topk_ids,
|
| 136 |
+
"expert_shard_type": 0,
|
| 137 |
+
"shared_expert_rank_num": 0,
|
| 138 |
+
"moe_expert_num": moe_expert_num,
|
| 139 |
+
"global_bs": global_bs,
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
rank = torch.distributed.get_rank()
|
| 143 |
+
|
| 144 |
+
quant_mode = 0
|
| 145 |
+
ep_group = get_ep_group().device_group
|
| 146 |
+
local_rank = torch.distributed.get_rank(group=ep_group)
|
| 147 |
+
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
|
| 148 |
+
|
| 149 |
+
tp_size = get_etp_group().world_size
|
| 150 |
+
tp_rank = rank % tp_size
|
| 151 |
+
|
| 152 |
+
stage1_kwargs = {
|
| 153 |
+
"scales": None,
|
| 154 |
+
"quant_mode": quant_mode,
|
| 155 |
+
"group_ep": moe_all_to_all_group_name,
|
| 156 |
+
"ep_world_size": all_to_all_group_size,
|
| 157 |
+
"ep_rank_id": local_rank,
|
| 158 |
+
# "group_tp": self.moe_rs_group_name,
|
| 159 |
+
"group_tp": moe_all_to_all_group_name,
|
| 160 |
+
"tp_world_size": tp_size,
|
| 161 |
+
"tp_rank_id": tp_rank,
|
| 162 |
+
}
|
| 163 |
+
kwargs_mc2.update(stage1_kwargs)
|
| 164 |
+
|
| 165 |
+
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
| 166 |
+
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
| 167 |
+
0:5]
|
| 168 |
+
|
| 169 |
+
if shared_experts is not None:
|
| 170 |
+
with npu_stream_switch("moe_secondary", 0):
|
| 171 |
+
npu_wait_tensor(hidden_states, topk_weights)
|
| 172 |
+
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
| 173 |
+
npu_wait_tensor(shared_gate_up, expand_x)
|
| 174 |
+
shared_act = shared_experts.act_fn(shared_gate_up)
|
| 175 |
+
|
| 176 |
+
w1 = w1.transpose(1, 2)
|
| 177 |
+
|
| 178 |
+
group_list = expert_token_nums.to(torch.int64)
|
| 179 |
+
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
| 180 |
+
x=[expand_x],
|
| 181 |
+
weight=[w1],
|
| 182 |
+
split_item=2,
|
| 183 |
+
# 1 means count mode, to avoid cumulative operation of the group list
|
| 184 |
+
group_list_type=1,
|
| 185 |
+
group_type=0,
|
| 186 |
+
group_list=group_list,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# TODO: Remove this in the future.
|
| 190 |
+
gate_up_out = torch.cat(gate_up_out_list, dim=0)
|
| 191 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
| 192 |
+
|
| 193 |
+
w2 = w2.transpose(1, 2)
|
| 194 |
+
down_out_list = torch_npu.npu_grouped_matmul(
|
| 195 |
+
x=[gate_up_out],
|
| 196 |
+
weight=[w2],
|
| 197 |
+
split_item=2,
|
| 198 |
+
group_list_type=1,
|
| 199 |
+
group_type=0,
|
| 200 |
+
group_list=group_list,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
down_out_list = torch.cat(down_out_list, dim=0)
|
| 204 |
+
|
| 205 |
+
# moeCombine
|
| 206 |
+
kwargs_mc2 = {
|
| 207 |
+
"expand_x": down_out_list,
|
| 208 |
+
"expert_ids": topk_ids,
|
| 209 |
+
"expand_idx": expand_idx,
|
| 210 |
+
"expert_scales": topk_weights.to(torch.float32),
|
| 211 |
+
"expert_shard_type": 0,
|
| 212 |
+
"shared_expert_rank_num": 0,
|
| 213 |
+
"moe_expert_num": moe_expert_num,
|
| 214 |
+
"global_bs": 0,
|
| 215 |
+
}
|
| 216 |
+
tp_recv_counts = output[5]
|
| 217 |
+
stage3_kwargs = {
|
| 218 |
+
"ep_send_counts": ep_recv_counts,
|
| 219 |
+
"group_ep": moe_all_to_all_group_name,
|
| 220 |
+
"ep_world_size": all_to_all_group_size,
|
| 221 |
+
"ep_rank_id": local_rank,
|
| 222 |
+
"tp_send_counts": tp_recv_counts,
|
| 223 |
+
# "group_tp": self.moe_rs_group_name,
|
| 224 |
+
"group_tp": moe_all_to_all_group_name,
|
| 225 |
+
"tp_world_size": tp_size,
|
| 226 |
+
"tp_rank_id": tp_rank,
|
| 227 |
+
}
|
| 228 |
+
kwargs_mc2.update(stage3_kwargs)
|
| 229 |
+
|
| 230 |
+
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
| 231 |
+
|
| 232 |
+
if shared_experts is None:
|
| 233 |
+
return hidden_states
|
| 234 |
+
else:
|
| 235 |
+
with npu_stream_switch("moe_secondary", 0):
|
| 236 |
+
npu_wait_tensor(shared_act, down_out_list)
|
| 237 |
+
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
|
| 238 |
+
return hidden_states, shared_hidden_states
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
| 242 |
+
w1: torch.Tensor,
|
| 243 |
+
w2: torch.Tensor,
|
| 244 |
+
group_list: torch.Tensor,
|
| 245 |
+
group_list_type: int = 1) -> torch.Tensor:
|
| 246 |
+
"""
|
| 247 |
+
apply MLP: gate_up_proj -> swiglu -> down_proj
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
|
| 251 |
+
w1: expert weights1 with shape
|
| 252 |
+
(num_experts, hidden_size, intermediate_size * 2)
|
| 253 |
+
w2: expert weights2 with shape
|
| 254 |
+
(num_experts, intermediate_size, hidden_size)
|
| 255 |
+
group_list: number of tokens for each expert, follow cumsum mode, and
|
| 256 |
+
with shape (num_experts).
|
| 257 |
+
transpose_weight:
|
| 258 |
+
w1: (num_experts, intermediate_size * 2, hidden_size) ->
|
| 259 |
+
(num_experts, hidden_size, intermediate_size * 2)
|
| 260 |
+
w2: (num_experts, hidden_size, intermediate_size) ->
|
| 261 |
+
(num_experts, intermediate_size, hidden_size)
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
hidden_states: output hidden states after MLP.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
assert len(hidden_states_wrapper) == 1
|
| 268 |
+
hidden_states = hidden_states_wrapper.pop()
|
| 269 |
+
|
| 270 |
+
w1 = w1.transpose(1, 2)
|
| 271 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 272 |
+
x=[hidden_states],
|
| 273 |
+
weight=[w1],
|
| 274 |
+
split_item=2,
|
| 275 |
+
group_list_type=group_list_type,
|
| 276 |
+
group_type=0,
|
| 277 |
+
group_list=group_list,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
hidden_states = torch.cat(hidden_states, dim=0)
|
| 281 |
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
| 282 |
+
|
| 283 |
+
w2 = w2.transpose(1, 2)
|
| 284 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 285 |
+
x=[hidden_states],
|
| 286 |
+
weight=[w2],
|
| 287 |
+
split_item=2,
|
| 288 |
+
group_list_type=group_list_type,
|
| 289 |
+
group_type=0,
|
| 290 |
+
group_list=group_list,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
hidden_states = torch.cat(hidden_states, dim=0)
|
| 294 |
+
return hidden_states
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def fused_experts_with_all2all(
|
| 298 |
+
hidden_states: torch.Tensor,
|
| 299 |
+
w1: torch.Tensor,
|
| 300 |
+
w2: torch.Tensor,
|
| 301 |
+
topk_weights: torch.Tensor,
|
| 302 |
+
topk_ids: torch.Tensor,
|
| 303 |
+
top_k: int,
|
| 304 |
+
expert_map: torch.Tensor = None,
|
| 305 |
+
ep_group: GroupCoordinator = None,
|
| 306 |
+
):
|
| 307 |
+
original_shape = hidden_states.shape
|
| 308 |
+
if len(original_shape) == 3:
|
| 309 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 310 |
+
|
| 311 |
+
num_tokens, _ = hidden_states.shape
|
| 312 |
+
num_experts = w1.shape[0]
|
| 313 |
+
device = hidden_states.device
|
| 314 |
+
|
| 315 |
+
if expert_map is not None:
|
| 316 |
+
global_num_experts = len(expert_map)
|
| 317 |
+
local_num_experts = global_num_experts // ep_group.world_size
|
| 318 |
+
row_idx_len = num_tokens * top_k
|
| 319 |
+
row_idx = (torch.arange(0,
|
| 320 |
+
row_idx_len,
|
| 321 |
+
dtype=torch.int32,
|
| 322 |
+
device=device).view(top_k, -1).permute(
|
| 323 |
+
1, 0).contiguous())
|
| 324 |
+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 325 |
+
hidden_states,
|
| 326 |
+
row_idx=row_idx,
|
| 327 |
+
expert_idx=topk_ids,
|
| 328 |
+
active_num=num_tokens)
|
| 329 |
+
|
| 330 |
+
global_expert_tokens = torch.bincount(expanded_expert_idx,
|
| 331 |
+
minlength=global_num_experts)
|
| 332 |
+
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
|
| 333 |
+
-1).sum(-1)
|
| 334 |
+
|
| 335 |
+
gather_sizes = torch.empty_like(scatter_sizes)
|
| 336 |
+
dist.all_to_all_single(gather_sizes,
|
| 337 |
+
scatter_sizes,
|
| 338 |
+
group=ep_group.device_group)
|
| 339 |
+
scatter_size_list = scatter_sizes.cpu().tolist()
|
| 340 |
+
gather_size_list = gather_sizes.cpu().tolist()
|
| 341 |
+
|
| 342 |
+
expanded_expert_idx = expanded_expert_idx % local_num_experts
|
| 343 |
+
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
| 344 |
+
scatter_size_list,
|
| 345 |
+
gather_size_list)
|
| 346 |
+
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
|
| 347 |
+
scatter_size_list,
|
| 348 |
+
gather_size_list)
|
| 349 |
+
|
| 350 |
+
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
|
| 351 |
+
|
| 352 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 353 |
+
sorted_local_expert_idx, local_num_experts).to(torch.int64)
|
| 354 |
+
|
| 355 |
+
hidden_states = hidden_states[sorted_idx]
|
| 356 |
+
else:
|
| 357 |
+
row_idx_len = num_tokens * top_k
|
| 358 |
+
row_idx = torch.arange(0,
|
| 359 |
+
row_idx_len,
|
| 360 |
+
dtype=torch.int32,
|
| 361 |
+
device=topk_weights.device).view(
|
| 362 |
+
top_k, -1).permute(1, 0).contiguous()
|
| 363 |
+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 364 |
+
hidden_states,
|
| 365 |
+
row_idx=row_idx,
|
| 366 |
+
expert_idx=topk_ids,
|
| 367 |
+
active_num=num_tokens)
|
| 368 |
+
|
| 369 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 370 |
+
expanded_expert_idx, num_experts)
|
| 371 |
+
expert_tokens = expert_tokens.to(torch.int64)
|
| 372 |
+
|
| 373 |
+
w1 = w1.transpose(1, 2)
|
| 374 |
+
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
| 375 |
+
x=[hidden_states],
|
| 376 |
+
weight=[w1],
|
| 377 |
+
split_item=2,
|
| 378 |
+
group_list_type=0,
|
| 379 |
+
group_type=0,
|
| 380 |
+
group_list=expert_tokens,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# TODO: Remove this in the future.
|
| 384 |
+
hidden_states = torch.cat(gate_up_out_list, dim=0)
|
| 385 |
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
| 386 |
+
|
| 387 |
+
w2 = w2.transpose(1, 2)
|
| 388 |
+
down_out_list = torch_npu.npu_grouped_matmul(
|
| 389 |
+
x=[hidden_states],
|
| 390 |
+
weight=[w2],
|
| 391 |
+
split_item=2,
|
| 392 |
+
group_list_type=0,
|
| 393 |
+
group_type=0,
|
| 394 |
+
group_list=expert_tokens,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
hidden_states = torch.cat(down_out_list, dim=0)
|
| 398 |
+
|
| 399 |
+
if expert_map is not None:
|
| 400 |
+
resorted_idx = torch.argsort(sorted_idx)
|
| 401 |
+
hidden_states = hidden_states[resorted_idx]
|
| 402 |
+
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
| 403 |
+
gather_size_list,
|
| 404 |
+
scatter_size_list)
|
| 405 |
+
|
| 406 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 407 |
+
hidden_states,
|
| 408 |
+
skip1=None,
|
| 409 |
+
skip2=None,
|
| 410 |
+
bias=None,
|
| 411 |
+
scales=topk_weights,
|
| 412 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 413 |
+
export_for_source_row=topk_ids,
|
| 414 |
+
)
|
| 415 |
+
else:
|
| 416 |
+
# TODO: Reorder device memory 2 times here, replace the current
|
| 417 |
+
# implementation here when suitable operators become available.
|
| 418 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 419 |
+
hidden_states,
|
| 420 |
+
skip1=None,
|
| 421 |
+
skip2=None,
|
| 422 |
+
bias=None,
|
| 423 |
+
scales=topk_weights,
|
| 424 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 425 |
+
export_for_source_row=topk_ids,
|
| 426 |
+
)
|
| 427 |
+
if len(original_shape) == 3:
|
| 428 |
+
final_hidden_states = final_hidden_states.view(original_shape)
|
| 429 |
+
return final_hidden_states
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
# currently expert parallelism implemented with all2all
|
| 433 |
+
# is under-optimized.
|
| 434 |
+
def fused_experts_with_all2all_buffer(
|
| 435 |
+
hidden_states: torch.Tensor,
|
| 436 |
+
w1: torch.Tensor,
|
| 437 |
+
w2: torch.Tensor,
|
| 438 |
+
topk_weights: torch.Tensor,
|
| 439 |
+
topk_ids: torch.Tensor,
|
| 440 |
+
top_k: int,
|
| 441 |
+
max_model_len: int,
|
| 442 |
+
global_batch_size: int,
|
| 443 |
+
expert_map: torch.Tensor = None,
|
| 444 |
+
ep_group: GroupCoordinator = None,
|
| 445 |
+
):
|
| 446 |
+
original_shape = hidden_states.shape
|
| 447 |
+
if len(original_shape) == 3:
|
| 448 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 449 |
+
|
| 450 |
+
num_tokens, _ = hidden_states.shape
|
| 451 |
+
device = hidden_states.device
|
| 452 |
+
|
| 453 |
+
global_num_experts = len(expert_map)
|
| 454 |
+
local_num_experts = global_num_experts // ep_group.world_size
|
| 455 |
+
row_idx_len = num_tokens * top_k
|
| 456 |
+
row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32,
|
| 457 |
+
device=device).view(top_k,
|
| 458 |
+
-1).permute(1, 0).contiguous())
|
| 459 |
+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 460 |
+
hidden_states,
|
| 461 |
+
row_idx=row_idx,
|
| 462 |
+
expert_idx=topk_ids,
|
| 463 |
+
active_num=num_tokens)
|
| 464 |
+
|
| 465 |
+
max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) *
|
| 466 |
+
max_model_len // ep_group.world_size +
|
| 467 |
+
1) * top_k * 2
|
| 468 |
+
expert_idx_buffer_scatter, unpad_indices = process_topk_ids(
|
| 469 |
+
expanded_expert_idx, global_num_experts, ep_group.world_size,
|
| 470 |
+
max_row_per_ep_rank, num_tokens, top_k)
|
| 471 |
+
hidden_states_pad_idx = torch.zeros(
|
| 472 |
+
expert_idx_buffer_scatter.shape,
|
| 473 |
+
dtype=expert_idx_buffer_scatter.dtype,
|
| 474 |
+
device=expert_idx_buffer_scatter.device)
|
| 475 |
+
non_pad_len = torch.sum((expert_idx_buffer_scatter
|
| 476 |
+
!= global_num_experts).to(torch.int32))
|
| 477 |
+
hidden_states_pad_idx[expert_idx_buffer_scatter !=
|
| 478 |
+
global_num_experts] = torch.arange(
|
| 479 |
+
non_pad_len,
|
| 480 |
+
dtype=expert_idx_buffer_scatter.dtype,
|
| 481 |
+
device=hidden_states.device)
|
| 482 |
+
|
| 483 |
+
hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx]
|
| 484 |
+
expert_idx_buffer_gather = torch.empty_like(
|
| 485 |
+
expert_idx_buffer_scatter,
|
| 486 |
+
dtype=expert_idx_buffer_scatter.dtype,
|
| 487 |
+
device=expert_idx_buffer_scatter.device)
|
| 488 |
+
hidden_states_buffer_gather = torch.empty_like(
|
| 489 |
+
hidden_states_buffer_scatter,
|
| 490 |
+
dtype=hidden_states_buffer_scatter.dtype,
|
| 491 |
+
device=hidden_states_buffer_scatter.device)
|
| 492 |
+
dist.all_to_all_single(expert_idx_buffer_gather,
|
| 493 |
+
expert_idx_buffer_scatter,
|
| 494 |
+
group=ep_group.device_group)
|
| 495 |
+
dist.all_to_all_single(hidden_states_buffer_gather,
|
| 496 |
+
hidden_states_buffer_scatter,
|
| 497 |
+
group=ep_group.device_group)
|
| 498 |
+
mask = expert_idx_buffer_gather != global_num_experts
|
| 499 |
+
local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * (
|
| 500 |
+
global_num_experts // ep_group.world_size)
|
| 501 |
+
hidden_states = hidden_states_buffer_gather[mask]
|
| 502 |
+
idx_type = local_expert_idx.dtype
|
| 503 |
+
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float())
|
| 504 |
+
sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type)
|
| 505 |
+
|
| 506 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 507 |
+
sorted_local_expert_idx, local_num_experts).to(torch.int64)
|
| 508 |
+
hidden_states = hidden_states[sorted_idx]
|
| 509 |
+
group_list_type = 0
|
| 510 |
+
|
| 511 |
+
hidden_states_wrapper = [hidden_states]
|
| 512 |
+
del hidden_states
|
| 513 |
+
|
| 514 |
+
hidden_states = apply_mlp(hidden_states_wrapper,
|
| 515 |
+
w1,
|
| 516 |
+
w2,
|
| 517 |
+
expert_tokens,
|
| 518 |
+
group_list_type=group_list_type)
|
| 519 |
+
|
| 520 |
+
resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype)
|
| 521 |
+
hidden_states = hidden_states[resorted_idx]
|
| 522 |
+
hidden_states_scatter = torch.zeros(
|
| 523 |
+
(mask.shape[0], hidden_states.shape[1]),
|
| 524 |
+
dtype=hidden_states.dtype,
|
| 525 |
+
device=hidden_states.device)
|
| 526 |
+
hidden_states_scatter[mask] = hidden_states
|
| 527 |
+
hidden_states_gatter = torch.empty_like(
|
| 528 |
+
hidden_states_scatter,
|
| 529 |
+
dtype=hidden_states_scatter.dtype,
|
| 530 |
+
device=hidden_states_scatter.device)
|
| 531 |
+
dist.all_to_all_single(hidden_states_gatter,
|
| 532 |
+
hidden_states_scatter,
|
| 533 |
+
group=ep_group.device_group)
|
| 534 |
+
hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter !=
|
| 535 |
+
global_num_experts]
|
| 536 |
+
if hidden_states_gatter.shape[0] != row_idx_len:
|
| 537 |
+
hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]),
|
| 538 |
+
dtype=hidden_states.dtype,
|
| 539 |
+
device=hidden_states.device)
|
| 540 |
+
hidden_states[unpad_indices != -1] = hidden_states_gatter
|
| 541 |
+
else:
|
| 542 |
+
# TODO: Reorder device memory 2 times here, replace the current
|
| 543 |
+
hidden_states = hidden_states_gatter
|
| 544 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 545 |
+
hidden_states,
|
| 546 |
+
skip1=None,
|
| 547 |
+
skip2=None,
|
| 548 |
+
bias=None,
|
| 549 |
+
scales=topk_weights,
|
| 550 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 551 |
+
export_for_source_row=topk_ids,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
if len(original_shape) == 3:
|
| 555 |
+
final_hidden_states = final_hidden_states.view(original_shape)
|
| 556 |
+
return final_hidden_states
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def fused_experts_moge(
|
| 560 |
+
hidden_states: torch.Tensor,
|
| 561 |
+
w1: torch.Tensor,
|
| 562 |
+
w2: torch.Tensor,
|
| 563 |
+
topk_weights: torch.Tensor,
|
| 564 |
+
topk_ids: torch.Tensor,
|
| 565 |
+
top_k: int,
|
| 566 |
+
global_num_experts: int,
|
| 567 |
+
expert_map: torch.Tensor = None,
|
| 568 |
+
apply_router_weight_on_input: bool = False,
|
| 569 |
+
) -> torch.Tensor:
|
| 570 |
+
"""
|
| 571 |
+
|
| 572 |
+
Args:
|
| 573 |
+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
| 574 |
+
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
| 575 |
+
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
| 576 |
+
topk_weights: Routing weights of shape (num_tokens, top_k).
|
| 577 |
+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
| 578 |
+
top_k: Number of experts to select.
|
| 579 |
+
expert_map: Expert mapping of shape (num_experts,).
|
| 580 |
+
|
| 581 |
+
Returns:
|
| 582 |
+
hidden_states: Hidden states after routing.
|
| 583 |
+
"""
|
| 584 |
+
ep_size = get_ep_group().world_size
|
| 585 |
+
local_num_experts = global_num_experts // ep_size
|
| 586 |
+
local_num_group = top_k // ep_size
|
| 587 |
+
|
| 588 |
+
if apply_router_weight_on_input:
|
| 589 |
+
assert (topk_weights.dim() == 2
|
| 590 |
+
), "`topk_weights` should be in shape (num_tokens, topk)"
|
| 591 |
+
_, topk = topk_weights.shape
|
| 592 |
+
assert (
|
| 593 |
+
topk == 1
|
| 594 |
+
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
| 595 |
+
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
| 596 |
+
|
| 597 |
+
bsz, _ = hidden_states.shape
|
| 598 |
+
flatten_topk_ids = topk_ids.view(-1)
|
| 599 |
+
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
| 600 |
+
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
| 601 |
+
sorted_hidden_states = hidden_states.index_select(
|
| 602 |
+
0, sorted_topk_ids // local_num_group)
|
| 603 |
+
|
| 604 |
+
experts_id = torch.arange(0,
|
| 605 |
+
local_num_experts,
|
| 606 |
+
dtype=topk_ids.dtype,
|
| 607 |
+
device=topk_ids.device)
|
| 608 |
+
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
| 609 |
+
torch.float32).sum(0)
|
| 610 |
+
topk_scales = topk_weights.view(-1).index_select(
|
| 611 |
+
0, sorted_topk_ids).unsqueeze(-1)
|
| 612 |
+
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
| 613 |
+
|
| 614 |
+
w1 = w1.transpose(1, 2)
|
| 615 |
+
gate_up_out = torch_npu.npu_grouped_matmul(
|
| 616 |
+
x=[sorted_hidden_states],
|
| 617 |
+
weight=[w1],
|
| 618 |
+
split_item=2,
|
| 619 |
+
group_list_type=0,
|
| 620 |
+
group_type=0,
|
| 621 |
+
group_list=group_list,
|
| 622 |
+
)[0]
|
| 623 |
+
|
| 624 |
+
if is_310p():
|
| 625 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
| 626 |
+
torch.float16)
|
| 627 |
+
else:
|
| 628 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
| 629 |
+
gate_up_out *= topk_scales
|
| 630 |
+
|
| 631 |
+
w2 = w2.transpose(1, 2)
|
| 632 |
+
down_out_list = torch_npu.npu_grouped_matmul(
|
| 633 |
+
x=[gate_up_out],
|
| 634 |
+
weight=[w2],
|
| 635 |
+
split_item=2,
|
| 636 |
+
group_list_type=0,
|
| 637 |
+
group_type=0,
|
| 638 |
+
group_list=group_list,
|
| 639 |
+
)[0]
|
| 640 |
+
|
| 641 |
+
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
| 642 |
+
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
|
| 643 |
+
final_hidden_states = unsorted_hidden_states.reshape(
|
| 644 |
+
bsz, top_k // ep_size, -1).sum(1)
|
| 645 |
+
|
| 646 |
+
return final_hidden_states
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def fused_experts(
|
| 650 |
+
hidden_states: torch.Tensor,
|
| 651 |
+
w1: torch.Tensor,
|
| 652 |
+
w2: torch.Tensor,
|
| 653 |
+
topk_weights: torch.Tensor,
|
| 654 |
+
topk_ids: torch.Tensor,
|
| 655 |
+
top_k: int,
|
| 656 |
+
expert_map: torch.Tensor = None,
|
| 657 |
+
apply_router_weight_on_input: bool = False,
|
| 658 |
+
max_num_tokens: Optional[int] = None,
|
| 659 |
+
) -> torch.Tensor:
|
| 660 |
+
"""
|
| 661 |
+
Fused experts with top-k routing.
|
| 662 |
+
|
| 663 |
+
Args:
|
| 664 |
+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
| 665 |
+
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
| 666 |
+
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
| 667 |
+
topk_weights: Routing weights of shape (num_tokens, top_k).
|
| 668 |
+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
| 669 |
+
top_k: Number of experts to select.
|
| 670 |
+
expert_map: Expert mapping of shape (num_experts,).
|
| 671 |
+
|
| 672 |
+
Returns:
|
| 673 |
+
hidden_states: Hidden states after routing.
|
| 674 |
+
"""
|
| 675 |
+
"""
|
| 676 |
+
# Check constraints.
|
| 677 |
+
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
| 678 |
+
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
| 679 |
+
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
| 680 |
+
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
| 681 |
+
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
| 682 |
+
"""
|
| 683 |
+
# if torch.distributed.get_rank() == 0:
|
| 684 |
+
# print(w1.shape)
|
| 685 |
+
# print(hidden_states.shape)
|
| 686 |
+
|
| 687 |
+
original_shape = hidden_states.shape
|
| 688 |
+
# assert len(original_shape) == 2
|
| 689 |
+
|
| 690 |
+
num_tokens = hidden_states.shape[:-1].numel()
|
| 691 |
+
num_experts = w1.shape[0]
|
| 692 |
+
dtype = hidden_states.dtype
|
| 693 |
+
device = hidden_states.device
|
| 694 |
+
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
|
| 695 |
+
# ], "Only float32, float16, and bfloat16 are supported"
|
| 696 |
+
|
| 697 |
+
if apply_router_weight_on_input:
|
| 698 |
+
assert (topk_weights.dim() == 2
|
| 699 |
+
), "`topk_weights` should be in shape (num_tokens, topk)"
|
| 700 |
+
_, topk = topk_weights.shape
|
| 701 |
+
assert (
|
| 702 |
+
topk == 1
|
| 703 |
+
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
| 704 |
+
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
| 705 |
+
|
| 706 |
+
if expert_map is not None:
|
| 707 |
+
# Generate token indices and flatten
|
| 708 |
+
token_indices = (torch.arange(num_tokens,
|
| 709 |
+
device=device,
|
| 710 |
+
dtype=torch.int64).unsqueeze(1).expand(
|
| 711 |
+
-1, top_k).reshape(-1))
|
| 712 |
+
|
| 713 |
+
# Flatten token-to-expert mappings and map to local experts
|
| 714 |
+
weights_flat = topk_weights.view(-1)
|
| 715 |
+
experts_flat = topk_ids.view(-1)
|
| 716 |
+
local_experts_flat = expert_map[experts_flat]
|
| 717 |
+
|
| 718 |
+
# Filter valid token-expert pairs
|
| 719 |
+
mask = local_experts_flat != -1
|
| 720 |
+
filtered_weights = torch.where(
|
| 721 |
+
mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype)
|
| 722 |
+
filtered_experts = torch.where(
|
| 723 |
+
mask, local_experts_flat,
|
| 724 |
+
torch.full_like(local_experts_flat,
|
| 725 |
+
num_experts)).to(topk_ids.dtype)
|
| 726 |
+
|
| 727 |
+
# Sort by local expert IDs
|
| 728 |
+
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
| 729 |
+
sorted_token_indices = token_indices[sort_indices]
|
| 730 |
+
sorted_weights = filtered_weights[sort_indices]
|
| 731 |
+
|
| 732 |
+
# Compute token counts with minlength of num_experts
|
| 733 |
+
# This is equivalent to but faster than:
|
| 734 |
+
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
| 735 |
+
token_counts = torch.zeros(num_experts + 1,
|
| 736 |
+
device=device,
|
| 737 |
+
dtype=torch.int64)
|
| 738 |
+
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
| 739 |
+
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
| 740 |
+
token_counts = token_counts[:num_experts]
|
| 741 |
+
expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64)
|
| 742 |
+
|
| 743 |
+
# Rearrange hidden_states
|
| 744 |
+
sorted_hidden_states = hidden_states[sorted_token_indices]
|
| 745 |
+
else:
|
| 746 |
+
row_idx_len = num_tokens * top_k
|
| 747 |
+
row_idx = (torch.arange(0,
|
| 748 |
+
row_idx_len,
|
| 749 |
+
dtype=torch.int32,
|
| 750 |
+
device=device).view(top_k, -1).permute(
|
| 751 |
+
1, 0).contiguous())
|
| 752 |
+
active_num = max_num_tokens if max_num_tokens is not None else num_tokens
|
| 753 |
+
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 754 |
+
hidden_states,
|
| 755 |
+
row_idx=row_idx,
|
| 756 |
+
expert_idx=topk_ids,
|
| 757 |
+
active_num=active_num)
|
| 758 |
+
|
| 759 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 760 |
+
expanded_expert_idx, num_experts)
|
| 761 |
+
expert_tokens = expert_tokens.to(torch.int64)
|
| 762 |
+
|
| 763 |
+
w1 = w1.transpose(1, 2)
|
| 764 |
+
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
| 765 |
+
x=[sorted_hidden_states],
|
| 766 |
+
weight=[w1],
|
| 767 |
+
split_item=2,
|
| 768 |
+
group_list_type=0,
|
| 769 |
+
group_type=0,
|
| 770 |
+
group_list=expert_tokens,
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
# TODO: Remove this in the future.
|
| 774 |
+
gate_up_out = torch.cat(gate_up_out_list, dim=0)
|
| 775 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
| 776 |
+
|
| 777 |
+
w2 = w2.transpose(1, 2)
|
| 778 |
+
down_out_list = torch_npu.npu_grouped_matmul(
|
| 779 |
+
x=[gate_up_out],
|
| 780 |
+
weight=[w2],
|
| 781 |
+
split_item=2,
|
| 782 |
+
group_list_type=0,
|
| 783 |
+
group_type=0,
|
| 784 |
+
group_list=expert_tokens,
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
down_out_list = torch.cat(down_out_list, dim=0)
|
| 788 |
+
|
| 789 |
+
if expert_map is not None:
|
| 790 |
+
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)
|
| 791 |
+
|
| 792 |
+
final_hidden_states = torch.zeros(*original_shape,
|
| 793 |
+
device=hidden_states.device,
|
| 794 |
+
dtype=dtype)
|
| 795 |
+
|
| 796 |
+
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
| 797 |
+
# This created multiple NaN and index_add_ will mix them up which harms accuracy
|
| 798 |
+
# remove this mask and filter after it being fixed
|
| 799 |
+
num_valid_tokens = mask.sum()
|
| 800 |
+
valid_token_mask = torch.arange(
|
| 801 |
+
0, sorted_token_indices.shape[0],
|
| 802 |
+
device=device).unsqueeze(1) < num_valid_tokens
|
| 803 |
+
valid_output = torch.where(
|
| 804 |
+
valid_token_mask, weighted_down_out,
|
| 805 |
+
torch.zeros_like(weighted_down_out)).to(dtype)
|
| 806 |
+
final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
|
| 807 |
+
else:
|
| 808 |
+
scales = torch.ones_like(
|
| 809 |
+
topk_weights) if apply_router_weight_on_input else topk_weights
|
| 810 |
+
# TODO: Reorder device memory 2 times here, replace the current
|
| 811 |
+
# implementation here when suitable operators become available.
|
| 812 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 813 |
+
down_out_list,
|
| 814 |
+
skip1=None,
|
| 815 |
+
skip2=None,
|
| 816 |
+
bias=None,
|
| 817 |
+
scales=scales,
|
| 818 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 819 |
+
export_for_source_row=topk_ids,
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
return final_hidden_states
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def fused_experts_allgather_ep(
|
| 826 |
+
hidden_states: torch.Tensor,
|
| 827 |
+
w1: torch.Tensor,
|
| 828 |
+
w2: torch.Tensor,
|
| 829 |
+
topk_weights: torch.Tensor,
|
| 830 |
+
topk_ids: torch.Tensor,
|
| 831 |
+
is_prefill: bool
|
| 832 |
+
):
|
| 833 |
+
local_rank = torch.distributed.get_rank(group=get_ep_group().device_group)
|
| 834 |
+
num_experts_per_ep = w1.shape[0]
|
| 835 |
+
local_expert_indices_offset = local_rank * num_experts_per_ep
|
| 836 |
+
global_local_mask = (topk_ids >= local_expert_indices_offset) & \
|
| 837 |
+
(topk_ids <= local_expert_indices_offset + num_experts_per_ep - 1)
|
| 838 |
+
non_global_local_mask = (~global_local_mask).to(torch.int32)
|
| 839 |
+
global_local_mask = global_local_mask.to(torch.int32)
|
| 840 |
+
row_idx = torch.arange(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32).view(
|
| 841 |
+
-1, topk_ids.shape[0]).transpose(0, 1).contiguous()
|
| 842 |
+
|
| 843 |
+
topk_ids -= local_expert_indices_offset
|
| 844 |
+
local_topk_ids_mask_with_max = topk_ids * global_local_mask + non_global_local_mask * num_experts_per_ep
|
| 845 |
+
sorted_tokens, expanded_src_to_dst_row, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 846 |
+
x=hidden_states,
|
| 847 |
+
row_idx=row_idx,
|
| 848 |
+
expert_idx=local_topk_ids_mask_with_max,
|
| 849 |
+
active_num=topk_ids.shape[0]*topk_ids.shape[1]
|
| 850 |
+
)
|
| 851 |
+
if expanded_expert_idx.shape[0] > 8192:
|
| 852 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(expanded_expert_idx, num_experts_per_ep + 1)
|
| 853 |
+
expert_tokens = expert_tokens[:-1]
|
| 854 |
+
else:
|
| 855 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(expanded_expert_idx, num_experts_per_ep)
|
| 856 |
+
expert_tokens = expert_tokens.to(torch.int64)
|
| 857 |
+
|
| 858 |
+
w1 = w1.transpose(1, 2)
|
| 859 |
+
gate_up_out = torch_npu.npu_grouped_matmul(
|
| 860 |
+
x=[sorted_tokens],
|
| 861 |
+
weight=[w1],
|
| 862 |
+
group_list=expert_tokens,
|
| 863 |
+
split_item=3,
|
| 864 |
+
group_type=0
|
| 865 |
+
)[0]
|
| 866 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
| 867 |
+
|
| 868 |
+
w2 = w2.transpose(1, 2)
|
| 869 |
+
down_out = torch_npu.npu_grouped_matmul(
|
| 870 |
+
x=[gate_up_out],
|
| 871 |
+
weight=[w2],
|
| 872 |
+
group_list=expert_tokens,
|
| 873 |
+
split_item=3,
|
| 874 |
+
group_type=0
|
| 875 |
+
)[0]
|
| 876 |
+
|
| 877 |
+
if is_prefill:
|
| 878 |
+
down_out[expert_tokens[-1]:] = 0
|
| 879 |
+
else:
|
| 880 |
+
sorted_tokens_mask = expanded_expert_idx != num_experts_per_ep
|
| 881 |
+
down_out *= sorted_tokens_mask.unsqueeze(1)
|
| 882 |
+
|
| 883 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 884 |
+
expanded_permuted_rows=down_out,
|
| 885 |
+
skip1=None,
|
| 886 |
+
skip2=None,
|
| 887 |
+
bias=None,
|
| 888 |
+
scales=topk_weights.to(down_out.dtype),
|
| 889 |
+
expanded_src_to_dst_row=expanded_src_to_dst_row,
|
| 890 |
+
export_for_source_row=topk_ids
|
| 891 |
+
)
|
| 892 |
+
return final_hidden_states
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
def select_gating_top_k_softmax_experts(
|
| 896 |
+
hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int,
|
| 897 |
+
renormalize: bool) -> tuple[torch.Tensor, torch.Tensor]:
|
| 898 |
+
"""
|
| 899 |
+
Select top-k experts based on router logits.
|
| 900 |
+
only supports float16、bfloat16、float32
|
| 901 |
+
|
| 902 |
+
Args:
|
| 903 |
+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
| 904 |
+
router_logits: Router logits of shape (num_tokens, num_experts).
|
| 905 |
+
top_k: Number of experts to select.
|
| 906 |
+
renormalize: Whether to renormalize the routing weights.
|
| 907 |
+
|
| 908 |
+
Returns:
|
| 909 |
+
topk_weights: Routing weights of shape (num_tokens, top_k).
|
| 910 |
+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
| 911 |
+
|
| 912 |
+
Raises:
|
| 913 |
+
ValueError: If an unsupported scoring function is provided.
|
| 914 |
+
"""
|
| 915 |
+
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
|
| 916 |
+
router_logits, None, k=top_k)
|
| 917 |
+
|
| 918 |
+
# # Required by npu_moe_init_routing
|
| 919 |
+
# topk_weights = topk_weights.to(hidden_states.dtype)
|
| 920 |
+
# topk_ids = topk_ids.to(torch.int32)
|
| 921 |
+
|
| 922 |
+
if renormalize:
|
| 923 |
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
| 924 |
+
|
| 925 |
+
return topk_weights, topk_ids
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
def native_grouped_topk(
|
| 929 |
+
topk_weights: torch.Tensor,
|
| 930 |
+
num_expert_group: Optional[int],
|
| 931 |
+
topk_group: Optional[int],
|
| 932 |
+
):
|
| 933 |
+
topk_group = 0 if topk_group is None else topk_group
|
| 934 |
+
num_expert_group = 0 if num_expert_group is None else num_expert_group
|
| 935 |
+
|
| 936 |
+
num_token = topk_weights.shape[0]
|
| 937 |
+
grouped_weights = topk_weights.view(num_token, num_expert_group,
|
| 938 |
+
-1).max(dim=-1).values
|
| 939 |
+
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
|
| 940 |
+
k=topk_group,
|
| 941 |
+
dim=-1,
|
| 942 |
+
sorted=False)[1]
|
| 943 |
+
topk_group_mask = torch.zeros_like(grouped_weights)
|
| 944 |
+
topk_group_mask.scatter_(1, topk_group_indices, 1)
|
| 945 |
+
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
|
| 946 |
+
num_token, num_expert_group,
|
| 947 |
+
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
|
| 948 |
+
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
|
| 949 |
+
|
| 950 |
+
return topk_weights
|
| 951 |
+
|
| 952 |
+
|
| 953 |
+
def select_experts(
|
| 954 |
+
hidden_states: torch.Tensor,
|
| 955 |
+
router_logits: torch.Tensor,
|
| 956 |
+
top_k: int,
|
| 957 |
+
use_grouped_topk: bool,
|
| 958 |
+
renormalize: bool,
|
| 959 |
+
topk_group: Optional[int] = None,
|
| 960 |
+
num_expert_group: Optional[int] = None,
|
| 961 |
+
custom_routing_function: Optional[Callable] = None,
|
| 962 |
+
scoring_func: str = "softmax",
|
| 963 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 964 |
+
global_num_experts: Optional[torch.Tensor] = None
|
| 965 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 966 |
+
"""
|
| 967 |
+
Select top-k experts based on router logits.
|
| 968 |
+
|
| 969 |
+
Args:
|
| 970 |
+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
| 971 |
+
router_logits: Router logits of shape (num_tokens, num_experts).
|
| 972 |
+
top_k: Number of experts to select.
|
| 973 |
+
use_grouped_topk: Whether to group experts before selecting top-k.
|
| 974 |
+
renormalize: Whether to renormalize the routing weights.
|
| 975 |
+
topk_group: Number of expert groups to select from.
|
| 976 |
+
num_expert_group: Number of experts in each group.
|
| 977 |
+
custom_routing_function: Custom routing function.
|
| 978 |
+
scoring_func: Scoring function to use.
|
| 979 |
+
e_score_correction_bias: Correction bias to apply to expert scores.
|
| 980 |
+
|
| 981 |
+
Returns:
|
| 982 |
+
topk_weights: Routing weights of shape (num_tokens, top_k).
|
| 983 |
+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
| 984 |
+
|
| 985 |
+
Raises:
|
| 986 |
+
ValueError: If an unsupported scoring function is provided.
|
| 987 |
+
"""
|
| 988 |
+
|
| 989 |
+
if scoring_func == "softmax":
|
| 990 |
+
# NOTE: vLLM use dtype=torch.float here
|
| 991 |
+
topk_weights = router_logits.softmax(dim=-1)
|
| 992 |
+
elif scoring_func == "sigmoid":
|
| 993 |
+
topk_weights = router_logits.sigmoid()
|
| 994 |
+
else:
|
| 995 |
+
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
| 996 |
+
|
| 997 |
+
if use_grouped_topk:
|
| 998 |
+
assert topk_group is not None
|
| 999 |
+
assert num_expert_group is not None
|
| 1000 |
+
|
| 1001 |
+
if e_score_correction_bias is not None:
|
| 1002 |
+
# Store original scores before applying correction bias. We use biased
|
| 1003 |
+
# scores for expert selection but original scores for routing weights
|
| 1004 |
+
original_weights = topk_weights
|
| 1005 |
+
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
|
| 1006 |
+
|
| 1007 |
+
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
|
| 1008 |
+
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
|
| 1009 |
+
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
|
| 1010 |
+
topk_group)
|
| 1011 |
+
# TODO bfloat16 is not supported in torch.topk with ge graph.
|
| 1012 |
+
if e_score_correction_bias is not None:
|
| 1013 |
+
topk_ids = torch.topk(topk_weights.to(torch.float32),
|
| 1014 |
+
k=top_k,
|
| 1015 |
+
dim=-1,
|
| 1016 |
+
sorted=False)[1]
|
| 1017 |
+
# Use original unbiased scores for the routing weights
|
| 1018 |
+
topk_weights = original_weights.gather(1, topk_ids)
|
| 1019 |
+
else:
|
| 1020 |
+
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
|
| 1021 |
+
k=top_k,
|
| 1022 |
+
dim=-1,
|
| 1023 |
+
sorted=False)
|
| 1024 |
+
elif custom_routing_function is None:
|
| 1025 |
+
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
| 1026 |
+
else:
|
| 1027 |
+
topk_weights, topk_ids = custom_routing_function(
|
| 1028 |
+
hidden_states=hidden_states,
|
| 1029 |
+
gating_output=router_logits,
|
| 1030 |
+
topk=top_k,
|
| 1031 |
+
renormalize=renormalize,
|
| 1032 |
+
global_num_experts=global_num_experts)
|
| 1033 |
+
# Required by npu_moe_init_routing
|
| 1034 |
+
topk_ids = topk_ids.to(torch.int32)
|
| 1035 |
+
return topk_weights, topk_ids
|
| 1036 |
+
|
| 1037 |
+
# Required by npu_moe_init_routing
|
| 1038 |
+
topk_ids = topk_ids.to(torch.int32)
|
| 1039 |
+
|
| 1040 |
+
if renormalize:
|
| 1041 |
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
| 1042 |
+
|
| 1043 |
+
return topk_weights, topk_ids
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
| 1047 |
+
|
| 1048 |
+
def __init__(self, moe: FusedMoEConfig = None):
|
| 1049 |
+
|
| 1050 |
+
super().__init__(moe=moe)
|
| 1051 |
+
vllm_config = get_current_vllm_config()
|
| 1052 |
+
|
| 1053 |
+
self.ep_group = get_ep_group()
|
| 1054 |
+
self.ep_size = self.ep_group.world_size
|
| 1055 |
+
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
| 1056 |
+
self.local_batch_size = self.global_batch_size // self.ep_size
|
| 1057 |
+
self.max_model_len = vllm_config.model_config.max_model_len
|
| 1058 |
+
|
| 1059 |
+
ascend_config = get_ascend_config()
|
| 1060 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 1061 |
+
|
| 1062 |
+
try:
|
| 1063 |
+
device_group = self.ep_group.device_group
|
| 1064 |
+
# TODO: Try local_rank = ep_group.rank_in_group
|
| 1065 |
+
local_rank = torch.distributed.get_rank(group=device_group)
|
| 1066 |
+
backend = device_group._get_backend(torch.device("npu"))
|
| 1067 |
+
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
| 1068 |
+
local_rank)
|
| 1069 |
+
except AttributeError:
|
| 1070 |
+
self.moe_all_to_all_group_name = None
|
| 1071 |
+
|
| 1072 |
+
def process_weights_after_loading(self, layer):
|
| 1073 |
+
super(UnquantizedFusedMoEMethod,
|
| 1074 |
+
self).process_weights_after_loading(layer)
|
| 1075 |
+
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
| 1076 |
+
layer.w13_weight.data),
|
| 1077 |
+
requires_grad=False)
|
| 1078 |
+
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
| 1079 |
+
layer.w2_weight.data),
|
| 1080 |
+
requires_grad=False)
|
| 1081 |
+
|
| 1082 |
+
def apply(
|
| 1083 |
+
self,
|
| 1084 |
+
layer: torch.nn.Module,
|
| 1085 |
+
x: torch.Tensor,
|
| 1086 |
+
router_logits: torch.Tensor,
|
| 1087 |
+
top_k: int,
|
| 1088 |
+
renormalize: bool,
|
| 1089 |
+
use_grouped_topk: bool = False,
|
| 1090 |
+
global_num_experts: int = -1,
|
| 1091 |
+
expert_map: Optional[torch.Tensor] = None,
|
| 1092 |
+
topk_group: Optional[int] = None,
|
| 1093 |
+
num_expert_group: Optional[int] = None,
|
| 1094 |
+
custom_routing_function: Optional[Callable] = None,
|
| 1095 |
+
scoring_func: str = "softmax",
|
| 1096 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 1097 |
+
is_prefill: bool = False,
|
| 1098 |
+
enable_force_load_balance: bool = False,
|
| 1099 |
+
shared_experts: Optional[Any] = None,
|
| 1100 |
+
**kwargs,
|
| 1101 |
+
) -> torch.Tensor:
|
| 1102 |
+
use_grouped_topk = (topk_group > 1 or num_expert_group > 1)
|
| 1103 |
+
is_deepseek_v3_r1 = global_num_experts == 256
|
| 1104 |
+
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
| 1105 |
+
if use_grouped_topk and is_deepseek_v3_r1:
|
| 1106 |
+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
| 1107 |
+
router_logits,
|
| 1108 |
+
k=top_k, # topk当前写8
|
| 1109 |
+
bias=e_score_correction_bias,
|
| 1110 |
+
k_group=topk_group, # fix: 4
|
| 1111 |
+
group_count=num_expert_group, # fix 8
|
| 1112 |
+
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
|
| 1113 |
+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
| 1114 |
+
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
| 1115 |
+
# out_flag=False, # todo new api; 第三个输出是否输出
|
| 1116 |
+
# y2_flag=False, # old api; 第三个输出是否输出
|
| 1117 |
+
routed_scaling_factor=1,
|
| 1118 |
+
eps=float(1e-20))
|
| 1119 |
+
elif use_grouped_topk and SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
|
| 1120 |
+
topk_weights, topk_ids = select_gating_top_k_softmax_experts(
|
| 1121 |
+
hidden_states=x,
|
| 1122 |
+
router_logits=router_logits,
|
| 1123 |
+
top_k=top_k,
|
| 1124 |
+
renormalize=renormalize)
|
| 1125 |
+
else:
|
| 1126 |
+
topk_weights, topk_ids = select_experts(
|
| 1127 |
+
hidden_states=x,
|
| 1128 |
+
router_logits=router_logits,
|
| 1129 |
+
top_k=top_k,
|
| 1130 |
+
use_grouped_topk=use_grouped_topk,
|
| 1131 |
+
renormalize=renormalize,
|
| 1132 |
+
topk_group=topk_group,
|
| 1133 |
+
num_expert_group=num_expert_group,
|
| 1134 |
+
custom_routing_function=custom_routing_function,
|
| 1135 |
+
scoring_func=scoring_func,
|
| 1136 |
+
e_score_correction_bias=e_score_correction_bias,
|
| 1137 |
+
)
|
| 1138 |
+
|
| 1139 |
+
topk_weights = topk_weights.to(x.dtype)
|
| 1140 |
+
# this is a naive implementation for experts load balance so as
|
| 1141 |
+
# to avoid accumulating too much tokens on a single rank.
|
| 1142 |
+
# currently it is only activated when doing profile runs.
|
| 1143 |
+
if enable_force_load_balance:
|
| 1144 |
+
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
| 1145 |
+
|
| 1146 |
+
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
|
| 1147 |
+
is_prefill, is_deepseek_v3_r1)
|
| 1148 |
+
if fused_moe_state == FusedMoEState.MC2:
|
| 1149 |
+
return fused_experts_with_mc2(
|
| 1150 |
+
hidden_states=x,
|
| 1151 |
+
w1=layer.w13_weight,
|
| 1152 |
+
w2=layer.w2_weight,
|
| 1153 |
+
topk_weights=topk_weights,
|
| 1154 |
+
topk_ids=topk_ids,
|
| 1155 |
+
top_k=top_k,
|
| 1156 |
+
expert_map=expert_map,
|
| 1157 |
+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
| 1158 |
+
shared_experts=shared_experts)
|
| 1159 |
+
elif fused_moe_state == FusedMoEState.AllGatherEP:
|
| 1160 |
+
return fused_experts_allgather_ep(
|
| 1161 |
+
hidden_states=x,
|
| 1162 |
+
w1=layer.w13_weight,
|
| 1163 |
+
w2=layer.w2_weight,
|
| 1164 |
+
topk_weights=topk_weights,
|
| 1165 |
+
topk_ids=topk_ids,
|
| 1166 |
+
is_prefill=is_prefill)
|
| 1167 |
+
elif fused_moe_state in [
|
| 1168 |
+
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
| 1169 |
+
]:
|
| 1170 |
+
return fused_experts(hidden_states=x,
|
| 1171 |
+
w1=layer.w13_weight,
|
| 1172 |
+
w2=layer.w2_weight,
|
| 1173 |
+
topk_weights=topk_weights,
|
| 1174 |
+
topk_ids=topk_ids,
|
| 1175 |
+
top_k=top_k,
|
| 1176 |
+
expert_map=expert_map)
|
| 1177 |
+
elif MOE_ALL2ALL_BUFFER:
|
| 1178 |
+
return fused_experts_with_all2all_buffer(
|
| 1179 |
+
hidden_states=x,
|
| 1180 |
+
w1=layer.w13_weight,
|
| 1181 |
+
w2=layer.w2_weight,
|
| 1182 |
+
topk_weights=topk_weights,
|
| 1183 |
+
topk_ids=topk_ids,
|
| 1184 |
+
top_k=top_k,
|
| 1185 |
+
max_model_len=self.max_model_len,
|
| 1186 |
+
global_batch_size=self.global_batch_size,
|
| 1187 |
+
expert_map=expert_map,
|
| 1188 |
+
ep_group=get_ep_group())
|
| 1189 |
+
else:
|
| 1190 |
+
return fused_experts_with_all2all(hidden_states=x,
|
| 1191 |
+
w1=layer.w13_weight,
|
| 1192 |
+
w2=layer.w2_weight,
|
| 1193 |
+
topk_weights=topk_weights,
|
| 1194 |
+
topk_ids=topk_ids,
|
| 1195 |
+
top_k=top_k,
|
| 1196 |
+
expert_map=expert_map,
|
| 1197 |
+
ep_group=get_ep_group())
|
| 1198 |
+
|
| 1199 |
+
|
| 1200 |
+
class AscendFusedMoE(FusedMoE):
|
| 1201 |
+
|
| 1202 |
+
# The moe_counter parameter is required during the initialization of EPLB
|
| 1203 |
+
# to identify the current layer index within the MOE model.
|
| 1204 |
+
moe_counter = -1
|
| 1205 |
+
|
| 1206 |
+
def __init__(
|
| 1207 |
+
self,
|
| 1208 |
+
num_experts: int, # Global number of experts
|
| 1209 |
+
top_k: int,
|
| 1210 |
+
hidden_size: int,
|
| 1211 |
+
intermediate_size: int,
|
| 1212 |
+
params_dtype: Optional[torch.dtype] = None,
|
| 1213 |
+
reduce_results: bool = False,
|
| 1214 |
+
renormalize: bool = True,
|
| 1215 |
+
use_grouped_topk: bool = False,
|
| 1216 |
+
num_expert_group: Optional[int] = None,
|
| 1217 |
+
topk_group: Optional[int] = None,
|
| 1218 |
+
quant_config: Optional[QuantizationConfig] = None,
|
| 1219 |
+
tp_size: Optional[int] = None,
|
| 1220 |
+
ep_size: Optional[int] = None,
|
| 1221 |
+
dp_size: Optional[int] = None,
|
| 1222 |
+
prefix: str = "",
|
| 1223 |
+
custom_routing_function: Optional[Callable] = None,
|
| 1224 |
+
scoring_func: str = "softmax",
|
| 1225 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 1226 |
+
activation: str = "silu",
|
| 1227 |
+
apply_router_weight_on_input: bool = False,
|
| 1228 |
+
):
|
| 1229 |
+
# TODO: This could not initialize FusedMoE baseclass,
|
| 1230 |
+
# fixme and make __init__() of AscendFusedMoE more clear
|
| 1231 |
+
super(FusedMoE, self).__init__()
|
| 1232 |
+
|
| 1233 |
+
AscendFusedMoE.moe_counter += 1
|
| 1234 |
+
self.moe_instance_id = AscendFusedMoE.moe_counter
|
| 1235 |
+
|
| 1236 |
+
if params_dtype is None:
|
| 1237 |
+
params_dtype = torch.get_default_dtype()
|
| 1238 |
+
|
| 1239 |
+
vllm_config = get_current_vllm_config()
|
| 1240 |
+
|
| 1241 |
+
self.moe_parallel_config = FusedMoEParallelConfig.make(
|
| 1242 |
+
tp_size_=(tp_size if tp_size is not None else
|
| 1243 |
+
get_tensor_model_parallel_world_size()),
|
| 1244 |
+
dp_size_=(dp_size
|
| 1245 |
+
if dp_size is not None else get_dp_group().world_size),
|
| 1246 |
+
vllm_parallel_config=vllm_config.parallel_config)
|
| 1247 |
+
|
| 1248 |
+
self.top_k = top_k
|
| 1249 |
+
self.num_experts = num_experts
|
| 1250 |
+
self.global_num_experts = num_experts
|
| 1251 |
+
assert intermediate_size % self.tp_size == 0
|
| 1252 |
+
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
| 1253 |
+
self.reduce_results = reduce_results
|
| 1254 |
+
self.renormalize = renormalize
|
| 1255 |
+
self.use_grouped_topk = use_grouped_topk
|
| 1256 |
+
if self.use_grouped_topk:
|
| 1257 |
+
assert num_expert_group is not None and topk_group is not None
|
| 1258 |
+
self.num_expert_group = num_expert_group
|
| 1259 |
+
self.topk_group = topk_group
|
| 1260 |
+
self.custom_routing_function = custom_routing_function
|
| 1261 |
+
self.scoring_func = scoring_func
|
| 1262 |
+
self.e_score_correction_bias = e_score_correction_bias
|
| 1263 |
+
self.expert_map = None
|
| 1264 |
+
self.activation = activation
|
| 1265 |
+
self.log2phy = None
|
| 1266 |
+
self.global_redundant_expert_num = 0
|
| 1267 |
+
|
| 1268 |
+
is_deepseek_v3_r1 = self.global_num_experts == 256
|
| 1269 |
+
self.rm_router_logits = get_rm_router_logits_state(
|
| 1270 |
+
self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1)
|
| 1271 |
+
self.all_reduce_merge = get_all_reduce_merge_state(
|
| 1272 |
+
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
|
| 1273 |
+
|
| 1274 |
+
ascend_config = get_ascend_config()
|
| 1275 |
+
expert_map_path = ascend_config.expert_map_path
|
| 1276 |
+
if expert_map_path and os.path.exists(expert_map_path):
|
| 1277 |
+
# moe expert load balance
|
| 1278 |
+
expert_load_balancer = ExpertLoadBalancer(expert_map_path,
|
| 1279 |
+
self.global_num_experts)
|
| 1280 |
+
self.local_num_experts, self.expert_map = \
|
| 1281 |
+
expert_load_balancer.get_rank_placement_map(
|
| 1282 |
+
self.moe_instance_id,
|
| 1283 |
+
get_ep_group().rank_in_group)
|
| 1284 |
+
self.log2phy = expert_load_balancer.get_rank_log2phy_map(
|
| 1285 |
+
self.moe_instance_id,
|
| 1286 |
+
get_ep_group().rank_in_group)
|
| 1287 |
+
self.global_redundant_expert_num = \
|
| 1288 |
+
expert_load_balancer.get_global_redundant_expert_num()
|
| 1289 |
+
else:
|
| 1290 |
+
# Create a tensor of size num_experts filled with -1
|
| 1291 |
+
self.local_num_experts, self.expert_map = determine_expert_map(
|
| 1292 |
+
self.ep_size,
|
| 1293 |
+
get_ep_group().rank_in_group, self.global_num_experts)
|
| 1294 |
+
|
| 1295 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 1296 |
+
self.enable_multistream_moe = \
|
| 1297 |
+
ascend_config.torchair_graph_config.enable_multistream_moe
|
| 1298 |
+
|
| 1299 |
+
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
| 1300 |
+
raise ValueError("Only softmax scoring function is supported for "
|
| 1301 |
+
"non-grouped topk.")
|
| 1302 |
+
moe = FusedMoEConfig.make(
|
| 1303 |
+
num_experts=self.global_num_experts,
|
| 1304 |
+
experts_per_token=top_k,
|
| 1305 |
+
hidden_dim=hidden_size,
|
| 1306 |
+
num_local_experts=self.local_num_experts,
|
| 1307 |
+
moe_parallel_config=self.moe_parallel_config,
|
| 1308 |
+
# TODO (bnell): this needs to be fixed for quantized types.
|
| 1309 |
+
in_dtype=params_dtype,
|
| 1310 |
+
quant_config=quant_config)
|
| 1311 |
+
|
| 1312 |
+
if quant_config is None:
|
| 1313 |
+
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
|
| 1314 |
+
else:
|
| 1315 |
+
self.quant_method = quant_config.get_quant_method(self, prefix)
|
| 1316 |
+
|
| 1317 |
+
assert self.quant_method is not None
|
| 1318 |
+
|
| 1319 |
+
local_num_experts = torch.sum(self.expert_map != -1) \
|
| 1320 |
+
if self.expert_map is not None else num_experts
|
| 1321 |
+
|
| 1322 |
+
moe_quant_params = {
|
| 1323 |
+
"num_experts": local_num_experts,
|
| 1324 |
+
"hidden_size": hidden_size,
|
| 1325 |
+
"intermediate_size_per_partition":
|
| 1326 |
+
self.intermediate_size_per_partition,
|
| 1327 |
+
"params_dtype": params_dtype,
|
| 1328 |
+
"weight_loader": self.weight_loader,
|
| 1329 |
+
}
|
| 1330 |
+
# need full intermediate size pre-sharding for WNA16 act order
|
| 1331 |
+
if (self.quant_method.__class__.__name__
|
| 1332 |
+
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
| 1333 |
+
moe_quant_params["intermediate_size_full"] = intermediate_size
|
| 1334 |
+
|
| 1335 |
+
self.ep_group = get_ep_group()
|
| 1336 |
+
# NOTE: self.tp_group is not expert_tp_group
|
| 1337 |
+
self.tp_group = get_tp_group().device_group
|
| 1338 |
+
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
| 1339 |
+
|
| 1340 |
+
def naive_multicast(self, x: torch.Tensor,
|
| 1341 |
+
cu_tokens_across_dp_cpu: torch.Tensor):
|
| 1342 |
+
assert (len(x.shape) == 2)
|
| 1343 |
+
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
| 1344 |
+
device=x.device,
|
| 1345 |
+
dtype=x.dtype)
|
| 1346 |
+
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
| 1347 |
+
self.dp_rank - 1]
|
| 1348 |
+
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
| 1349 |
+
buffer[start:end, :].copy_(x)
|
| 1350 |
+
for idx in range(self.dp_size):
|
| 1351 |
+
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
| 1352 |
+
end = cu_tokens_across_dp_cpu[idx]
|
| 1353 |
+
get_dp_group().broadcast(buffer[start:end, :], idx)
|
| 1354 |
+
return buffer
|
| 1355 |
+
|
| 1356 |
+
def forward(self,
|
| 1357 |
+
hidden_states: torch.Tensor,
|
| 1358 |
+
router_logits: torch.Tensor,
|
| 1359 |
+
is_prefill: bool,
|
| 1360 |
+
enable_force_load_balance: bool = False,
|
| 1361 |
+
top_k: Optional[int] = None,
|
| 1362 |
+
shared_experts: Optional[Any] = None,
|
| 1363 |
+
gate=None,
|
| 1364 |
+
replace_allreduce: bool = False):
|
| 1365 |
+
|
| 1366 |
+
assert self.quant_method is not None
|
| 1367 |
+
|
| 1368 |
+
if top_k:
|
| 1369 |
+
real_top_k = top_k
|
| 1370 |
+
else:
|
| 1371 |
+
real_top_k = self.top_k
|
| 1372 |
+
|
| 1373 |
+
num_tokens, hidden_size = hidden_states.shape
|
| 1374 |
+
is_deepseek_v3_r1 = self.global_num_experts == 256
|
| 1375 |
+
|
| 1376 |
+
fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
|
| 1377 |
+
is_prefill, is_deepseek_v3_r1)
|
| 1378 |
+
if shared_experts:
|
| 1379 |
+
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
|
| 1380 |
+
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
| 1381 |
+
shared_hidden_states = shared_experts(hidden_states)
|
| 1382 |
+
|
| 1383 |
+
tp_size = get_tensor_model_parallel_world_size()
|
| 1384 |
+
if (tp_size > 1 and fused_moe_state not in [
|
| 1385 |
+
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
| 1386 |
+
FusedMoEState.NaiveMulticast
|
| 1387 |
+
] and not replace_allreduce):
|
| 1388 |
+
if num_tokens < tp_size:
|
| 1389 |
+
hidden_states = nn.functional.pad(
|
| 1390 |
+
hidden_states, (0, 0, 0, tp_size - num_tokens))
|
| 1391 |
+
router_logits = nn.functional.pad(
|
| 1392 |
+
router_logits, (0, 0, 0, tp_size - num_tokens))
|
| 1393 |
+
chunk_hidden_states = torch.tensor_split(hidden_states,
|
| 1394 |
+
tp_size,
|
| 1395 |
+
dim=0)
|
| 1396 |
+
chunk_router_logits = torch.tensor_split(router_logits,
|
| 1397 |
+
tp_size,
|
| 1398 |
+
dim=0)
|
| 1399 |
+
tp_rank = get_tensor_model_parallel_rank()
|
| 1400 |
+
hidden_states = chunk_hidden_states[tp_rank]
|
| 1401 |
+
router_logits = chunk_router_logits[tp_rank]
|
| 1402 |
+
|
| 1403 |
+
if self.dp_size > 1:
|
| 1404 |
+
if fused_moe_state in (FusedMoEState.AllGather, FusedMoEState.AllGatherEP):
|
| 1405 |
+
# NOTE: When in torchair graph, it has been padded in model_runner_v1
|
| 1406 |
+
if not self.torchair_graph_enabled or is_prefill:
|
| 1407 |
+
attn_metadata = get_forward_context().attn_metadata
|
| 1408 |
+
if attn_metadata is not None:
|
| 1409 |
+
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
|
| 1410 |
+
if num_tokens < max_num_tokens_across_dp:
|
| 1411 |
+
hidden_states = nn.functional.pad(
|
| 1412 |
+
hidden_states,
|
| 1413 |
+
(0, 0, 0,
|
| 1414 |
+
max_num_tokens_across_dp - num_tokens))
|
| 1415 |
+
if not self.rm_router_logits:
|
| 1416 |
+
router_logits = nn.functional.pad(
|
| 1417 |
+
router_logits,
|
| 1418 |
+
(0, 0, 0,
|
| 1419 |
+
max_num_tokens_across_dp - num_tokens))
|
| 1420 |
+
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
| 1421 |
+
if self.rm_router_logits:
|
| 1422 |
+
router_logits, _ = gate(hidden_states.float())
|
| 1423 |
+
else:
|
| 1424 |
+
router_logits = get_dp_group().all_gather(router_logits, 0)
|
| 1425 |
+
|
| 1426 |
+
elif fused_moe_state == FusedMoEState.NaiveMulticast:
|
| 1427 |
+
cu_tokens_across_dp_cpu = get_forward_context(
|
| 1428 |
+
).dp_metadata.cu_tokens_across_dp_cpu
|
| 1429 |
+
hidden_states = self.naive_multicast(hidden_states,
|
| 1430 |
+
cu_tokens_across_dp_cpu)
|
| 1431 |
+
if self.rm_router_logits:
|
| 1432 |
+
router_logits, _ = gate(hidden_states.float())
|
| 1433 |
+
else:
|
| 1434 |
+
router_logits = self.naive_multicast(
|
| 1435 |
+
router_logits, cu_tokens_across_dp_cpu)
|
| 1436 |
+
|
| 1437 |
+
# Matrix multiply.
|
| 1438 |
+
e_hidden_states = self.quant_method.apply(
|
| 1439 |
+
layer=self,
|
| 1440 |
+
x=hidden_states,
|
| 1441 |
+
router_logits=router_logits,
|
| 1442 |
+
top_k=real_top_k,
|
| 1443 |
+
renormalize=self.renormalize,
|
| 1444 |
+
use_grouped_topk=self.use_grouped_topk,
|
| 1445 |
+
global_num_experts=self.global_num_experts,
|
| 1446 |
+
expert_map=self.expert_map,
|
| 1447 |
+
topk_group=self.topk_group,
|
| 1448 |
+
num_expert_group=self.num_expert_group,
|
| 1449 |
+
custom_routing_function=self.custom_routing_function,
|
| 1450 |
+
scoring_func=self.scoring_func,
|
| 1451 |
+
e_score_correction_bias=self.e_score_correction_bias,
|
| 1452 |
+
is_prefill=is_prefill,
|
| 1453 |
+
enable_force_load_balance=enable_force_load_balance,
|
| 1454 |
+
log2phy=self.log2phy,
|
| 1455 |
+
global_redundant_expert_num=self.global_redundant_expert_num,
|
| 1456 |
+
shared_experts=shared_experts if self.torchair_graph_enabled
|
| 1457 |
+
and self.enable_multistream_moe and not is_prefill else None,
|
| 1458 |
+
)
|
| 1459 |
+
|
| 1460 |
+
if shared_experts:
|
| 1461 |
+
if isinstance(e_hidden_states, tuple):
|
| 1462 |
+
e_hidden_states, shared_hidden_states = e_hidden_states
|
| 1463 |
+
|
| 1464 |
+
if (tp_size > 1 and fused_moe_state not in [
|
| 1465 |
+
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
| 1466 |
+
FusedMoEState.NaiveMulticast
|
| 1467 |
+
] and not replace_allreduce):
|
| 1468 |
+
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
| 1469 |
+
self.tp_group)
|
| 1470 |
+
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
| 1471 |
+
if num_tokens < tp_size:
|
| 1472 |
+
final_hidden_states = final_hidden_states[:num_tokens]
|
| 1473 |
+
dispose_tensor(e_hidden_states)
|
| 1474 |
+
elif self.dp_size > 1:
|
| 1475 |
+
if fused_moe_state == FusedMoEState.NaiveMulticast:
|
| 1476 |
+
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
| 1477 |
+
self.dp_rank - 1]
|
| 1478 |
+
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
| 1479 |
+
final_hidden_states = get_dp_group().all_reduce(
|
| 1480 |
+
e_hidden_states)
|
| 1481 |
+
final_hidden_states = final_hidden_states[start:end, :]
|
| 1482 |
+
dispose_tensor(e_hidden_states)
|
| 1483 |
+
elif fused_moe_state in (FusedMoEState.AllGather, FusedMoEState.AllGatherEP):
|
| 1484 |
+
final_hidden_states = data_parallel_reduce_scatter(
|
| 1485 |
+
e_hidden_states, dim=0)
|
| 1486 |
+
final_hidden_states = final_hidden_states[:num_tokens]
|
| 1487 |
+
dispose_tensor(e_hidden_states)
|
| 1488 |
+
else:
|
| 1489 |
+
final_hidden_states = e_hidden_states
|
| 1490 |
+
|
| 1491 |
+
if tp_size > 1 and not self.all_reduce_merge and fused_moe_state in [
|
| 1492 |
+
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
| 1493 |
+
FusedMoEState.NaiveMulticast
|
| 1494 |
+
]:
|
| 1495 |
+
final_hidden_states = tensor_model_parallel_all_reduce(
|
| 1496 |
+
final_hidden_states)
|
| 1497 |
+
|
| 1498 |
+
if shared_experts:
|
| 1499 |
+
return final_hidden_states, shared_hidden_states
|
| 1500 |
+
else:
|
| 1501 |
+
return final_hidden_states
|
| 1502 |
+
|
| 1503 |
+
# ----------------------------------------- TBO-related --------------------------------------------
|
| 1504 |
+
|
| 1505 |
+
def _forward_ms_fused_moe_comp(
|
| 1506 |
+
self,
|
| 1507 |
+
hidden_states: torch.Tensor,
|
| 1508 |
+
router_logits: torch.Tensor,
|
| 1509 |
+
is_prefill: bool,
|
| 1510 |
+
real_top_k,
|
| 1511 |
+
enable_force_load_balance: bool = False,
|
| 1512 |
+
):
|
| 1513 |
+
hidden_states = self.quant_method.apply(
|
| 1514 |
+
layer=self,
|
| 1515 |
+
x=hidden_states,
|
| 1516 |
+
router_logits=router_logits,
|
| 1517 |
+
top_k=real_top_k,
|
| 1518 |
+
renormalize=self.renormalize,
|
| 1519 |
+
use_grouped_topk=self.use_grouped_topk,
|
| 1520 |
+
global_num_experts=self.global_num_experts,
|
| 1521 |
+
expert_map=self.expert_map,
|
| 1522 |
+
topk_group=self.topk_group,
|
| 1523 |
+
num_expert_group=self.num_expert_group,
|
| 1524 |
+
custom_routing_function=self.custom_routing_function,
|
| 1525 |
+
scoring_func=self.scoring_func,
|
| 1526 |
+
e_score_correction_bias=self.e_score_correction_bias,
|
| 1527 |
+
is_prefill=is_prefill,
|
| 1528 |
+
enable_force_load_balance=enable_force_load_balance)
|
| 1529 |
+
|
| 1530 |
+
return hidden_states
|
inference/vllm_ascend/patch/worker/patch_common/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# This file is a part of the vllm-ascend project.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
# patch_utils should be the first import, because it will be used by other
|
| 19 |
+
# patch files.
|
| 20 |
+
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
|
| 21 |
+
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
| 22 |
+
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
| 23 |
+
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
|
| 24 |
+
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
|
| 25 |
+
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
|
| 26 |
+
import vllm_ascend.patch.worker.patch_common.patch_config # noqa
|
| 27 |
+
import vllm_ascend.patch.worker.patch_common.patch_parsers # noqa
|
inference/vllm_ascend/patch/worker/patch_common/patch_config.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# This file is a part of the vllm-ascend project.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
from vllm.config import ModelConfig
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_attr_by_names(src_config, attrs, default_value):
|
| 21 |
+
for attr in attrs:
|
| 22 |
+
value = getattr(src_config, attr, 0)
|
| 23 |
+
if value > 0:
|
| 24 |
+
return value
|
| 25 |
+
return default_value
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _verify_with_expert_parallelism(self) -> None:
|
| 29 |
+
num_expert_names = [
|
| 30 |
+
"moe_num_experts", # Dbrx
|
| 31 |
+
"num_experts", # Jamba
|
| 32 |
+
"n_routed_experts", # DeepSeek
|
| 33 |
+
"num_local_experts", # Mixtral
|
| 34 |
+
"num_routed_experts", # Pangu
|
| 35 |
+
]
|
| 36 |
+
num_experts = 0
|
| 37 |
+
for name in num_expert_names:
|
| 38 |
+
num_experts = getattr(self.hf_text_config, name, 0)
|
| 39 |
+
if num_experts > 0:
|
| 40 |
+
break
|
| 41 |
+
if num_experts < 1:
|
| 42 |
+
raise ValueError(
|
| 43 |
+
"Number of experts in the model must be greater than 0 "
|
| 44 |
+
"when expert parallelism is enabled.")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def is_deepseek_mla(self) -> bool:
|
| 49 |
+
kv_lora_dim_names = ['attention_kv_lora_dim', 'kv_lora_rank']
|
| 50 |
+
kv_lora_dim = get_attr_by_names(self.hf_text_config, kv_lora_dim_names, None)
|
| 51 |
+
if not hasattr(self.hf_text_config, "model_type"):
|
| 52 |
+
return False
|
| 53 |
+
elif self.hf_text_config.model_type in \
|
| 54 |
+
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'pangu_ultra_moe'):
|
| 55 |
+
return kv_lora_dim is not None
|
| 56 |
+
elif self.hf_text_config.model_type == 'eagle':
|
| 57 |
+
# if the model is an EAGLE module, check for the
|
| 58 |
+
# underlying architecture
|
| 59 |
+
return self.hf_text_config.model.model_type in \
|
| 60 |
+
('deepseek_v2', 'deepseek_v3', 'pangu_ultra_moe') \
|
| 61 |
+
and kv_lora_dim is not None
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_head_size(self) -> int:
|
| 66 |
+
if self.is_deepseek_mla:
|
| 67 |
+
qk_rope_dim_names = ['attention_qk_rope_dim', 'qk_rope_head_dim']
|
| 68 |
+
kv_lora_dim_names = ['attention_kv_lora_dim', 'kv_lora_rank']
|
| 69 |
+
qk_rope_dim = get_attr_by_names(self.hf_text_config, qk_rope_dim_names, 0)
|
| 70 |
+
kv_lora_dim = get_attr_by_names(self.hf_text_config, kv_lora_dim_names, 0)
|
| 71 |
+
if self.use_mla:
|
| 72 |
+
return kv_lora_dim + qk_rope_dim
|
| 73 |
+
else:
|
| 74 |
+
qk_dim_names = ['attention_qk_dim', 'qk_nope_head_dim']
|
| 75 |
+
qk_dim = get_attr_by_names(self.hf_text_config, qk_dim_names, 0)
|
| 76 |
+
if qk_rope_dim and qk_dim:
|
| 77 |
+
return qk_rope_dim + qk_dim
|
| 78 |
+
if hasattr(self.hf_text_config,
|
| 79 |
+
"model_type") and (self.hf_text_config.model_type
|
| 80 |
+
== "zamba2"):
|
| 81 |
+
return self.hf_text_config.attention_head_dim
|
| 82 |
+
|
| 83 |
+
if self.is_attention_free:
|
| 84 |
+
return 0
|
| 85 |
+
|
| 86 |
+
# NOTE: Some configs may set head_dim=None in the config
|
| 87 |
+
if getattr(self.hf_text_config, "head_dim", None) is not None:
|
| 88 |
+
return self.hf_text_config.head_dim
|
| 89 |
+
|
| 90 |
+
# FIXME(woosuk): This may not be true for all models.
|
| 91 |
+
return (self.hf_text_config.hidden_size //
|
| 92 |
+
self.hf_text_config.num_attention_heads)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
ModelConfig._verify_with_expert_parallelism = _verify_with_expert_parallelism
|
| 96 |
+
ModelConfig.is_deepseek_mla = is_deepseek_mla
|
| 97 |
+
ModelConfig.get_head_size = get_head_size
|
inference/vllm_ascend/patch/worker/patch_common/patch_parsers.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# This file is a part of the vllm-ascend project.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from vllm.entrypoints.openai import tool_parsers
|
| 20 |
+
from vllm_ascend.entrypoints.openai.tool_parsers import PanguToolParser
|
| 21 |
+
tool_parsers.__all__.append("PanguToolParser")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
from vllm import reasoning
|
| 25 |
+
from vllm_ascend.entrypoints.openai.reasoning_parsers import PanguReasoningParser
|
| 26 |
+
reasoning.__all__.append("PanguReasoningParser")
|
inference/vllm_ascend/patch/worker/patch_common/patch_sampler.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
# This file is a part of the vllm-ascend project.
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
#
|
| 18 |
+
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch_npu
|
| 23 |
+
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
|
| 24 |
+
from vllm.v1.sample.sampler import Sampler, _SAMPLING_EPS
|
| 25 |
+
from vllm.v1.sample.metadata import SamplingMetadata
|
| 26 |
+
from vllm_ascend import envs
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def apply_top_k_top_p(
|
| 30 |
+
logits: torch.Tensor,
|
| 31 |
+
k: torch.Tensor,
|
| 32 |
+
p: torch.Tensor,
|
| 33 |
+
) -> torch.Tensor:
|
| 34 |
+
if p is not None and k is not None:
|
| 35 |
+
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
|
| 36 |
+
return torch_npu.npu_top_k_top_p(logits, p, k)
|
| 37 |
+
|
| 38 |
+
probs = logits.softmax(dim=-1)
|
| 39 |
+
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
| 40 |
+
|
| 41 |
+
if k is not None:
|
| 42 |
+
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
| 43 |
+
top_k_count = top_k_count.unsqueeze(dim=1)
|
| 44 |
+
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
| 45 |
+
|
| 46 |
+
# Make sure the no top-k rows are no-op.
|
| 47 |
+
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
| 48 |
+
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
| 49 |
+
|
| 50 |
+
elements_to_discard = probs < top_k_cutoff
|
| 51 |
+
logits.masked_fill_(elements_to_discard, -float("inf"))
|
| 52 |
+
|
| 53 |
+
if p is not None:
|
| 54 |
+
cumprob = torch.cumsum(probs_sort, dim=-1)
|
| 55 |
+
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
| 56 |
+
top_p_mask[:, -1] = False # at least one
|
| 57 |
+
|
| 58 |
+
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
| 59 |
+
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
| 60 |
+
elements_to_discard = probs < top_p_cutoff
|
| 61 |
+
logits.masked_fill_(elements_to_discard, -float("inf"))
|
| 62 |
+
|
| 63 |
+
return logits
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def topk_topp_forward_native(
|
| 67 |
+
self,
|
| 68 |
+
logits: torch.Tensor,
|
| 69 |
+
generators: dict[int, torch.Generator],
|
| 70 |
+
k: Optional[torch.Tensor],
|
| 71 |
+
p: Optional[torch.Tensor],
|
| 72 |
+
) -> torch.Tensor:
|
| 73 |
+
"""
|
| 74 |
+
PyTorch-native implementation of top-k and top-p sampling.
|
| 75 |
+
|
| 76 |
+
The logits tensor may be updated in-place.
|
| 77 |
+
"""
|
| 78 |
+
logits = apply_top_k_top_p(logits, k, p)
|
| 79 |
+
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
| 80 |
+
return random_sample(probs, generators)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def apply_top_n_sigma(
|
| 84 |
+
logits: torch.Tensor,
|
| 85 |
+
sampling_metadata: SamplingMetadata,
|
| 86 |
+
):
|
| 87 |
+
if sampling_metadata.no_top_n_sigma:
|
| 88 |
+
return logits
|
| 89 |
+
|
| 90 |
+
top_n_sigma = sampling_metadata.top_n_sigma[:, None]
|
| 91 |
+
top_n_sigma_mask = (top_n_sigma != -1)
|
| 92 |
+
filter_value = -3.4028e+38
|
| 93 |
+
max_vals, _ = logits.max(dim=-1, keepdim=True)
|
| 94 |
+
std_vals = logits.std(dim=-1, keepdim=True)
|
| 95 |
+
threshold = max_vals - top_n_sigma * std_vals
|
| 96 |
+
threshold[~top_n_sigma_mask] = filter_value
|
| 97 |
+
mask = (logits < threshold)
|
| 98 |
+
logits = torch.where(mask, filter_value, logits)
|
| 99 |
+
return logits
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def sample(
|
| 103 |
+
self,
|
| 104 |
+
logits: torch.Tensor,
|
| 105 |
+
sampling_metadata: SamplingMetadata,
|
| 106 |
+
) -> torch.Tensor:
|
| 107 |
+
"""Sample logits based on sampling metadata.
|
| 108 |
+
|
| 109 |
+
The various logits processing functions called in this method
|
| 110 |
+
may update the logits tensor in-place.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
assert not (sampling_metadata.all_greedy
|
| 114 |
+
and sampling_metadata.all_random)
|
| 115 |
+
if sampling_metadata.all_random:
|
| 116 |
+
greedy_sampled = None
|
| 117 |
+
else:
|
| 118 |
+
greedy_sampled = self.greedy_sample(logits)
|
| 119 |
+
if sampling_metadata.all_greedy:
|
| 120 |
+
return greedy_sampled
|
| 121 |
+
|
| 122 |
+
assert sampling_metadata.temperature is not None
|
| 123 |
+
|
| 124 |
+
# Apply temperature.
|
| 125 |
+
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
| 126 |
+
|
| 127 |
+
# Apply logits processors that only apply to random sampling
|
| 128 |
+
# (argmax invariant)
|
| 129 |
+
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
| 130 |
+
logits = processor.apply(logits)
|
| 131 |
+
|
| 132 |
+
# Apply top_n_sigma
|
| 133 |
+
logits = apply_top_n_sigma(logits, sampling_metadata)
|
| 134 |
+
|
| 135 |
+
# Apply top_k and/or top_p.
|
| 136 |
+
random_sampled = self.topk_topp_sampler(
|
| 137 |
+
logits,
|
| 138 |
+
sampling_metadata.generators,
|
| 139 |
+
sampling_metadata.top_k,
|
| 140 |
+
sampling_metadata.top_p,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if greedy_sampled is None:
|
| 144 |
+
return random_sampled
|
| 145 |
+
|
| 146 |
+
sampled = torch.where(
|
| 147 |
+
sampling_metadata.temperature < _SAMPLING_EPS,
|
| 148 |
+
greedy_sampled,
|
| 149 |
+
random_sampled,
|
| 150 |
+
out=greedy_sampled, # Reuse tensor
|
| 151 |
+
)
|
| 152 |
+
return sampled
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
|
| 156 |
+
TopKTopPSampler.forward_native = topk_topp_forward_native
|
| 157 |
+
|
| 158 |
+
if envs.VLLM_ASCEND_ENABLE_TOP_N_SIGMA:
|
| 159 |
+
Sampler.sample = sample
|
inference/vllm_ascend/quantization/w8a8.py
ADDED
|
@@ -0,0 +1,757 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# This file is a part of the vllm-ascend project.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
from typing import Any, Callable, Dict, Optional
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch_npu
|
| 22 |
+
from vllm.attention.backends.abstract import AttentionType
|
| 23 |
+
|
| 24 |
+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
| 25 |
+
from vllm_ascend.distributed.parallel_state import get_ep_group
|
| 26 |
+
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def quant_per_tensor(in_tensor: torch.Tensor,
|
| 30 |
+
input_scale: torch.Tensor,
|
| 31 |
+
input_offset: torch.Tensor,
|
| 32 |
+
function=False):
|
| 33 |
+
return torch_npu.npu_quantize(in_tensor, input_scale, input_offset,
|
| 34 |
+
torch.qint8, -1, function)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class AscendW8A8LinearMethod:
|
| 38 |
+
"""Linear method for Ascend W8A8.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
w_sym: whether the linear weight is symmetrically quantized.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self) -> None:
|
| 45 |
+
# aclnn quant matmul requires to transpose matrix B, set to true by default.
|
| 46 |
+
self.transpose_weight = not is_310p()
|
| 47 |
+
|
| 48 |
+
@staticmethod
|
| 49 |
+
def get_weight(
|
| 50 |
+
input_size: int,
|
| 51 |
+
output_size: int,
|
| 52 |
+
params_dtype: torch.dtype = torch.bfloat16,
|
| 53 |
+
) -> Dict[str, Any]:
|
| 54 |
+
params_dict = {
|
| 55 |
+
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
| 56 |
+
}
|
| 57 |
+
return params_dict
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 61 |
+
params_dict = {}
|
| 62 |
+
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
| 63 |
+
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
|
| 64 |
+
return params_dict
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def get_perchannel_param(
|
| 68 |
+
output_size: int,
|
| 69 |
+
params_dtype: torch.dtype,
|
| 70 |
+
) -> Dict[str, Any]:
|
| 71 |
+
params_dict = {}
|
| 72 |
+
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
|
| 73 |
+
if params_dtype == torch.bfloat16:
|
| 74 |
+
params_dict["deq_scale"] = torch.empty(output_size,
|
| 75 |
+
dtype=torch.float32)
|
| 76 |
+
elif params_dtype == torch.float16:
|
| 77 |
+
params_dict["deq_scale"] = torch.empty(output_size,
|
| 78 |
+
dtype=torch.int64)
|
| 79 |
+
params_dict["weight_scale"] = torch.empty(output_size,
|
| 80 |
+
1,
|
| 81 |
+
dtype=params_dtype)
|
| 82 |
+
params_dict["weight_offset"] = torch.empty(output_size,
|
| 83 |
+
1,
|
| 84 |
+
dtype=params_dtype)
|
| 85 |
+
return params_dict
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def apply(
|
| 89 |
+
layer: torch.nn.Module,
|
| 90 |
+
x: torch.Tensor,
|
| 91 |
+
bias: Optional[torch.Tensor] = None,
|
| 92 |
+
tp_rank: Optional[int] = 0,
|
| 93 |
+
) -> torch.Tensor:
|
| 94 |
+
original_dtype = x.dtype
|
| 95 |
+
if original_dtype != torch.int8:
|
| 96 |
+
x = quant_per_tensor(x, layer.aclnn_input_scale,
|
| 97 |
+
layer.aclnn_input_offset)
|
| 98 |
+
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
| 99 |
+
if is_310p():
|
| 100 |
+
# On 300I Duo platform, we need transpose again if
|
| 101 |
+
# using nz. This transpose can be skipped in torchair.
|
| 102 |
+
output = torch_npu.npu_quant_matmul(
|
| 103 |
+
x,
|
| 104 |
+
layer.weight.data.transpose(1, 0),
|
| 105 |
+
layer.deq_scale,
|
| 106 |
+
bias=quant_bias,
|
| 107 |
+
output_dtype=original_dtype,
|
| 108 |
+
)
|
| 109 |
+
else:
|
| 110 |
+
output = torch_npu.npu_quant_matmul(
|
| 111 |
+
x,
|
| 112 |
+
layer.weight,
|
| 113 |
+
layer.deq_scale,
|
| 114 |
+
bias=quant_bias,
|
| 115 |
+
output_dtype=original_dtype,
|
| 116 |
+
)
|
| 117 |
+
return output
|
| 118 |
+
|
| 119 |
+
def process_weights_after_loading(self, layer):
|
| 120 |
+
expanding_factor = layer.weight.data.shape[1]
|
| 121 |
+
layer.aclnn_input_scale = 1 / torch.nn.Parameter(
|
| 122 |
+
layer.input_scale.data.repeat(expanding_factor),
|
| 123 |
+
requires_grad=False)
|
| 124 |
+
layer.aclnn_input_offset = torch.nn.Parameter(
|
| 125 |
+
layer.input_offset.data.repeat(expanding_factor),
|
| 126 |
+
requires_grad=False).to(layer.aclnn_input_scale.dtype)
|
| 127 |
+
if self.transpose_weight:
|
| 128 |
+
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
| 129 |
+
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data,
|
| 130 |
+
ACL_FORMAT_FRACTAL_NZ)
|
| 131 |
+
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
| 132 |
+
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class AscendW8A8FusedMoEMethod:
|
| 136 |
+
"""FusedMoe method for Ascend W8A8.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(self):
|
| 140 |
+
self.transpose_weight = True
|
| 141 |
+
|
| 142 |
+
@staticmethod
|
| 143 |
+
def get_weight(num_experts: int, intermediate_size_per_partition: int,
|
| 144 |
+
hidden_sizes: int,
|
| 145 |
+
params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 146 |
+
param_dict = {}
|
| 147 |
+
param_dict["w13_weight"] = torch.empty(num_experts,
|
| 148 |
+
2 *
|
| 149 |
+
intermediate_size_per_partition,
|
| 150 |
+
hidden_sizes,
|
| 151 |
+
dtype=torch.int8,
|
| 152 |
+
requires_grad=False)
|
| 153 |
+
param_dict["w2_weight"] = torch.empty(num_experts,
|
| 154 |
+
hidden_sizes,
|
| 155 |
+
intermediate_size_per_partition,
|
| 156 |
+
dtype=torch.int8,
|
| 157 |
+
requires_grad=False)
|
| 158 |
+
return param_dict
|
| 159 |
+
|
| 160 |
+
@staticmethod
|
| 161 |
+
def get_dynamic_quant_param(num_experts: int,
|
| 162 |
+
intermediate_size_per_partition: int,
|
| 163 |
+
hidden_sizes: int,
|
| 164 |
+
params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 165 |
+
param_dict = {}
|
| 166 |
+
param_dict["w13_weight_scale"] = torch.empty(
|
| 167 |
+
num_experts,
|
| 168 |
+
2 * intermediate_size_per_partition,
|
| 169 |
+
1,
|
| 170 |
+
dtype=torch.float32)
|
| 171 |
+
param_dict["w13_weight_offset"] = torch.empty(
|
| 172 |
+
num_experts,
|
| 173 |
+
2 * intermediate_size_per_partition,
|
| 174 |
+
1,
|
| 175 |
+
dtype=torch.float16)
|
| 176 |
+
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
| 177 |
+
hidden_sizes,
|
| 178 |
+
1,
|
| 179 |
+
dtype=torch.float32)
|
| 180 |
+
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
| 181 |
+
hidden_sizes,
|
| 182 |
+
1,
|
| 183 |
+
dtype=torch.float16)
|
| 184 |
+
param_dict["w2_deq_scale"] = torch.empty(num_experts,
|
| 185 |
+
hidden_sizes,
|
| 186 |
+
dtype=torch.float32)
|
| 187 |
+
param_dict["w13_deq_scale"] = torch.empty(
|
| 188 |
+
num_experts,
|
| 189 |
+
2 * intermediate_size_per_partition,
|
| 190 |
+
dtype=torch.float32)
|
| 191 |
+
param_dict["w2_input_scale"] = torch.empty(num_experts,
|
| 192 |
+
1,
|
| 193 |
+
dtype=torch.float32)
|
| 194 |
+
param_dict["w13_input_scale"] = torch.empty(num_experts,
|
| 195 |
+
1,
|
| 196 |
+
dtype=torch.float32)
|
| 197 |
+
param_dict["w2_input_offset"] = torch.empty(num_experts,
|
| 198 |
+
1,
|
| 199 |
+
dtype=torch.int8)
|
| 200 |
+
param_dict["w13_input_offset"] = torch.empty(num_experts,
|
| 201 |
+
1,
|
| 202 |
+
dtype=torch.int8)
|
| 203 |
+
param_dict["quant_bias"] = torch.empty(num_experts,
|
| 204 |
+
hidden_sizes,
|
| 205 |
+
dtype=torch.int32)
|
| 206 |
+
|
| 207 |
+
return param_dict
|
| 208 |
+
|
| 209 |
+
def apply(
|
| 210 |
+
self,
|
| 211 |
+
layer: torch.nn.Module,
|
| 212 |
+
x: torch.Tensor,
|
| 213 |
+
router_logits: torch.Tensor,
|
| 214 |
+
top_k: int,
|
| 215 |
+
renormalize: bool,
|
| 216 |
+
use_grouped_topk: bool = False,
|
| 217 |
+
global_num_experts: int = -1,
|
| 218 |
+
expert_map: Optional[torch.Tensor] = None,
|
| 219 |
+
topk_group: Optional[int] = None,
|
| 220 |
+
num_expert_group: Optional[int] = None,
|
| 221 |
+
custom_routing_function: Optional[Callable] = None,
|
| 222 |
+
scoring_func: str = "softmax",
|
| 223 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 224 |
+
is_prefill: bool = True,
|
| 225 |
+
enable_force_load_balance: bool = False,
|
| 226 |
+
log2phy: torch.Tensor = None,
|
| 227 |
+
global_redundant_expert_num: int = 0,
|
| 228 |
+
shared_experts: Optional[Any] = None,
|
| 229 |
+
**kwargs,
|
| 230 |
+
) -> torch.Tensor:
|
| 231 |
+
assert router_logits.shape[
|
| 232 |
+
1] == global_num_experts, "Number of global experts mismatch"
|
| 233 |
+
|
| 234 |
+
topk_weights, topk_ids = select_experts(
|
| 235 |
+
hidden_states=x,
|
| 236 |
+
router_logits=router_logits,
|
| 237 |
+
top_k=top_k,
|
| 238 |
+
use_grouped_topk=use_grouped_topk,
|
| 239 |
+
renormalize=renormalize,
|
| 240 |
+
topk_group=topk_group,
|
| 241 |
+
num_expert_group=num_expert_group,
|
| 242 |
+
custom_routing_function=custom_routing_function,
|
| 243 |
+
scoring_func=scoring_func,
|
| 244 |
+
e_score_correction_bias=e_score_correction_bias,
|
| 245 |
+
global_num_experts=global_num_experts,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
if is_310p():
|
| 249 |
+
return fused_experts_310p(hidden_states=x,
|
| 250 |
+
w1=layer.w13_weight,
|
| 251 |
+
w1_scale=layer.w13_weight_scale,
|
| 252 |
+
w1_input_scale=layer.w13_input_scale,
|
| 253 |
+
w2=layer.w2_weight,
|
| 254 |
+
w2_scale=layer.w2_weight_scale,
|
| 255 |
+
w2_input_scale=layer.w2_input_scale,
|
| 256 |
+
topk_weights=topk_weights,
|
| 257 |
+
topk_ids=topk_ids,
|
| 258 |
+
top_k=top_k,
|
| 259 |
+
global_num_experts=global_num_experts,
|
| 260 |
+
expert_map=expert_map)
|
| 261 |
+
return fused_experts(hidden_states=x,
|
| 262 |
+
w1=layer.w13_weight,
|
| 263 |
+
w1_scale=layer.w13_weight_scale,
|
| 264 |
+
w1_input_scale=layer.w13_input_scale,
|
| 265 |
+
w1_input_offset=layer.w13_input_offset,
|
| 266 |
+
w2=layer.w2_weight,
|
| 267 |
+
w2_scale=layer.w2_weight_scale,
|
| 268 |
+
w2_input_scale=layer.w2_input_scale,
|
| 269 |
+
w2_input_offset=layer.w2_input_offset,
|
| 270 |
+
topk_weights=topk_weights,
|
| 271 |
+
topk_ids=topk_ids,
|
| 272 |
+
top_k=top_k,
|
| 273 |
+
global_num_experts=global_num_experts,
|
| 274 |
+
expert_map=expert_map)
|
| 275 |
+
|
| 276 |
+
def process_weights_after_loading(self, layer):
|
| 277 |
+
if not is_310p():
|
| 278 |
+
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
| 279 |
+
1, 2).contiguous()
|
| 280 |
+
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
| 281 |
+
1, 2).contiguous()
|
| 282 |
+
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
| 283 |
+
layer.w13_weight_scale.data.shape[0], -1)
|
| 284 |
+
|
| 285 |
+
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
| 286 |
+
layer.w13_weight_offset.data.shape[0], -1)
|
| 287 |
+
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
| 288 |
+
layer.w2_weight_scale.data.shape[0], -1)
|
| 289 |
+
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
| 290 |
+
layer.w2_weight_offset.data.shape[0], -1)
|
| 291 |
+
expanding_factor_w13 = layer.w13_weight.data.shape[1]
|
| 292 |
+
expanding_factor_w2 = layer.w2_weight.data.shape[1]
|
| 293 |
+
|
| 294 |
+
if is_310p():
|
| 295 |
+
layer.w13_input_scale.data = torch.nn.Parameter(
|
| 296 |
+
layer.w13_input_scale.data.max())
|
| 297 |
+
layer.w2_input_scale.data = torch.nn.Parameter(
|
| 298 |
+
layer.w2_input_scale.data.max())
|
| 299 |
+
else:
|
| 300 |
+
layer.w13_input_scale.data = torch.nn.Parameter(
|
| 301 |
+
layer.w13_input_scale.data.repeat(1,
|
| 302 |
+
expanding_factor_w13)[0:1])
|
| 303 |
+
layer.w2_input_scale.data = torch.nn.Parameter(
|
| 304 |
+
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
|
| 305 |
+
|
| 306 |
+
layer.w13_input_offset.data = torch.nn.Parameter(
|
| 307 |
+
layer.w13_input_scale.data.repeat(1, expanding_factor_w13)[0:1])
|
| 308 |
+
layer.w2_input_offset.data = torch.nn.Parameter(
|
| 309 |
+
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
|
| 310 |
+
|
| 311 |
+
# converting ACL_FORMAT_FRACTAL_NZ.
|
| 312 |
+
# npu_quant_grouped_matmul_dequant in eager mode does not accept
|
| 313 |
+
# ACL_FORMAT_FRACTAL_NZ.
|
| 314 |
+
if not is_310p():
|
| 315 |
+
layer.w13_weight.data = torch_npu.npu_format_cast(
|
| 316 |
+
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
|
| 317 |
+
layer.w2_weight.data = torch_npu.npu_format_cast(
|
| 318 |
+
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class AscendC8KVCacheMethod:
|
| 322 |
+
|
| 323 |
+
def __init__(self) -> None:
|
| 324 |
+
self.antiquant_scale_comb = None
|
| 325 |
+
|
| 326 |
+
@staticmethod
|
| 327 |
+
def create_weights(layer) -> None:
|
| 328 |
+
param_dict = {} # num_kv_heads * head_size
|
| 329 |
+
param_dict["key_antiquant_scale"] = torch.empty(layer.num_kv_heads *
|
| 330 |
+
layer.head_size,
|
| 331 |
+
dtype=torch.float16,
|
| 332 |
+
requires_grad=False)
|
| 333 |
+
param_dict["value_antiquant_scale"] = torch.empty(layer.num_kv_heads *
|
| 334 |
+
layer.head_size,
|
| 335 |
+
dtype=torch.float16,
|
| 336 |
+
requires_grad=False)
|
| 337 |
+
for weight_name, weight_param in param_dict.items():
|
| 338 |
+
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
| 339 |
+
layer.register_parameter(weight_name, param)
|
| 340 |
+
|
| 341 |
+
def process_weights_after_loading(self, layer):
|
| 342 |
+
self.antiquant_scale_comb = torch.cat(
|
| 343 |
+
(layer.key_antiquant_scale.data.unsqueeze(0),
|
| 344 |
+
layer.value_antiquant_scale.data.unsqueeze(0)),
|
| 345 |
+
dim=0).to(torch.float16).contiguous()
|
| 346 |
+
|
| 347 |
+
def apply(self, layer, query, key, value, kv_cache, attn_metadata,
|
| 348 |
+
attn_type, scale, output) -> torch.Tensor:
|
| 349 |
+
num_tokens = query.shape[0]
|
| 350 |
+
if attn_metadata is None:
|
| 351 |
+
return output.view(num_tokens, layer.num_heads * layer.head_size)
|
| 352 |
+
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
| 353 |
+
if attn_type != AttentionType.DECODER:
|
| 354 |
+
raise NotImplementedError("Encoder self-attention and "
|
| 355 |
+
"encoder/decoder cross-attention "
|
| 356 |
+
"are not implemented for "
|
| 357 |
+
"PallasAttentionBackendImpl")
|
| 358 |
+
|
| 359 |
+
# C8
|
| 360 |
+
quant_key = quant_per_tensor(
|
| 361 |
+
key.view(-1, layer.num_kv_heads * layer.head_size),
|
| 362 |
+
layer.key_antiquant_scale.data.view(-1), None, True)
|
| 363 |
+
quant_value = quant_per_tensor(
|
| 364 |
+
value.view(-1, layer.num_kv_heads * layer.head_size),
|
| 365 |
+
layer.value_antiquant_scale.data.view(-1), None, True)
|
| 366 |
+
|
| 367 |
+
# View q k v to BSH.
|
| 368 |
+
query = query.view(-1, layer.num_heads, layer.head_size)
|
| 369 |
+
key = key.view(-1, layer.num_kv_heads, layer.head_size)
|
| 370 |
+
value = value.view(-1, layer.num_kv_heads, layer.head_size)
|
| 371 |
+
# TODO: Remove this contiguous in the future.
|
| 372 |
+
value = value.contiguous()
|
| 373 |
+
|
| 374 |
+
if kv_cache[0].numel() > 0:
|
| 375 |
+
# if key_cache is None:
|
| 376 |
+
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
| 377 |
+
slots = attn_metadata.slot_mapping
|
| 378 |
+
|
| 379 |
+
block_size = key_cache.shape[1]
|
| 380 |
+
slots_indices = slots.reshape(-1, 1)
|
| 381 |
+
block_indices = slots_indices // block_size
|
| 382 |
+
slots_indices = slots_indices % block_size
|
| 383 |
+
indices = torch.cat((block_indices, slots_indices), dim=1)
|
| 384 |
+
|
| 385 |
+
# C8
|
| 386 |
+
torch_npu.npu_scatter_nd_update_(key_cache, indices, quant_key)
|
| 387 |
+
torch_npu.npu_scatter_nd_update_(value_cache, indices, quant_value)
|
| 388 |
+
|
| 389 |
+
# V0-Style scheduler situation.
|
| 390 |
+
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
| 391 |
+
assert attn_metadata is not None
|
| 392 |
+
assert attn_metadata.attn_mask is not None
|
| 393 |
+
mask = attn_metadata.attn_mask
|
| 394 |
+
torch_npu._npu_flash_attention(query=query,
|
| 395 |
+
key=key,
|
| 396 |
+
value=value,
|
| 397 |
+
mask=mask,
|
| 398 |
+
seq_len=attn_metadata.seq_lens,
|
| 399 |
+
scale_value=scale,
|
| 400 |
+
num_heads=layer.num_heads,
|
| 401 |
+
num_kv_heads=layer.num_kv_heads,
|
| 402 |
+
out=output.reshape(query.shape))
|
| 403 |
+
|
| 404 |
+
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
| 405 |
+
raise NotImplementedError("kv cache int8 are not "
|
| 406 |
+
"implemented for "
|
| 407 |
+
"PrefillCacheHit")
|
| 408 |
+
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # changed attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
| 409 |
+
if hasattr(attn_metadata, "decode"):
|
| 410 |
+
# torch_air
|
| 411 |
+
decode_meta = attn_metadata.decode
|
| 412 |
+
seq_lens = decode_meta.seq_lens_list
|
| 413 |
+
else:
|
| 414 |
+
seq_lens = attn_metadata.seq_lens
|
| 415 |
+
block_size = key_cache.shape[1]
|
| 416 |
+
query = query.view(num_tokens, 1, layer.num_heads *
|
| 417 |
+
layer.head_size).contiguous() # changed
|
| 418 |
+
|
| 419 |
+
# [num_blocks, block_size, N, D] --> [num_blocks, N, block_size, D]
|
| 420 |
+
key = key_cache
|
| 421 |
+
value = value_cache
|
| 422 |
+
|
| 423 |
+
output = torch_npu.npu_incre_flash_attention(
|
| 424 |
+
query,
|
| 425 |
+
key,
|
| 426 |
+
value,
|
| 427 |
+
num_key_value_heads=layer.num_kv_heads,
|
| 428 |
+
num_heads=layer.num_heads,
|
| 429 |
+
actual_seq_lengths=seq_lens,
|
| 430 |
+
scale_value=scale,
|
| 431 |
+
input_layout='BSH',
|
| 432 |
+
block_size=block_size,
|
| 433 |
+
block_table=attn_metadata.block_tables,
|
| 434 |
+
antiquant_scale=self.antiquant_scale_comb,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# Normal V1 situation.
|
| 438 |
+
else:
|
| 439 |
+
raise NotImplementedError("kv cache int8 are not "
|
| 440 |
+
"implemented for "
|
| 441 |
+
"other case")
|
| 442 |
+
return output
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def fused_experts_310p(
|
| 446 |
+
hidden_states: torch.Tensor,
|
| 447 |
+
w1: torch.Tensor,
|
| 448 |
+
w1_scale: torch.Tensor,
|
| 449 |
+
w1_input_scale: torch.Tensor,
|
| 450 |
+
w2: torch.Tensor,
|
| 451 |
+
w2_scale: torch.Tensor,
|
| 452 |
+
w2_input_scale: torch.Tensor,
|
| 453 |
+
topk_weights: torch.Tensor,
|
| 454 |
+
topk_ids: torch.Tensor,
|
| 455 |
+
top_k: int,
|
| 456 |
+
global_num_experts: int,
|
| 457 |
+
expert_map: torch.Tensor = None,
|
| 458 |
+
) -> torch.Tensor:
|
| 459 |
+
ep_size = get_ep_group().world_size
|
| 460 |
+
local_num_experts = global_num_experts // ep_size
|
| 461 |
+
local_num_group = top_k // ep_size
|
| 462 |
+
|
| 463 |
+
bsz, _ = hidden_states.shape
|
| 464 |
+
flatten_topk_ids = topk_ids.view(-1)
|
| 465 |
+
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
| 466 |
+
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
| 467 |
+
sorted_hidden_states = hidden_states.index_select(
|
| 468 |
+
0, sorted_topk_ids // local_num_group)
|
| 469 |
+
|
| 470 |
+
experts_id = torch.arange(0,
|
| 471 |
+
local_num_experts,
|
| 472 |
+
dtype=topk_ids.dtype,
|
| 473 |
+
device=topk_ids.device)
|
| 474 |
+
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
| 475 |
+
torch.float32).sum(0)
|
| 476 |
+
topk_scales = topk_weights.view(-1).index_select(
|
| 477 |
+
0, sorted_topk_ids).unsqueeze(-1)
|
| 478 |
+
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
| 479 |
+
|
| 480 |
+
gate_up_out = torch_npu.npu_quant_grouped_matmul_dequant(
|
| 481 |
+
x=sorted_hidden_states,
|
| 482 |
+
quantized_weight=w1,
|
| 483 |
+
weight_scale=w1_scale,
|
| 484 |
+
group_list=group_list,
|
| 485 |
+
x_scale=w1_input_scale,
|
| 486 |
+
quant_mode="pertensor")
|
| 487 |
+
|
| 488 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
| 489 |
+
torch.float16)
|
| 490 |
+
gate_up_out *= topk_scales
|
| 491 |
+
|
| 492 |
+
down_out = torch_npu.npu_quant_grouped_matmul_dequant(
|
| 493 |
+
x=gate_up_out,
|
| 494 |
+
quantized_weight=w2,
|
| 495 |
+
weight_scale=w2_scale,
|
| 496 |
+
group_list=group_list,
|
| 497 |
+
x_scale=w2_input_scale,
|
| 498 |
+
quant_mode="pertensor")
|
| 499 |
+
|
| 500 |
+
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
| 501 |
+
unsorted_hidden_states = down_out.index_select(0, unsorted_topk_ids)
|
| 502 |
+
final_hidden_states = unsorted_hidden_states.reshape(
|
| 503 |
+
bsz, top_k // ep_size, -1).sum(1)
|
| 504 |
+
|
| 505 |
+
return final_hidden_states
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def fused_experts(
|
| 509 |
+
hidden_states: torch.Tensor,
|
| 510 |
+
w1: torch.Tensor,
|
| 511 |
+
w1_scale: torch.Tensor,
|
| 512 |
+
w1_input_scale: torch.Tensor,
|
| 513 |
+
w1_input_offset: torch.Tensor,
|
| 514 |
+
w2: torch.Tensor,
|
| 515 |
+
w2_scale: torch.Tensor,
|
| 516 |
+
w2_input_scale: torch.Tensor,
|
| 517 |
+
w2_input_offset: torch.Tensor,
|
| 518 |
+
topk_weights: torch.Tensor,
|
| 519 |
+
topk_ids: torch.Tensor,
|
| 520 |
+
top_k: int,
|
| 521 |
+
global_num_experts: int,
|
| 522 |
+
expert_map: torch.Tensor = None,
|
| 523 |
+
) -> torch.Tensor:
|
| 524 |
+
"""
|
| 525 |
+
Fused experts with top-k routing.
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
| 529 |
+
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
| 530 |
+
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
| 531 |
+
topk_weights: Routing weights of shape (num_tokens, top_k).
|
| 532 |
+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
| 533 |
+
top_k: Number of experts to select.
|
| 534 |
+
expert_map: Expert mapping of shape (num_experts,).
|
| 535 |
+
|
| 536 |
+
Returns:
|
| 537 |
+
hidden_states: Hidden states after routing.
|
| 538 |
+
"""
|
| 539 |
+
"""
|
| 540 |
+
# Check constraints.
|
| 541 |
+
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
| 542 |
+
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
| 543 |
+
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
| 544 |
+
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
| 545 |
+
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
| 546 |
+
"""
|
| 547 |
+
|
| 548 |
+
original_dtype = hidden_states.dtype
|
| 549 |
+
ep_size = get_ep_group().world_size
|
| 550 |
+
local_num_experts = global_num_experts // ep_size
|
| 551 |
+
w1_input_scale, _ = w1_input_scale.max(0)
|
| 552 |
+
quant_sorted_hidden_states = quant_per_tensor(
|
| 553 |
+
hidden_states,
|
| 554 |
+
w1_input_scale,
|
| 555 |
+
None,
|
| 556 |
+
True,
|
| 557 |
+
)
|
| 558 |
+
if expert_map is not None:
|
| 559 |
+
expanded_x, expanded_row_idx, expert_token_count, expanded_scale = torch_npu.npu_moe_init_routing_v2(
|
| 560 |
+
quant_sorted_hidden_states,
|
| 561 |
+
topk_ids,
|
| 562 |
+
scale=None,
|
| 563 |
+
active_num=topk_ids.numel(),
|
| 564 |
+
expert_capacity=-1,
|
| 565 |
+
expert_num=local_num_experts,
|
| 566 |
+
drop_pad_mode=0,
|
| 567 |
+
expert_tokens_num_type=1,
|
| 568 |
+
expert_tokens_num_flag=True,
|
| 569 |
+
quant_mode=-1,
|
| 570 |
+
active_expert_range=[0, local_num_experts],
|
| 571 |
+
row_idx_type=0,
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
else:
|
| 575 |
+
raise NotImplementedError(
|
| 576 |
+
"The quantified version of MOE class models "
|
| 577 |
+
"currently does not support tensor parallelism")
|
| 578 |
+
if expanded_x.dtype != w1.dtype:
|
| 579 |
+
w1_input_scale, _ = w1_input_scale.max(0)
|
| 580 |
+
quant_sorted_hidden_states = quant_per_tensor(
|
| 581 |
+
expanded_x,
|
| 582 |
+
w1_input_scale,
|
| 583 |
+
None,
|
| 584 |
+
True,
|
| 585 |
+
)
|
| 586 |
+
else:
|
| 587 |
+
quant_sorted_hidden_states = expanded_x
|
| 588 |
+
gate_up_out = torch_npu.npu_grouped_matmul(
|
| 589 |
+
x=[quant_sorted_hidden_states],
|
| 590 |
+
weight=[w1],
|
| 591 |
+
scale=[w1_scale * w1_input_scale[0]],
|
| 592 |
+
split_item=2,
|
| 593 |
+
group_list_type=1,
|
| 594 |
+
group_type=0,
|
| 595 |
+
group_list=expert_token_count,
|
| 596 |
+
output_dtype=original_dtype,
|
| 597 |
+
)[0]
|
| 598 |
+
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
| 599 |
+
|
| 600 |
+
if gate_up_out.dtype != w2.dtype:
|
| 601 |
+
w2_input_scale, _ = w2_input_scale.max(0)
|
| 602 |
+
quant_gate_up_out = quant_per_tensor(
|
| 603 |
+
gate_up_out,
|
| 604 |
+
w2_input_scale,
|
| 605 |
+
None,
|
| 606 |
+
True,
|
| 607 |
+
)
|
| 608 |
+
else:
|
| 609 |
+
quant_gate_up_out = gate_up_out
|
| 610 |
+
|
| 611 |
+
down_out = torch_npu.npu_grouped_matmul(
|
| 612 |
+
x=[quant_gate_up_out],
|
| 613 |
+
weight=[w2],
|
| 614 |
+
scale=[w2_scale * w2_input_scale[0]],
|
| 615 |
+
split_item=2,
|
| 616 |
+
group_list_type=1,
|
| 617 |
+
group_type=0,
|
| 618 |
+
group_list=expert_token_count,
|
| 619 |
+
output_dtype=original_dtype,
|
| 620 |
+
)[0]
|
| 621 |
+
|
| 622 |
+
if expert_map is not None:
|
| 623 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 624 |
+
down_out,
|
| 625 |
+
skip1=None,
|
| 626 |
+
skip2=None,
|
| 627 |
+
bias=None,
|
| 628 |
+
scales=topk_weights.to(down_out.dtype),
|
| 629 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 630 |
+
export_for_source_row=topk_ids,
|
| 631 |
+
drop_pad_mode=2,
|
| 632 |
+
)
|
| 633 |
+
else:
|
| 634 |
+
raise NotImplementedError(
|
| 635 |
+
"The quantified version of MOE class models "
|
| 636 |
+
"currently does not support tensor parallelism")
|
| 637 |
+
|
| 638 |
+
return final_hidden_states
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
def select_experts(
|
| 642 |
+
hidden_states: torch.Tensor,
|
| 643 |
+
router_logits: torch.Tensor,
|
| 644 |
+
top_k: int,
|
| 645 |
+
use_grouped_topk: bool,
|
| 646 |
+
renormalize: bool,
|
| 647 |
+
topk_group: Optional[int] = None,
|
| 648 |
+
num_expert_group: Optional[int] = None,
|
| 649 |
+
custom_routing_function: Optional[Callable] = None,
|
| 650 |
+
scoring_func: str = "softmax",
|
| 651 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 652 |
+
global_num_experts=-1,
|
| 653 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 654 |
+
"""
|
| 655 |
+
Select top-k experts based on router logits.
|
| 656 |
+
|
| 657 |
+
Args:
|
| 658 |
+
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
| 659 |
+
router_logits: Router logits of shape (num_tokens, num_experts).
|
| 660 |
+
top_k: Number of experts to select.
|
| 661 |
+
use_grouped_topk: Whether to group experts before selecting top-k.
|
| 662 |
+
renormalize: Whether to renormalize the routing weights.
|
| 663 |
+
topk_group: Number of expert groups to select from.
|
| 664 |
+
num_expert_group: Number of experts in each group.
|
| 665 |
+
custom_routing_function: Custom routing function.
|
| 666 |
+
scoring_func: Scoring function to use.
|
| 667 |
+
e_score_correction_bias: Correction bias to apply to expert scores.
|
| 668 |
+
|
| 669 |
+
Returns:
|
| 670 |
+
topk_weights: Routing weights of shape (num_tokens, top_k).
|
| 671 |
+
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
| 672 |
+
|
| 673 |
+
Raises:
|
| 674 |
+
ValueError: If an unsupported scoring function is provided.
|
| 675 |
+
"""
|
| 676 |
+
|
| 677 |
+
if scoring_func == "softmax":
|
| 678 |
+
# NOTE: vLLM use dtype=torch.float here
|
| 679 |
+
topk_weights = router_logits.softmax(dim=-1)
|
| 680 |
+
elif scoring_func == "sigmoid":
|
| 681 |
+
topk_weights = router_logits.sigmoid()
|
| 682 |
+
else:
|
| 683 |
+
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
| 684 |
+
|
| 685 |
+
if use_grouped_topk:
|
| 686 |
+
assert topk_group is not None
|
| 687 |
+
assert num_expert_group is not None
|
| 688 |
+
|
| 689 |
+
if e_score_correction_bias is not None:
|
| 690 |
+
# Store original scores before applying correction bias. We use biased
|
| 691 |
+
# scores for expert selection but original scores for routing weights
|
| 692 |
+
original_weights = topk_weights
|
| 693 |
+
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
|
| 694 |
+
|
| 695 |
+
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
|
| 696 |
+
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
|
| 697 |
+
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
|
| 698 |
+
topk_group)
|
| 699 |
+
# TODO bfloat16 is not supported in torch.topk with ge graph.
|
| 700 |
+
if e_score_correction_bias is not None:
|
| 701 |
+
topk_ids = torch.topk(topk_weights.to(torch.float32),
|
| 702 |
+
k=top_k,
|
| 703 |
+
dim=-1,
|
| 704 |
+
sorted=False)[1]
|
| 705 |
+
# Use original unbiased scores for the routing weights
|
| 706 |
+
topk_weights = original_weights.gather(1, topk_ids)
|
| 707 |
+
else:
|
| 708 |
+
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
|
| 709 |
+
k=top_k,
|
| 710 |
+
dim=-1,
|
| 711 |
+
sorted=False)
|
| 712 |
+
elif custom_routing_function is None:
|
| 713 |
+
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
| 714 |
+
else:
|
| 715 |
+
topk_weights, topk_ids = custom_routing_function(
|
| 716 |
+
hidden_states=hidden_states,
|
| 717 |
+
gating_output=router_logits,
|
| 718 |
+
topk=top_k,
|
| 719 |
+
renormalize=renormalize,
|
| 720 |
+
global_num_experts=global_num_experts,
|
| 721 |
+
)
|
| 722 |
+
# Required by npu_moe_init_routing
|
| 723 |
+
topk_ids = topk_ids.to(torch.int32)
|
| 724 |
+
return topk_weights, topk_ids
|
| 725 |
+
|
| 726 |
+
# Required by npu_moe_init_routing
|
| 727 |
+
topk_ids = topk_ids.to(torch.int32)
|
| 728 |
+
|
| 729 |
+
if renormalize:
|
| 730 |
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
| 731 |
+
|
| 732 |
+
return topk_weights, topk_ids
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
def native_grouped_topk(
|
| 736 |
+
topk_weights: torch.Tensor,
|
| 737 |
+
num_expert_group: Optional[int],
|
| 738 |
+
topk_group: Optional[int],
|
| 739 |
+
):
|
| 740 |
+
topk_group = 0 if topk_group is None else topk_group
|
| 741 |
+
num_expert_group = 0 if num_expert_group is None else num_expert_group
|
| 742 |
+
|
| 743 |
+
num_token = topk_weights.shape[0]
|
| 744 |
+
grouped_weights = topk_weights.view(num_token, num_expert_group,
|
| 745 |
+
-1).max(dim=-1).values
|
| 746 |
+
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
|
| 747 |
+
k=topk_group,
|
| 748 |
+
dim=-1,
|
| 749 |
+
sorted=False)[1]
|
| 750 |
+
topk_group_mask = torch.zeros_like(grouped_weights)
|
| 751 |
+
topk_group_mask.scatter_(1, topk_group_indices, 1)
|
| 752 |
+
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
|
| 753 |
+
num_token, num_expert_group,
|
| 754 |
+
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
|
| 755 |
+
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
|
| 756 |
+
|
| 757 |
+
return topk_weights
|
inference/vllm_ascend/quantization/w8a8_dynamic.py
ADDED
|
@@ -0,0 +1,831 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# This file is a part of the vllm-ascend project.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
#
|
| 17 |
+
|
| 18 |
+
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.distributed as dist
|
| 22 |
+
import torch_npu
|
| 23 |
+
from vllm.distributed import GroupCoordinator
|
| 24 |
+
|
| 25 |
+
import vllm_ascend.envs as envs
|
| 26 |
+
from vllm_ascend.ascend_config import get_ascend_config
|
| 27 |
+
from vllm_ascend.distributed.parallel_state import get_ep_group
|
| 28 |
+
from vllm_ascend.ops.fused_moe import select_experts
|
| 29 |
+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, FusedMoEState,
|
| 30 |
+
dispose_tensor, get_fused_moe_state,
|
| 31 |
+
npu_stream_switch, npu_wait_tensor)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def apply_mlp(hidden_states: torch.Tensor,
|
| 35 |
+
w1: torch.Tensor,
|
| 36 |
+
w1_scale: torch.Tensor,
|
| 37 |
+
w2: torch.Tensor,
|
| 38 |
+
w2_scale: torch.Tensor,
|
| 39 |
+
group_list: torch.Tensor,
|
| 40 |
+
dynamic_scale: torch.Tensor = None,
|
| 41 |
+
group_list_type: int = 1) -> torch.Tensor:
|
| 42 |
+
"""
|
| 43 |
+
apply MLP: gate_up_proj -> swiglu -> down_proj
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
hidden_states: input hidden states with shape (num_tokens, hidden_size).
|
| 47 |
+
w1: expert weights1 with shape
|
| 48 |
+
(num_experts, hidden_size, intermediate_size * 2)
|
| 49 |
+
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
|
| 50 |
+
w2: expert weights2 with shape
|
| 51 |
+
(num_experts, intermediate_size, hidden_size)
|
| 52 |
+
w2_scale: weights2 scale with shape (num_experts, hidden_size)
|
| 53 |
+
group_list: number of tokens for each expert, follow cumsum mode, and
|
| 54 |
+
with shape (num_experts).
|
| 55 |
+
transpose_weight:
|
| 56 |
+
w1: (num_experts, intermediate_size * 2, hidden_size) ->
|
| 57 |
+
(num_experts, hidden_size, intermediate_size * 2)
|
| 58 |
+
w2: (num_experts, hidden_size, intermediate_size) ->
|
| 59 |
+
(num_experts, intermediate_size, hidden_size)
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
hidden_states: output hidden states after MLP.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
if dynamic_scale is None:
|
| 66 |
+
unquantized_hidden_states = hidden_states
|
| 67 |
+
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
| 68 |
+
hidden_states)
|
| 69 |
+
# Dispose the original unquantized hidden states
|
| 70 |
+
# to save npu memory because they're no longer used.
|
| 71 |
+
dispose_tensor(unquantized_hidden_states)
|
| 72 |
+
else:
|
| 73 |
+
pertoken_scale = dynamic_scale
|
| 74 |
+
|
| 75 |
+
# gmm1: gate_up_proj
|
| 76 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 77 |
+
x=[hidden_states],
|
| 78 |
+
weight=[w1],
|
| 79 |
+
scale=[w1_scale],
|
| 80 |
+
per_token_scale=[pertoken_scale],
|
| 81 |
+
split_item=2,
|
| 82 |
+
group_list_type=group_list_type,
|
| 83 |
+
group_type=0,
|
| 84 |
+
group_list=group_list,
|
| 85 |
+
output_dtype=w2_scale.dtype)[0]
|
| 86 |
+
|
| 87 |
+
# act_fn: swiglu
|
| 88 |
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
| 89 |
+
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
| 90 |
+
hidden_states)
|
| 91 |
+
|
| 92 |
+
# gmm2: down_proj
|
| 93 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 94 |
+
x=[hidden_states],
|
| 95 |
+
weight=[w2],
|
| 96 |
+
scale=[w2_scale],
|
| 97 |
+
per_token_scale=[swiglu_out_scale],
|
| 98 |
+
split_item=2,
|
| 99 |
+
group_list_type=group_list_type,
|
| 100 |
+
group_type=0,
|
| 101 |
+
group_list=group_list,
|
| 102 |
+
output_dtype=w2_scale.dtype)[0]
|
| 103 |
+
|
| 104 |
+
return hidden_states
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def fused_experts_with_mc2(
|
| 108 |
+
hidden_states: torch.Tensor,
|
| 109 |
+
w1: torch.Tensor,
|
| 110 |
+
w2: torch.Tensor,
|
| 111 |
+
w1_scale: torch.Tensor,
|
| 112 |
+
w2_scale: torch.Tensor,
|
| 113 |
+
topk_weights: torch.Tensor,
|
| 114 |
+
topk_ids: torch.Tensor,
|
| 115 |
+
top_k: int,
|
| 116 |
+
expert_map: torch.Tensor = None,
|
| 117 |
+
moe_all_to_all_group_name: str = "",
|
| 118 |
+
log2phy: torch.Tensor = None,
|
| 119 |
+
global_redundant_expert_num: int = 0,
|
| 120 |
+
shared_experts: Optional[Any] = None,
|
| 121 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 122 |
+
if log2phy is not None:
|
| 123 |
+
topk_ids = log2phy[topk_ids]
|
| 124 |
+
global_bs = 0
|
| 125 |
+
moe_expert_num = len(expert_map) + global_redundant_expert_num
|
| 126 |
+
# hidden_states = hidden_states.bfloat16()
|
| 127 |
+
kwargs_mc2 = {
|
| 128 |
+
"x": hidden_states,
|
| 129 |
+
"expert_ids": topk_ids,
|
| 130 |
+
"expert_shard_type": 0,
|
| 131 |
+
"shared_expert_rank_num": 0,
|
| 132 |
+
"moe_expert_num": moe_expert_num,
|
| 133 |
+
"global_bs": global_bs,
|
| 134 |
+
"expert_scales": topk_weights.to(torch.float32),
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
rank = torch.distributed.get_rank()
|
| 138 |
+
|
| 139 |
+
quant_mode = 2
|
| 140 |
+
ep_group = get_ep_group().device_group
|
| 141 |
+
local_rank = torch.distributed.get_rank(group=ep_group)
|
| 142 |
+
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
|
| 143 |
+
|
| 144 |
+
world_size = torch.distributed.get_world_size()
|
| 145 |
+
tp_size = world_size // all_to_all_group_size
|
| 146 |
+
tp_rank = rank % tp_size
|
| 147 |
+
|
| 148 |
+
stage1_kwargs = {
|
| 149 |
+
"scales": None,
|
| 150 |
+
"quant_mode": quant_mode,
|
| 151 |
+
"group_ep": moe_all_to_all_group_name,
|
| 152 |
+
"ep_world_size": all_to_all_group_size,
|
| 153 |
+
"ep_rank_id": local_rank,
|
| 154 |
+
# "group_tp": self.moe_rs_group_name,
|
| 155 |
+
"group_tp": moe_all_to_all_group_name,
|
| 156 |
+
"tp_world_size": tp_size,
|
| 157 |
+
"tp_rank_id": tp_rank,
|
| 158 |
+
}
|
| 159 |
+
kwargs_mc2.update(stage1_kwargs)
|
| 160 |
+
|
| 161 |
+
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
| 162 |
+
# comm_stream.wait_stream(torch.npu.current_stream())
|
| 163 |
+
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts, _, expand_scales = output[
|
| 164 |
+
0:7]
|
| 165 |
+
|
| 166 |
+
if shared_experts is not None:
|
| 167 |
+
with npu_stream_switch("moe_secondary", 0):
|
| 168 |
+
npu_wait_tensor(hidden_states, topk_weights)
|
| 169 |
+
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
| 170 |
+
npu_wait_tensor(shared_gate_up[0], expand_x)
|
| 171 |
+
shared_act = shared_experts.act_fn(shared_gate_up)
|
| 172 |
+
|
| 173 |
+
# `expand_x` will be disposed in the `apply_mlp` function
|
| 174 |
+
down_out_list = apply_mlp(expand_x,
|
| 175 |
+
w1,
|
| 176 |
+
w1_scale,
|
| 177 |
+
w2,
|
| 178 |
+
w2_scale,
|
| 179 |
+
expert_token_nums,
|
| 180 |
+
dynamic_scale=dynamic_scale)
|
| 181 |
+
|
| 182 |
+
# moeCombine
|
| 183 |
+
kwargs_mc2 = {
|
| 184 |
+
"expand_x": down_out_list,
|
| 185 |
+
"expert_ids": topk_ids,
|
| 186 |
+
"expand_idx": expand_idx,
|
| 187 |
+
"expert_scales": topk_weights.to(torch.float32),
|
| 188 |
+
"expert_shard_type": 0,
|
| 189 |
+
"shared_expert_rank_num": 0,
|
| 190 |
+
"moe_expert_num": moe_expert_num,
|
| 191 |
+
"global_bs": 0,
|
| 192 |
+
"expand_scales": expand_scales,
|
| 193 |
+
}
|
| 194 |
+
tp_recv_counts = torch.empty(1,
|
| 195 |
+
dtype=torch.int32,
|
| 196 |
+
device=hidden_states.device)
|
| 197 |
+
stage3_kwargs = {
|
| 198 |
+
"ep_send_counts": ep_recv_counts,
|
| 199 |
+
"group_ep": moe_all_to_all_group_name,
|
| 200 |
+
"ep_world_size": all_to_all_group_size,
|
| 201 |
+
"ep_rank_id": local_rank,
|
| 202 |
+
"tp_send_counts": tp_recv_counts,
|
| 203 |
+
# "group_tp": self.moe_rs_group_name,
|
| 204 |
+
"group_tp": moe_all_to_all_group_name,
|
| 205 |
+
"tp_world_size": tp_size,
|
| 206 |
+
"tp_rank_id": tp_rank,
|
| 207 |
+
}
|
| 208 |
+
kwargs_mc2.update(stage3_kwargs)
|
| 209 |
+
|
| 210 |
+
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
| 211 |
+
|
| 212 |
+
if shared_experts is None:
|
| 213 |
+
return hidden_states
|
| 214 |
+
else:
|
| 215 |
+
with npu_stream_switch("moe_secondary", 0):
|
| 216 |
+
npu_wait_tensor(shared_act[0], down_out_list)
|
| 217 |
+
shared_output, _ = shared_experts.down_proj(shared_act)
|
| 218 |
+
return hidden_states, shared_output
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# currently expert parallelism implemented with all2all
|
| 222 |
+
# is under-optimized.
|
| 223 |
+
def fused_experts_with_all2all(
|
| 224 |
+
hidden_states: torch.Tensor,
|
| 225 |
+
w1: torch.Tensor,
|
| 226 |
+
w1_scale: torch.Tensor,
|
| 227 |
+
w2: torch.Tensor,
|
| 228 |
+
w2_scale: torch.Tensor,
|
| 229 |
+
topk_weights: torch.Tensor,
|
| 230 |
+
topk_ids: torch.Tensor,
|
| 231 |
+
top_k: int,
|
| 232 |
+
expert_map: torch.Tensor = None,
|
| 233 |
+
ep_group: GroupCoordinator = None,
|
| 234 |
+
log2phy: torch.Tensor = None,
|
| 235 |
+
global_redundant_expert_num: int = 0,
|
| 236 |
+
):
|
| 237 |
+
if log2phy is not None:
|
| 238 |
+
topk_ids = log2phy[topk_ids]
|
| 239 |
+
original_shape = hidden_states.shape
|
| 240 |
+
if len(original_shape) == 3:
|
| 241 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 242 |
+
|
| 243 |
+
num_tokens, _ = hidden_states.shape
|
| 244 |
+
num_experts = w1.shape[0]
|
| 245 |
+
device = hidden_states.device
|
| 246 |
+
|
| 247 |
+
if expert_map is not None:
|
| 248 |
+
global_num_experts = len(expert_map) + global_redundant_expert_num
|
| 249 |
+
local_num_experts = global_num_experts // ep_group.world_size
|
| 250 |
+
row_idx_len = num_tokens * top_k
|
| 251 |
+
row_idx = (torch.arange(0,
|
| 252 |
+
row_idx_len,
|
| 253 |
+
dtype=torch.int32,
|
| 254 |
+
device=device).view(top_k, -1).permute(
|
| 255 |
+
1, 0).contiguous())
|
| 256 |
+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 257 |
+
hidden_states,
|
| 258 |
+
row_idx=row_idx,
|
| 259 |
+
expert_idx=topk_ids,
|
| 260 |
+
active_num=num_tokens)
|
| 261 |
+
|
| 262 |
+
global_expert_tokens = torch.bincount(expanded_expert_idx,
|
| 263 |
+
minlength=global_num_experts)
|
| 264 |
+
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
|
| 265 |
+
-1).sum(-1)
|
| 266 |
+
|
| 267 |
+
gather_sizes = torch.empty_like(scatter_sizes)
|
| 268 |
+
dist.all_to_all_single(gather_sizes,
|
| 269 |
+
scatter_sizes,
|
| 270 |
+
group=ep_group.device_group)
|
| 271 |
+
scatter_size_list = scatter_sizes.cpu().tolist()
|
| 272 |
+
gather_size_list = gather_sizes.cpu().tolist()
|
| 273 |
+
|
| 274 |
+
expanded_expert_idx = expanded_expert_idx % local_num_experts
|
| 275 |
+
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
| 276 |
+
scatter_size_list,
|
| 277 |
+
gather_size_list)
|
| 278 |
+
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
|
| 279 |
+
scatter_size_list,
|
| 280 |
+
gather_size_list)
|
| 281 |
+
|
| 282 |
+
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
|
| 283 |
+
|
| 284 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 285 |
+
sorted_local_expert_idx, local_num_experts).to(torch.int64)
|
| 286 |
+
|
| 287 |
+
hidden_states = hidden_states[sorted_idx]
|
| 288 |
+
group_list_type = 0
|
| 289 |
+
else:
|
| 290 |
+
row_idx_len = num_tokens * top_k
|
| 291 |
+
row_idx = torch.arange(0,
|
| 292 |
+
row_idx_len,
|
| 293 |
+
dtype=torch.int32,
|
| 294 |
+
device=topk_weights.device).view(
|
| 295 |
+
top_k, -1).permute(1, 0).contiguous()
|
| 296 |
+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 297 |
+
hidden_states,
|
| 298 |
+
row_idx=row_idx,
|
| 299 |
+
expert_idx=topk_ids,
|
| 300 |
+
active_num=num_tokens)
|
| 301 |
+
|
| 302 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 303 |
+
expanded_expert_idx, num_experts)
|
| 304 |
+
expert_tokens = expert_tokens.to(torch.int64)
|
| 305 |
+
group_list_type = 0
|
| 306 |
+
|
| 307 |
+
# `hidden_states` will be disposed in the `apply_mlp` function
|
| 308 |
+
hidden_states = apply_mlp(
|
| 309 |
+
hidden_states,
|
| 310 |
+
w1,
|
| 311 |
+
w1_scale, #17
|
| 312 |
+
w2,
|
| 313 |
+
w2_scale,
|
| 314 |
+
expert_tokens, #16
|
| 315 |
+
group_list_type=group_list_type)
|
| 316 |
+
|
| 317 |
+
if expert_map is not None:
|
| 318 |
+
resorted_idx = torch.argsort(sorted_idx)
|
| 319 |
+
hidden_states = hidden_states[resorted_idx]
|
| 320 |
+
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
| 321 |
+
gather_size_list,
|
| 322 |
+
scatter_size_list)
|
| 323 |
+
|
| 324 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 325 |
+
hidden_states,
|
| 326 |
+
skip1=None,
|
| 327 |
+
skip2=None,
|
| 328 |
+
bias=None,
|
| 329 |
+
scales=topk_weights,
|
| 330 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 331 |
+
export_for_source_row=topk_ids,
|
| 332 |
+
)
|
| 333 |
+
else:
|
| 334 |
+
# TODO: Reorder device memory 2 times here, replace the current
|
| 335 |
+
# implementation here when suitable operators become available.
|
| 336 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 337 |
+
hidden_states,
|
| 338 |
+
skip1=None,
|
| 339 |
+
skip2=None,
|
| 340 |
+
bias=None,
|
| 341 |
+
scales=topk_weights,
|
| 342 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 343 |
+
export_for_source_row=topk_ids,
|
| 344 |
+
)
|
| 345 |
+
if len(original_shape) == 3:
|
| 346 |
+
final_hidden_states = final_hidden_states.view(original_shape)
|
| 347 |
+
return final_hidden_states
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def fused_experts_with_allgather(hidden_states: torch.Tensor,
|
| 351 |
+
w1: torch.Tensor,
|
| 352 |
+
w1_scale: torch.Tensor,
|
| 353 |
+
w2: torch.Tensor,
|
| 354 |
+
w2_scale: torch.Tensor,
|
| 355 |
+
topk_weights: torch.Tensor,
|
| 356 |
+
topk_ids: torch.Tensor,
|
| 357 |
+
top_k: int,
|
| 358 |
+
expert_map: torch.Tensor = None):
|
| 359 |
+
original_shape = hidden_states.shape
|
| 360 |
+
if len(original_shape) == 3:
|
| 361 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 362 |
+
num_tokens = hidden_states.shape[0]
|
| 363 |
+
batch_size, hidden_size = hidden_states.shape
|
| 364 |
+
|
| 365 |
+
ep_group = get_ep_group().device_group
|
| 366 |
+
ep_rank = torch.distributed.get_rank(group=ep_group)
|
| 367 |
+
ep_size = torch.distributed.get_world_size(ep_group)
|
| 368 |
+
|
| 369 |
+
global_num_experts = len(expert_map)
|
| 370 |
+
local_num_experts = global_num_experts // ep_size
|
| 371 |
+
|
| 372 |
+
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
| 373 |
+
|
| 374 |
+
hidden_states, expanded_x_idx, expert_tokens, pertoken_scale = torch_npu.npu_moe_init_routing_v2(
|
| 375 |
+
hidden_states,
|
| 376 |
+
topk_ids,
|
| 377 |
+
scale=pertoken_scale,
|
| 378 |
+
offset=None,
|
| 379 |
+
active_num=num_tokens * top_k,
|
| 380 |
+
expert_num=global_num_experts,
|
| 381 |
+
expert_tokens_num_type=1,
|
| 382 |
+
expert_tokens_num_flag=True,
|
| 383 |
+
active_expert_range=[
|
| 384 |
+
ep_rank * local_num_experts, (ep_rank + 1) * local_num_experts
|
| 385 |
+
],
|
| 386 |
+
quant_mode=-1,
|
| 387 |
+
row_idx_type=0)
|
| 388 |
+
group_list_type = 1
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 392 |
+
x=[hidden_states],
|
| 393 |
+
weight=[w1],
|
| 394 |
+
split_item=3,
|
| 395 |
+
group_list_type=group_list_type,
|
| 396 |
+
group_type=0,
|
| 397 |
+
group_list=expert_tokens,
|
| 398 |
+
output_dtype=torch.int32)[0]
|
| 399 |
+
|
| 400 |
+
# act_fn: swiglu
|
| 401 |
+
hidden_states, pertoken_scale = torch_npu.npu_dequant_swiglu_quant(
|
| 402 |
+
x=hidden_states,
|
| 403 |
+
weight_scale=w1_scale.to(torch.float32),
|
| 404 |
+
activation_scale=pertoken_scale,
|
| 405 |
+
bias=None,
|
| 406 |
+
quant_scale=None,
|
| 407 |
+
quant_offset=None,
|
| 408 |
+
group_index=expert_tokens,
|
| 409 |
+
activate_left=True,
|
| 410 |
+
quant_mode=1,
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
hidden_states = torch_npu.npu_grouped_matmul(
|
| 414 |
+
x=[hidden_states],
|
| 415 |
+
weight=[w2],
|
| 416 |
+
scale=[w2_scale.to(torch.bfloat16)],
|
| 417 |
+
per_token_scale=[pertoken_scale.view(-1)],
|
| 418 |
+
split_item=3,
|
| 419 |
+
group_list_type=group_list_type,
|
| 420 |
+
group_type=0,
|
| 421 |
+
group_list=expert_tokens,
|
| 422 |
+
output_dtype=torch.bfloat16)[0]
|
| 423 |
+
|
| 424 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 425 |
+
expanded_permuted_rows=hidden_states.unsqueeze(1),
|
| 426 |
+
skip1=None,
|
| 427 |
+
skip2=None,
|
| 428 |
+
bias=None,
|
| 429 |
+
scales=topk_weights.to(torch.bfloat16),
|
| 430 |
+
expanded_src_to_dst_row=expanded_x_idx.to(torch.int32),
|
| 431 |
+
export_for_source_row=topk_ids,
|
| 432 |
+
drop_pad_mode=3
|
| 433 |
+
).to(torch.bfloat16)
|
| 434 |
+
|
| 435 |
+
if len(original_shape) == 3:
|
| 436 |
+
final_hidden_states = final_hidden_states.view(original_shape)
|
| 437 |
+
|
| 438 |
+
return final_hidden_states
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def fused_experts(hidden_states: torch.Tensor,
|
| 442 |
+
w1: torch.Tensor,
|
| 443 |
+
w1_scale: torch.Tensor,
|
| 444 |
+
w2: torch.Tensor,
|
| 445 |
+
w2_scale: torch.Tensor,
|
| 446 |
+
topk_weights: torch.Tensor,
|
| 447 |
+
topk_ids: torch.Tensor,
|
| 448 |
+
top_k: int,
|
| 449 |
+
expert_map: torch.Tensor = None):
|
| 450 |
+
original_shape = hidden_states.shape
|
| 451 |
+
if len(original_shape) == 3:
|
| 452 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 453 |
+
|
| 454 |
+
num_tokens, _ = hidden_states.shape
|
| 455 |
+
num_experts = w1.shape[0]
|
| 456 |
+
dtype = hidden_states.dtype
|
| 457 |
+
device = hidden_states.device
|
| 458 |
+
|
| 459 |
+
if expert_map is not None:
|
| 460 |
+
# Generate token indices and flatten
|
| 461 |
+
token_indices = (torch.arange(num_tokens,
|
| 462 |
+
device=device,
|
| 463 |
+
dtype=torch.int64).unsqueeze(1).expand(
|
| 464 |
+
-1, top_k).reshape(-1))
|
| 465 |
+
|
| 466 |
+
# Flatten token-to-expert mappings and map to local experts
|
| 467 |
+
weights_flat = topk_weights.view(-1)
|
| 468 |
+
experts_flat = topk_ids.view(-1)
|
| 469 |
+
local_experts_flat = expert_map[experts_flat]
|
| 470 |
+
|
| 471 |
+
# Filter valid token-expert pairs
|
| 472 |
+
mask = local_experts_flat != -1
|
| 473 |
+
filtered_weights = torch.where(
|
| 474 |
+
mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype)
|
| 475 |
+
filtered_experts = torch.where(
|
| 476 |
+
mask, local_experts_flat,
|
| 477 |
+
torch.full_like(local_experts_flat,
|
| 478 |
+
num_experts)).to(topk_ids.dtype)
|
| 479 |
+
|
| 480 |
+
# Sort by local expert IDs
|
| 481 |
+
sort_indices = torch.argsort(filtered_experts)
|
| 482 |
+
sorted_token_indices = token_indices[sort_indices]
|
| 483 |
+
sorted_weights = filtered_weights[sort_indices]
|
| 484 |
+
|
| 485 |
+
# Compute token counts with minlength of num_experts
|
| 486 |
+
# This is equivalent to but faster than:
|
| 487 |
+
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
| 488 |
+
token_counts = torch.zeros(num_experts + 1,
|
| 489 |
+
device=device,
|
| 490 |
+
dtype=torch.int64)
|
| 491 |
+
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
| 492 |
+
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
| 493 |
+
expert_tokens = token_counts[:num_experts]
|
| 494 |
+
# Rearrange hidden_states
|
| 495 |
+
hidden_states = hidden_states[sorted_token_indices]
|
| 496 |
+
group_list_type = 1
|
| 497 |
+
else:
|
| 498 |
+
row_idx_len = num_tokens * top_k
|
| 499 |
+
row_idx = torch.arange(0,
|
| 500 |
+
row_idx_len,
|
| 501 |
+
dtype=torch.int32,
|
| 502 |
+
device=topk_weights.device).view(
|
| 503 |
+
top_k, -1).permute(1, 0).contiguous()
|
| 504 |
+
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
| 505 |
+
hidden_states,
|
| 506 |
+
row_idx=row_idx,
|
| 507 |
+
expert_idx=topk_ids,
|
| 508 |
+
active_num=num_tokens)
|
| 509 |
+
|
| 510 |
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
| 511 |
+
expanded_expert_idx, num_experts)
|
| 512 |
+
expert_tokens = expert_tokens.to(torch.int64)
|
| 513 |
+
group_list_type = 0
|
| 514 |
+
|
| 515 |
+
# `hidden_states` will be disposed in the `apply_mlp` function
|
| 516 |
+
hidden_states = apply_mlp(hidden_states,
|
| 517 |
+
w1,
|
| 518 |
+
w1_scale,
|
| 519 |
+
w2,
|
| 520 |
+
w2_scale,
|
| 521 |
+
expert_tokens,
|
| 522 |
+
group_list_type=group_list_type)
|
| 523 |
+
|
| 524 |
+
if expert_map is not None:
|
| 525 |
+
hidden_states.mul_(sorted_weights.unsqueeze(1))
|
| 526 |
+
final_hidden_states = torch.zeros(*original_shape,
|
| 527 |
+
device=device,
|
| 528 |
+
dtype=dtype)
|
| 529 |
+
|
| 530 |
+
num_valid_tokens = mask.sum()
|
| 531 |
+
valid_token_mask = torch.arange(
|
| 532 |
+
0, sorted_token_indices.shape[0],
|
| 533 |
+
device=device).unsqueeze(1) < num_valid_tokens
|
| 534 |
+
hidden_states = hidden_states.masked_fill_(~valid_token_mask,
|
| 535 |
+
0).to(dtype)
|
| 536 |
+
final_hidden_states.index_add_(0, sorted_token_indices, hidden_states)
|
| 537 |
+
else:
|
| 538 |
+
# TODO: Reorder device memory 2 times here, replace the current
|
| 539 |
+
# implementation here when suitable operators become available.
|
| 540 |
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
| 541 |
+
hidden_states,
|
| 542 |
+
skip1=None,
|
| 543 |
+
skip2=None,
|
| 544 |
+
bias=None,
|
| 545 |
+
scales=topk_weights,
|
| 546 |
+
expanded_src_to_dst_row=expanded_row_idx,
|
| 547 |
+
export_for_source_row=topk_ids,
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
if len(original_shape) == 3:
|
| 551 |
+
final_hidden_states = final_hidden_states.view(original_shape)
|
| 552 |
+
return final_hidden_states
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
class AscendW8A8DynamicLinearMethod:
|
| 556 |
+
"""Linear method for Ascend W8A8_DYNAMIC.
|
| 557 |
+
"""
|
| 558 |
+
|
| 559 |
+
def __init__(self):
|
| 560 |
+
self.transpose_weight = True
|
| 561 |
+
|
| 562 |
+
@staticmethod
|
| 563 |
+
def get_weight(input_size: int, output_size: int,
|
| 564 |
+
params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 565 |
+
params_dict = {
|
| 566 |
+
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
| 567 |
+
}
|
| 568 |
+
return params_dict
|
| 569 |
+
|
| 570 |
+
@staticmethod
|
| 571 |
+
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 572 |
+
return {}
|
| 573 |
+
|
| 574 |
+
@staticmethod
|
| 575 |
+
def get_perchannel_param(
|
| 576 |
+
output_size: int,
|
| 577 |
+
params_dtype: torch.dtype,
|
| 578 |
+
) -> Dict[str, Any]:
|
| 579 |
+
params_dict = {}
|
| 580 |
+
params_dict["weight_scale"] = torch.empty(output_size,
|
| 581 |
+
1,
|
| 582 |
+
dtype=params_dtype)
|
| 583 |
+
params_dict["weight_offset"] = torch.empty(output_size,
|
| 584 |
+
1,
|
| 585 |
+
dtype=params_dtype)
|
| 586 |
+
return params_dict
|
| 587 |
+
|
| 588 |
+
@staticmethod
|
| 589 |
+
def apply(
|
| 590 |
+
layer: torch.nn.Module,
|
| 591 |
+
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
| 592 |
+
bias: Optional[torch.Tensor] = None,
|
| 593 |
+
tp_rank: Optional[int] = 0,
|
| 594 |
+
) -> torch.Tensor:
|
| 595 |
+
config = getattr(layer, "_ascend_quant_config", {})
|
| 596 |
+
if not isinstance(x, tuple):
|
| 597 |
+
output_dtype = config.get("output_dtype", x.dtype)
|
| 598 |
+
quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
| 599 |
+
else:
|
| 600 |
+
assert "output_dtype" in config.keys(), (
|
| 601 |
+
f"DynamicLinearMethod needs explicitly specified `output_dtype`"
|
| 602 |
+
f"for pre-quantized input, got config [{config}]")
|
| 603 |
+
output_dtype = config["output_dtype"]
|
| 604 |
+
quantized_x, dynamic_scale = x
|
| 605 |
+
pertoken_scale = (dynamic_scale
|
| 606 |
+
if config.get("pertoken_scale", True) else None)
|
| 607 |
+
|
| 608 |
+
output = torch_npu.npu_quant_matmul(
|
| 609 |
+
quantized_x,
|
| 610 |
+
layer.weight,
|
| 611 |
+
layer.weight_scale,
|
| 612 |
+
pertoken_scale=pertoken_scale,
|
| 613 |
+
bias=bias,
|
| 614 |
+
output_dtype=output_dtype,
|
| 615 |
+
)
|
| 616 |
+
return ((output, dynamic_scale)
|
| 617 |
+
if config.get("return_scale", False) else output)
|
| 618 |
+
|
| 619 |
+
def process_weights_after_loading(self, layer):
|
| 620 |
+
if self.transpose_weight:
|
| 621 |
+
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
| 622 |
+
# cast quantized weight tensors in NZ format (29) for higher inference speed
|
| 623 |
+
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
| 624 |
+
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
| 625 |
+
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
| 626 |
+
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
class AscendW8A8DynamicFusedMoEMethod:
|
| 630 |
+
"""FusedMoe method for Ascend W8A8_DYNAMIC.
|
| 631 |
+
"""
|
| 632 |
+
|
| 633 |
+
def __init__(self):
|
| 634 |
+
self.transpose_weight = True
|
| 635 |
+
|
| 636 |
+
self.ep_group = get_ep_group()
|
| 637 |
+
|
| 638 |
+
ascend_config = get_ascend_config()
|
| 639 |
+
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 640 |
+
|
| 641 |
+
try:
|
| 642 |
+
device_group = self.ep_group.device_group
|
| 643 |
+
# TODO: Try local_rank = ep_group.rank_in_group
|
| 644 |
+
local_rank = torch.distributed.get_rank(group=device_group)
|
| 645 |
+
backend = device_group._get_backend(torch.device("npu"))
|
| 646 |
+
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
| 647 |
+
local_rank)
|
| 648 |
+
except AttributeError:
|
| 649 |
+
self.moe_all_to_all_group_name = ""
|
| 650 |
+
|
| 651 |
+
@staticmethod
|
| 652 |
+
def get_weight(num_experts: int, intermediate_size_per_partition: int,
|
| 653 |
+
hidden_sizes: int,
|
| 654 |
+
params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 655 |
+
param_dict = {}
|
| 656 |
+
param_dict["w13_weight"] = torch.empty(num_experts,
|
| 657 |
+
2 *
|
| 658 |
+
intermediate_size_per_partition,
|
| 659 |
+
hidden_sizes,
|
| 660 |
+
dtype=torch.int8)
|
| 661 |
+
param_dict["w2_weight"] = torch.empty(num_experts,
|
| 662 |
+
hidden_sizes,
|
| 663 |
+
intermediate_size_per_partition,
|
| 664 |
+
dtype=torch.int8)
|
| 665 |
+
return param_dict
|
| 666 |
+
|
| 667 |
+
@staticmethod
|
| 668 |
+
def get_dynamic_quant_param(num_experts: int,
|
| 669 |
+
intermediate_size_per_partition: int,
|
| 670 |
+
hidden_sizes: int,
|
| 671 |
+
params_dtype: torch.dtype) -> Dict[str, Any]:
|
| 672 |
+
param_dict = {}
|
| 673 |
+
param_dict["w13_weight_scale"] = torch.empty(
|
| 674 |
+
num_experts,
|
| 675 |
+
2 * intermediate_size_per_partition,
|
| 676 |
+
1,
|
| 677 |
+
dtype=params_dtype)
|
| 678 |
+
param_dict["w13_weight_offset"] = torch.empty(
|
| 679 |
+
num_experts,
|
| 680 |
+
2 * intermediate_size_per_partition,
|
| 681 |
+
1,
|
| 682 |
+
dtype=params_dtype)
|
| 683 |
+
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
| 684 |
+
hidden_sizes,
|
| 685 |
+
1,
|
| 686 |
+
dtype=params_dtype)
|
| 687 |
+
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
| 688 |
+
hidden_sizes,
|
| 689 |
+
1,
|
| 690 |
+
dtype=params_dtype)
|
| 691 |
+
return param_dict
|
| 692 |
+
|
| 693 |
+
def apply(
|
| 694 |
+
self,
|
| 695 |
+
layer: torch.nn.Module,
|
| 696 |
+
x: torch.Tensor,
|
| 697 |
+
router_logits: torch.Tensor,
|
| 698 |
+
top_k: int,
|
| 699 |
+
renormalize: bool,
|
| 700 |
+
use_grouped_topk: bool = False,
|
| 701 |
+
global_num_experts: int = -1,
|
| 702 |
+
expert_map: Optional[torch.Tensor] = None,
|
| 703 |
+
topk_group: Optional[int] = None,
|
| 704 |
+
num_expert_group: Optional[int] = None,
|
| 705 |
+
custom_routing_function: Optional[Callable] = None,
|
| 706 |
+
scoring_func: str = "softmax",
|
| 707 |
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
| 708 |
+
is_prefill: bool = True,
|
| 709 |
+
enable_force_load_balance: bool = True,
|
| 710 |
+
log2phy: torch.Tensor = None,
|
| 711 |
+
global_redundant_expert_num: int = 0,
|
| 712 |
+
shared_experts: Optional[Any] = None,
|
| 713 |
+
**kwargs,
|
| 714 |
+
) -> torch.Tensor:
|
| 715 |
+
assert router_logits.shape[
|
| 716 |
+
1] == global_num_experts, "Number of global experts mismatch"
|
| 717 |
+
|
| 718 |
+
is_deepseek_v3_r1 = global_num_experts == 256
|
| 719 |
+
use_grouped_topk = (topk_group > 1 or num_expert_group > 1)
|
| 720 |
+
|
| 721 |
+
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
| 722 |
+
if use_grouped_topk and is_deepseek_v3_r1:
|
| 723 |
+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
| 724 |
+
router_logits,
|
| 725 |
+
k=top_k, # topk当前写8
|
| 726 |
+
bias=e_score_correction_bias,
|
| 727 |
+
k_group=topk_group, # fix: 4
|
| 728 |
+
group_count=num_expert_group, # fix 8
|
| 729 |
+
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
|
| 730 |
+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
| 731 |
+
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
| 732 |
+
# out_flag=False, # todo new api; 第三个输出是否输出
|
| 733 |
+
# y2_flag=False, # old api; 第三个输出是否输出
|
| 734 |
+
routed_scaling_factor=1,
|
| 735 |
+
eps=float(1e-20))
|
| 736 |
+
else:
|
| 737 |
+
topk_weights, topk_ids = select_experts(
|
| 738 |
+
hidden_states=x,
|
| 739 |
+
router_logits=router_logits,
|
| 740 |
+
top_k=top_k,
|
| 741 |
+
use_grouped_topk=use_grouped_topk,
|
| 742 |
+
renormalize=renormalize,
|
| 743 |
+
topk_group=topk_group,
|
| 744 |
+
num_expert_group=num_expert_group,
|
| 745 |
+
custom_routing_function=custom_routing_function,
|
| 746 |
+
scoring_func=scoring_func,
|
| 747 |
+
e_score_correction_bias=e_score_correction_bias,
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
# this is a naive implementation for experts load balance so as
|
| 751 |
+
# to avoid accumulating too much tokens on a single rank.
|
| 752 |
+
# currently it is only activated when doing profile runs.
|
| 753 |
+
if enable_force_load_balance:
|
| 754 |
+
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
| 755 |
+
|
| 756 |
+
topk_weights = topk_weights.to(x.dtype)
|
| 757 |
+
|
| 758 |
+
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
|
| 759 |
+
is_prefill, is_deepseek_v3_r1)
|
| 760 |
+
if fused_moe_state == FusedMoEState.AllGatherEP:
|
| 761 |
+
return fused_experts_with_allgather(
|
| 762 |
+
hidden_states=x,
|
| 763 |
+
w1=layer.w13_weight,
|
| 764 |
+
w1_scale=layer.w13_weight_scale,
|
| 765 |
+
w2=layer.w2_weight,
|
| 766 |
+
w2_scale=layer.w2_weight_scale,
|
| 767 |
+
topk_weights=topk_weights,
|
| 768 |
+
topk_ids=topk_ids,
|
| 769 |
+
top_k=top_k,
|
| 770 |
+
expert_map=expert_map)
|
| 771 |
+
elif fused_moe_state == FusedMoEState.MC2:
|
| 772 |
+
return fused_experts_with_mc2(
|
| 773 |
+
hidden_states=x,
|
| 774 |
+
w1=layer.w13_weight,
|
| 775 |
+
w2=layer.w2_weight,
|
| 776 |
+
w1_scale=layer.w13_weight_scale,
|
| 777 |
+
w2_scale=layer.w2_weight_scale,
|
| 778 |
+
topk_weights=topk_weights,
|
| 779 |
+
topk_ids=topk_ids,
|
| 780 |
+
top_k=top_k,
|
| 781 |
+
expert_map=expert_map,
|
| 782 |
+
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
| 783 |
+
log2phy=log2phy,
|
| 784 |
+
global_redundant_expert_num=global_redundant_expert_num,
|
| 785 |
+
shared_experts=shared_experts)
|
| 786 |
+
elif fused_moe_state in [
|
| 787 |
+
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
| 788 |
+
]:
|
| 789 |
+
return fused_experts(hidden_states=x,
|
| 790 |
+
w1=layer.w13_weight,
|
| 791 |
+
w1_scale=layer.w13_weight_scale,
|
| 792 |
+
w2=layer.w2_weight,
|
| 793 |
+
w2_scale=layer.w2_weight_scale,
|
| 794 |
+
topk_weights=topk_weights,
|
| 795 |
+
topk_ids=topk_ids,
|
| 796 |
+
top_k=top_k,
|
| 797 |
+
expert_map=expert_map)
|
| 798 |
+
else:
|
| 799 |
+
# The current implementation of deepseek moe splits hidden_states
|
| 800 |
+
# according to tp_size before they are feed into fused_moe module.
|
| 801 |
+
# Therefore, all2all is needed no matter how dp/tp is set so as to
|
| 802 |
+
# dispatch/combine tokens.
|
| 803 |
+
return fused_experts_with_all2all(
|
| 804 |
+
hidden_states=x,
|
| 805 |
+
w1=layer.w13_weight,
|
| 806 |
+
w1_scale=layer.w13_weight_scale,
|
| 807 |
+
w2=layer.w2_weight,
|
| 808 |
+
w2_scale=layer.w2_weight_scale,
|
| 809 |
+
topk_weights=topk_weights,
|
| 810 |
+
topk_ids=topk_ids,
|
| 811 |
+
top_k=top_k,
|
| 812 |
+
expert_map=expert_map,
|
| 813 |
+
ep_group=self.ep_group,
|
| 814 |
+
log2phy=log2phy,
|
| 815 |
+
global_redundant_expert_num=global_redundant_expert_num,
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
def process_weights_after_loading(self, layer):
|
| 819 |
+
if self.transpose_weight:
|
| 820 |
+
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
| 821 |
+
1, 2).contiguous()
|
| 822 |
+
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
| 823 |
+
1, 2).contiguous()
|
| 824 |
+
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
| 825 |
+
layer.w13_weight_scale.data.shape[0], -1)
|
| 826 |
+
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
| 827 |
+
layer.w13_weight_offset.data.shape[0], -1)
|
| 828 |
+
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
| 829 |
+
layer.w2_weight_scale.data.shape[0], -1)
|
| 830 |
+
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
| 831 |
+
layer.w2_weight_offset.data.shape[0], -1)
|
inference/vllm_ascend/utils.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# Copyright 2023 The vLLM team.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
# This file is a part of the vllm-ascend project.
|
| 17 |
+
# Adapted from vllm-project/vllm/vllm/worker/worker.py
|
| 18 |
+
#
|
| 19 |
+
|
| 20 |
+
import atexit
|
| 21 |
+
import fcntl
|
| 22 |
+
import math
|
| 23 |
+
import os
|
| 24 |
+
import shutil
|
| 25 |
+
from contextlib import contextmanager, nullcontext
|
| 26 |
+
from enum import Enum
|
| 27 |
+
from threading import Lock
|
| 28 |
+
from typing import TYPE_CHECKING, List, Tuple
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch_npu # noqa: F401 # noqa: F401
|
| 32 |
+
from packaging.version import InvalidVersion, Version
|
| 33 |
+
from torch_npu.npu.streams import Event
|
| 34 |
+
from vllm.logger import logger
|
| 35 |
+
|
| 36 |
+
import vllm_ascend.envs as envs
|
| 37 |
+
from vllm_ascend.ascend_config import get_ascend_config
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
# Recent release of torchair has moved these ops to `.scope`.
|
| 41 |
+
from torchair.scope import npu_stream_switch as _npu_stream_switch
|
| 42 |
+
from torchair.scope import npu_wait_tensor as _npu_wait_tensor
|
| 43 |
+
except ImportError:
|
| 44 |
+
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
|
| 45 |
+
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
|
| 46 |
+
|
| 47 |
+
if TYPE_CHECKING:
|
| 48 |
+
from vllm.config import VllmConfig
|
| 49 |
+
else:
|
| 50 |
+
VllmConfig = None
|
| 51 |
+
|
| 52 |
+
# NOTE: Currently, we can only capture 1920 graphs at most,
|
| 53 |
+
# due to the limitation of ACL graph. This number is bounded by
|
| 54 |
+
# the number of streams, which is 2048, we save 128 streams
|
| 55 |
+
# as a buffer.
|
| 56 |
+
# Maximum number of graphs that can be captured by ACL Graph
|
| 57 |
+
MAX_CAPTURE_SIZE = 1920
|
| 58 |
+
|
| 59 |
+
ASCEND_QUATIZATION_METHOD = "ascend"
|
| 60 |
+
SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"]
|
| 61 |
+
|
| 62 |
+
ACL_FORMAT_FRACTAL_ND = 2
|
| 63 |
+
ACL_FORMAT_FRACTAL_NZ = 29
|
| 64 |
+
|
| 65 |
+
_CUSTOM_OP_ENABLED = None
|
| 66 |
+
_IS_310P = None
|
| 67 |
+
_SLEEP_MODE_ENABLED = None
|
| 68 |
+
_CURRENT_STREAM = None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def is_310p():
|
| 72 |
+
global _IS_310P
|
| 73 |
+
if _IS_310P is None:
|
| 74 |
+
from vllm_ascend import _build_info # type: ignore
|
| 75 |
+
_IS_310P = _build_info.__soc_version__.lower().startswith("ascend310p")
|
| 76 |
+
return _IS_310P
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def sleep_mode_enabled():
|
| 80 |
+
global _SLEEP_MODE_ENABLED
|
| 81 |
+
if _SLEEP_MODE_ENABLED is None:
|
| 82 |
+
from vllm_ascend import _build_info # type: ignore
|
| 83 |
+
_SLEEP_MODE_ENABLED = _build_info.__sleep_mode_enabled__
|
| 84 |
+
return _SLEEP_MODE_ENABLED
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _round_up(x: int, align: int):
|
| 88 |
+
# round up x to align, for example, if align is 16, x will be rounded up to 16, 32, 48, etc.
|
| 89 |
+
# input: 15, 16 -> output: 16
|
| 90 |
+
# input: 17, 16 -> output: 32
|
| 91 |
+
# input: 30, 16 -> output: 32
|
| 92 |
+
# input: 33, 16 -> output: 48
|
| 93 |
+
# ...
|
| 94 |
+
return (x + align - 1) // align * align
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _custom_pad(x, pad_dims):
|
| 98 |
+
# pad the input tensor to the shape of pad_dims
|
| 99 |
+
# input: (13, 30), pad_dims: [0, 2, 0, 3]
|
| 100 |
+
# output: (16, 32)
|
| 101 |
+
return torch.nn.functional.pad(x, pad_dims)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _custom_reshape(x, target_shape):
|
| 105 |
+
# reshape the input tensor to the shape of target_shape
|
| 106 |
+
# input: (16, 32), target_shape: [1, 16, 2, 16]
|
| 107 |
+
# output: (1, 16, 2, 16)
|
| 108 |
+
return x.reshape(target_shape)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _custom_transpose(x, dim1, dim2):
|
| 112 |
+
# transpose the input tensor
|
| 113 |
+
# input: (1, 16, 2, 16), dim1: 1, dim2: 2
|
| 114 |
+
# output: (1, 2, 16, 16)
|
| 115 |
+
return x.transpose(dim1, dim2)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def nd_to_nz_2d(in_tensor: torch.Tensor) -> torch.Tensor:
|
| 119 |
+
# in_tensor: (13, 30)
|
| 120 |
+
aux_dims = [1, 0, 0, 16]
|
| 121 |
+
# aux_dims[1]: 16
|
| 122 |
+
aux_dims[1] = _round_up(in_tensor.size(0), 16)
|
| 123 |
+
# aux_dims[2]: 2
|
| 124 |
+
aux_dims[2] = _round_up(in_tensor.size(1), 16) // 16
|
| 125 |
+
|
| 126 |
+
# after: aux_dims: [1, 16, 2, 16]
|
| 127 |
+
|
| 128 |
+
pad_dims = [0, 0, 0, 0]
|
| 129 |
+
# pad_dims[1]: 2
|
| 130 |
+
pad_dims[1] = _round_up(in_tensor.size(1), 16) - in_tensor.size(1)
|
| 131 |
+
# pad_dims[3]: 3
|
| 132 |
+
pad_dims[3] = _round_up(in_tensor.size(0), 16) - in_tensor.size(0)
|
| 133 |
+
|
| 134 |
+
# after: pad_dims: [0, 2, 0, 3]
|
| 135 |
+
|
| 136 |
+
# return: (1, 2, 16, 16)
|
| 137 |
+
return _custom_transpose(
|
| 138 |
+
_custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1,
|
| 139 |
+
2).contiguous()
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor:
|
| 143 |
+
num_tokens = mask_tensor.shape[0]
|
| 144 |
+
max_seq_len = mask_tensor.shape[1]
|
| 145 |
+
|
| 146 |
+
tokens_pad = (num_tokens + 15) // 16 * 16
|
| 147 |
+
max_seq_len_pad = (max_seq_len + 15) // 16 * 16
|
| 148 |
+
|
| 149 |
+
mask_tensor_pad = \
|
| 150 |
+
torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device)
|
| 151 |
+
mask_tensor_pad[0][:num_tokens, :max_seq_len] = mask_tensor
|
| 152 |
+
mask = mask_tensor_pad.reshape(
|
| 153 |
+
(1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)
|
| 154 |
+
return mask
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def aligned_16(tensor: torch.Tensor):
|
| 158 |
+
"""Aligned tensor for 310P"""
|
| 159 |
+
|
| 160 |
+
# Get the size of the current 0th dimension
|
| 161 |
+
n = tensor.size(0)
|
| 162 |
+
|
| 163 |
+
# Calculate the aligned size
|
| 164 |
+
n_aligned = ((n + 15) // 16) * 16
|
| 165 |
+
|
| 166 |
+
# If already aligned, return the original tensor
|
| 167 |
+
if n == n_aligned:
|
| 168 |
+
return tensor
|
| 169 |
+
|
| 170 |
+
# Create a new tensor with shape (n_aligned, H, W) and fill it with zeros
|
| 171 |
+
new_tensor = torch.zeros(n_aligned,
|
| 172 |
+
*tensor.shape[1:],
|
| 173 |
+
dtype=tensor.dtype,
|
| 174 |
+
device=tensor.device)
|
| 175 |
+
|
| 176 |
+
# Copy the original tensor to the first N positions of the new tensor
|
| 177 |
+
new_tensor[:n] = tensor
|
| 178 |
+
|
| 179 |
+
return new_tensor
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def maybe_converting_weight_acl_format(model, format=ACL_FORMAT_FRACTAL_NZ):
|
| 183 |
+
# currently, there are some operations which do not support ACL_FORMAT_FRACTAL_NZ
|
| 184 |
+
# in eager mode but support it in torchair graph mode. since ACL_FORMAT_FRACTAL_NZ
|
| 185 |
+
# is much more preferred than ACL_FORMAT_FRACTAL_ND on 300I Duo, we add this
|
| 186 |
+
# conversion when using torchair graph mode on 300I Duo platform.
|
| 187 |
+
# TODO: we will remove this conversion if npu_quant_grouped_matmul_dequant
|
| 188 |
+
# accepts weight format of ACL_FORMAT_FRACTAL_NZ in eager mode.
|
| 189 |
+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
| 190 |
+
|
| 191 |
+
use_torchair = get_ascend_config().torchair_graph_config.enabled
|
| 192 |
+
if not is_310p() or not use_torchair:
|
| 193 |
+
return
|
| 194 |
+
for module in model.modules():
|
| 195 |
+
if isinstance(module, FusedMoE):
|
| 196 |
+
if torch_npu.get_npu_format(module.w13_weight.data) == format:
|
| 197 |
+
return
|
| 198 |
+
module.w13_weight.data = torch_npu.npu_format_cast(
|
| 199 |
+
module.w13_weight.data, format)
|
| 200 |
+
module.w2_weight.data = torch_npu.npu_format_cast(
|
| 201 |
+
module.w2_weight.data, format)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def try_register_lib(lib_name: str, lib_info: str = ""):
|
| 205 |
+
import importlib
|
| 206 |
+
import importlib.util
|
| 207 |
+
try:
|
| 208 |
+
module_spec = importlib.util.find_spec(lib_name)
|
| 209 |
+
if module_spec is not None:
|
| 210 |
+
importlib.import_module(lib_name)
|
| 211 |
+
if lib_info:
|
| 212 |
+
logger.info(lib_info)
|
| 213 |
+
except Exception:
|
| 214 |
+
pass
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def enable_custom_op():
|
| 218 |
+
"""
|
| 219 |
+
Enable lazy init for vllm_ascend_C to avoid early initialization of CANN's RTS component.
|
| 220 |
+
Ensure that ASCEND_RT_VISIBLE_DEVICES can be dynamically modified before torch.npu.set_device().
|
| 221 |
+
"""
|
| 222 |
+
global _CUSTOM_OP_ENABLED
|
| 223 |
+
if _CUSTOM_OP_ENABLED is not None:
|
| 224 |
+
return _CUSTOM_OP_ENABLED
|
| 225 |
+
try:
|
| 226 |
+
# register custom ops into torch_library here
|
| 227 |
+
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
|
| 228 |
+
_CUSTOM_OP_ENABLED = True
|
| 229 |
+
except ImportError:
|
| 230 |
+
_CUSTOM_OP_ENABLED = False
|
| 231 |
+
logger.warning(
|
| 232 |
+
"Warning: Failed to register custom ops, all custom ops will be disabled"
|
| 233 |
+
)
|
| 234 |
+
return _CUSTOM_OP_ENABLED
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def find_hccl_library() -> str:
|
| 238 |
+
"""
|
| 239 |
+
We either use the library file specified by the `HCCL_SO_PATH`
|
| 240 |
+
environment variable, or we find the library file brought by PyTorch.
|
| 241 |
+
After importing `torch`, `libhccl.so` can be
|
| 242 |
+
found by `ctypes` automatically.
|
| 243 |
+
"""
|
| 244 |
+
so_file = envs.HCCL_SO_PATH
|
| 245 |
+
|
| 246 |
+
# manually load the hccl library
|
| 247 |
+
if so_file:
|
| 248 |
+
logger.info("Found hccl from environment variable HCCL_SO_PATH=%s",
|
| 249 |
+
so_file)
|
| 250 |
+
else:
|
| 251 |
+
if torch.version.cann is not None:
|
| 252 |
+
so_file = "libhccl.so"
|
| 253 |
+
else:
|
| 254 |
+
raise ValueError("HCCL only supports Ascend NPU backends.")
|
| 255 |
+
logger.info("Found hccl from library %s", so_file)
|
| 256 |
+
return so_file
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def current_stream() -> torch.npu.Stream:
|
| 260 |
+
"""
|
| 261 |
+
replace `torch.npu.current_stream()` with `vllm.utils.current_stream()`.
|
| 262 |
+
it turns out that `torch.npu.current_stream()` is quite expensive,
|
| 263 |
+
as it will construct a new stream object at each call.
|
| 264 |
+
here we patch `torch.npu.set_stream` to keep track of the current stream
|
| 265 |
+
directly, so that we can avoid calling `torch.npu.current_stream()`.
|
| 266 |
+
|
| 267 |
+
"""
|
| 268 |
+
global _CURRENT_STREAM
|
| 269 |
+
if _CURRENT_STREAM is None:
|
| 270 |
+
# when this function is called before any stream is set,
|
| 271 |
+
# we return the default stream.
|
| 272 |
+
_CURRENT_STREAM = torch.npu.current_stream()
|
| 273 |
+
return _CURRENT_STREAM
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def adapt_patch(is_global_patch: bool = False):
|
| 277 |
+
if is_global_patch:
|
| 278 |
+
from vllm_ascend.patch import platform # noqa: F401
|
| 279 |
+
else:
|
| 280 |
+
from vllm_ascend.patch import worker # noqa: F401
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def vllm_version_is(target_vllm_version: str):
|
| 284 |
+
if envs.VLLM_VERSION is not None:
|
| 285 |
+
vllm_version = envs.VLLM_VERSION
|
| 286 |
+
else:
|
| 287 |
+
import vllm
|
| 288 |
+
vllm_version = vllm.__version__
|
| 289 |
+
try:
|
| 290 |
+
return Version(vllm_version) == Version(target_vllm_version)
|
| 291 |
+
except InvalidVersion:
|
| 292 |
+
raise ValueError(
|
| 293 |
+
f"Invalid vllm version {vllm_version} found. A dev version of vllm "
|
| 294 |
+
"is installed probably. Set the environment variable VLLM_VERSION "
|
| 295 |
+
"to control it by hand. And please make sure the value follows the "
|
| 296 |
+
"format of x.y.z.")
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
| 300 |
+
"""Update ACL graph capture sizes based on hardware limitations"""
|
| 301 |
+
# Store original configuration and temporarily clear it
|
| 302 |
+
compilation_config = vllm_config.compilation_config
|
| 303 |
+
original_sizes, compilation_config.cudagraph_capture_sizes = \
|
| 304 |
+
compilation_config.cudagraph_capture_sizes, None
|
| 305 |
+
|
| 306 |
+
# Calculate parallel configuration factor
|
| 307 |
+
num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
| 308 |
+
parallel_config = vllm_config.parallel_config
|
| 309 |
+
|
| 310 |
+
# TODO: Find out whether we need to take into account the pp_size
|
| 311 |
+
parallel_factor = 1 + sum(size > 1 for size in [
|
| 312 |
+
parallel_config.data_parallel_size_local,
|
| 313 |
+
parallel_config.tensor_parallel_size,
|
| 314 |
+
parallel_config.expert_parallel_size,
|
| 315 |
+
parallel_config.expert_tensor_parallel_size,
|
| 316 |
+
])
|
| 317 |
+
|
| 318 |
+
# Calculate maximum supported batch sizes considering model architecture
|
| 319 |
+
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
|
| 320 |
+
(num_hidden_layers + 1) / parallel_factor)
|
| 321 |
+
logger.info("Calculated maximum supported batch sizes for ACL graph: %s",
|
| 322 |
+
max_num_batch_sizes)
|
| 323 |
+
|
| 324 |
+
# If original sizes exceed maximum, sample a representative subset
|
| 325 |
+
if max_num_batch_sizes < len(original_sizes):
|
| 326 |
+
# Sample uniformly from original sizes
|
| 327 |
+
step = (len(original_sizes) - 1) / (max_num_batch_sizes - 1)
|
| 328 |
+
indices = [round(i * step) for i in range(max_num_batch_sizes)]
|
| 329 |
+
|
| 330 |
+
# Ensure first and last elements are preserved
|
| 331 |
+
indices[0], indices[-1] = 0, len(original_sizes) - 1
|
| 332 |
+
|
| 333 |
+
sampled_sizes = [original_sizes[i] for i in indices]
|
| 334 |
+
compilation_config.init_with_cudagraph_sizes(sampled_sizes)
|
| 335 |
+
|
| 336 |
+
logger.info(
|
| 337 |
+
"Adjusted ACL graph batch sizes for %s model (layers: %d): %d → %d sizes",
|
| 338 |
+
vllm_config.model_config.architectures[0],
|
| 339 |
+
num_hidden_layers,
|
| 340 |
+
len(original_sizes),
|
| 341 |
+
len(compilation_config.
|
| 342 |
+
cudagraph_capture_sizes # type: ignore[arg-type]
|
| 343 |
+
))
|
| 344 |
+
else:
|
| 345 |
+
# No adjustment needed
|
| 346 |
+
compilation_config.cudagraph_capture_sizes = original_sizes
|
| 347 |
+
logger.info(
|
| 348 |
+
"No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes",
|
| 349 |
+
vllm_config.model_config.architectures[0], num_hidden_layers,
|
| 350 |
+
len(original_sizes))
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
# TODO(wxy): Move to ops module
|
| 354 |
+
def dispose_tensor(x: torch.Tensor):
|
| 355 |
+
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
class ProfileExecuteDuration:
|
| 359 |
+
_instance = None
|
| 360 |
+
_observations: List[Tuple[str, Event, Event]] = []
|
| 361 |
+
_lock = Lock()
|
| 362 |
+
|
| 363 |
+
def __new__(cls):
|
| 364 |
+
with cls._lock:
|
| 365 |
+
if cls._instance is None:
|
| 366 |
+
cls._instance = super().__new__(cls)
|
| 367 |
+
atexit.register(cls._instance.destroy)
|
| 368 |
+
return cls._instance
|
| 369 |
+
|
| 370 |
+
def destroy(self):
|
| 371 |
+
with self._lock:
|
| 372 |
+
self._observations.clear()
|
| 373 |
+
|
| 374 |
+
@contextmanager
|
| 375 |
+
def capture_async(self, duration_tag: str):
|
| 376 |
+
if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
|
| 377 |
+
yield
|
| 378 |
+
return
|
| 379 |
+
|
| 380 |
+
observe_start = Event(enable_timing=True)
|
| 381 |
+
observe_start.record()
|
| 382 |
+
try:
|
| 383 |
+
yield
|
| 384 |
+
finally:
|
| 385 |
+
observe_end = Event(enable_timing=True)
|
| 386 |
+
observe_end.record()
|
| 387 |
+
with self._lock:
|
| 388 |
+
self._observations.append(
|
| 389 |
+
(duration_tag, observe_start, observe_end))
|
| 390 |
+
|
| 391 |
+
def pop_captured_sync(self) -> dict:
|
| 392 |
+
"""Pop and synchronize all events in the observation list"""
|
| 393 |
+
durations: dict[str, float] = {}
|
| 394 |
+
if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
|
| 395 |
+
return durations
|
| 396 |
+
|
| 397 |
+
while self._observations:
|
| 398 |
+
with self._lock:
|
| 399 |
+
tag, observe_start, observe_end = self._observations.pop()
|
| 400 |
+
observe_end.synchronize()
|
| 401 |
+
durations[tag] = observe_start.elapsed_time(observe_end)
|
| 402 |
+
|
| 403 |
+
return durations
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
# TODO(wxy): Move to ops module
|
| 407 |
+
def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
|
| 408 |
+
return _npu_stream_switch(tag, priority) if enabled else nullcontext()
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
# TODO(wxy): Move to ops module
|
| 412 |
+
def npu_wait_tensor(self: torch.Tensor,
|
| 413 |
+
dependency: torch.Tensor,
|
| 414 |
+
*,
|
| 415 |
+
enabled: bool = True):
|
| 416 |
+
return _npu_wait_tensor(self, dependency) if enabled else self
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
# TODO(wxy): Move to ops module
|
| 420 |
+
def npu_prefetch(input: torch.Tensor,
|
| 421 |
+
dependency: torch.Tensor,
|
| 422 |
+
max_size: int = 0,
|
| 423 |
+
*,
|
| 424 |
+
enabled: bool = True):
|
| 425 |
+
if not enabled:
|
| 426 |
+
return
|
| 427 |
+
input_size = input.element_size() * input.numel()
|
| 428 |
+
if max_size <= 0 or max_size > input_size:
|
| 429 |
+
max_size = input_size
|
| 430 |
+
torch_npu.npu_prefetch(input, dependency, max_size)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
# TODO(zzzzwwjj): move this into forward_context
|
| 434 |
+
class FusedMoEState(Enum):
|
| 435 |
+
AllGather = 0
|
| 436 |
+
All2All = 1
|
| 437 |
+
MC2 = 2
|
| 438 |
+
AllGatherEP = 3
|
| 439 |
+
NaiveMulticast = 4
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
# TODO(ttanzhiqiang): rm_router_logits
|
| 443 |
+
# dp>1 will trigger
|
| 444 |
+
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
|
| 445 |
+
def get_rm_router_logits_state(ep_size: int, dp_size: int,
|
| 446 |
+
is_deepseek_v3_r1: bool):
|
| 447 |
+
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
| 448 |
+
# only supports deepseek v3/r1
|
| 449 |
+
if dp_size > 1:
|
| 450 |
+
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
| 451 |
+
and is_deepseek_v3_r1):
|
| 452 |
+
return True
|
| 453 |
+
elif ep_size == 1 and is_deepseek_v3_r1:
|
| 454 |
+
return True
|
| 455 |
+
return False
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
# TODO(ttanzhiqiang): all_reduce merge
|
| 459 |
+
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
| 460 |
+
# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
|
| 461 |
+
def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
|
| 462 |
+
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
| 463 |
+
# only supports deepseek v3/r1
|
| 464 |
+
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
| 465 |
+
and is_deepseek_v3_r1):
|
| 466 |
+
return True
|
| 467 |
+
elif ep_size == 1 and is_deepseek_v3_r1:
|
| 468 |
+
return True
|
| 469 |
+
return False
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
# TODO(zzzzwwjj): add soc_version to choose branch
|
| 473 |
+
def get_fused_moe_state(ep_size: int, with_prefill: bool,
|
| 474 |
+
is_deepseek_v3_r1: bool):
|
| 475 |
+
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
| 476 |
+
# only supports deepseek v3/r1
|
| 477 |
+
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
| 478 |
+
and is_deepseek_v3_r1 and not with_prefill):
|
| 479 |
+
return FusedMoEState.AllGatherEP
|
| 480 |
+
elif ep_size == 1:
|
| 481 |
+
if with_prefill:
|
| 482 |
+
return FusedMoEState.NaiveMulticast
|
| 483 |
+
else:
|
| 484 |
+
return FusedMoEState.AllGather
|
| 485 |
+
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
|
| 486 |
+
elif ep_size < 16 or with_prefill:
|
| 487 |
+
return FusedMoEState.All2All
|
| 488 |
+
else:
|
| 489 |
+
return FusedMoEState.MC2
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
|
| 493 |
+
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
|
| 494 |
+
TORCHAIR_CACHE_PATH_NAME = ".torchair_cache"
|
| 495 |
+
TORCHAIR_CACHE_DIR = os.getenv(
|
| 496 |
+
'TORCHAIR_CACHE_HOME', os.path.join(os.getcwd(), TORCHAIR_CACHE_PATH_NAME))
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def get_torchair_current_work_dir(file_name=None):
|
| 500 |
+
if file_name is None:
|
| 501 |
+
return TORCHAIR_CACHE_DIR
|
| 502 |
+
return os.path.join(TORCHAIR_CACHE_DIR, file_name)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def check_torchair_cache_exist():
|
| 506 |
+
res = False
|
| 507 |
+
torch_air_abs_path = get_torchair_current_work_dir()
|
| 508 |
+
if os.path.exists(torch_air_abs_path):
|
| 509 |
+
file_list = os.listdir(torch_air_abs_path)
|
| 510 |
+
if len(file_list) != 0:
|
| 511 |
+
res = True
|
| 512 |
+
return res
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def check_kv_cache_bytes_cache_exist():
|
| 516 |
+
res = False
|
| 517 |
+
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
|
| 518 |
+
KV_CACHE_BYTES_CACHE_PATH_NAME)
|
| 519 |
+
if os.path.exists(kv_cache_bytes_cache_abs_path):
|
| 520 |
+
file_list = os.listdir(kv_cache_bytes_cache_abs_path)
|
| 521 |
+
if len(file_list) != 0:
|
| 522 |
+
res = True
|
| 523 |
+
return res
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def read_kv_cache_bytes_from_file(rank) -> int:
|
| 527 |
+
kv_cache_bytes = -1
|
| 528 |
+
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
|
| 529 |
+
KV_CACHE_BYTES_CACHE_PATH_NAME)
|
| 530 |
+
kv_cache_bytes_file = os.path.join(
|
| 531 |
+
kv_cache_bytes_cache_abs_path,
|
| 532 |
+
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
|
| 533 |
+
with open(kv_cache_bytes_file, "r", encoding="utf-8") as f:
|
| 534 |
+
with file_lock(f, fcntl.LOCK_SH):
|
| 535 |
+
kv_cache_bytes = int(f.readline())
|
| 536 |
+
return kv_cache_bytes
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
@contextmanager
|
| 540 |
+
def file_lock(file_descriptor, lock_type):
|
| 541 |
+
fcntl.flock(file_descriptor, lock_type)
|
| 542 |
+
try:
|
| 543 |
+
yield
|
| 544 |
+
finally:
|
| 545 |
+
fcntl.flock(file_descriptor, fcntl.LOCK_UN)
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def write_kv_cache_bytes_to_file(rank, kv_cache_bytes):
|
| 549 |
+
kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
|
| 550 |
+
KV_CACHE_BYTES_CACHE_PATH_NAME)
|
| 551 |
+
os.makedirs(kv_cache_bytes_cache_abs_path, exist_ok=True)
|
| 552 |
+
kv_cache_bytes_file = os.path.join(
|
| 553 |
+
kv_cache_bytes_cache_abs_path,
|
| 554 |
+
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
|
| 555 |
+
with open(kv_cache_bytes_file, "w", encoding="utf-8") as f:
|
| 556 |
+
with file_lock(f, fcntl.LOCK_EX):
|
| 557 |
+
f.write(f"{kv_cache_bytes}")
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def delete_torchair_cache_file():
|
| 561 |
+
torch_air_abs_path = get_torchair_current_work_dir()
|
| 562 |
+
if os.path.exists(torch_air_abs_path):
|
| 563 |
+
shutil.rmtree(torch_air_abs_path)
|
inference/vllm_ascend/worker/model_runner_v1.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
inference/vllm_ascend/worker/npu_input_batch.py
ADDED
|
@@ -0,0 +1,796 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
| 3 |
+
# Copyright 2023 The vLLM team.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
# This file is a part of the vllm-ascend project.
|
| 17 |
+
# Adapted from vllm-project/vllm/vllm/worker/gpu_input_batch.py
|
| 18 |
+
#
|
| 19 |
+
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Optional, cast, Union
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
from vllm.lora.request import LoRARequest
|
| 26 |
+
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
| 27 |
+
from vllm.pooling_params import PoolingParams
|
| 28 |
+
from vllm.sampling_params import SamplingParams, SamplingType
|
| 29 |
+
from vllm.utils import swap_dict_values
|
| 30 |
+
from vllm.v1.outputs import LogprobsTensors
|
| 31 |
+
from vllm.v1.sample.logits_processor import init_builtin_logitsprocs
|
| 32 |
+
from vllm.v1.sample.metadata import SamplingMetadata
|
| 33 |
+
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
|
| 34 |
+
from vllm.v1.utils import copy_slice
|
| 35 |
+
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
| 36 |
+
|
| 37 |
+
from vllm_ascend.pool.metadata import PoolingMetadata
|
| 38 |
+
|
| 39 |
+
_SAMPLING_EPS = 1e-5
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class CachedRequestState:
|
| 44 |
+
|
| 45 |
+
req_id: str
|
| 46 |
+
prompt_token_ids: list[int]
|
| 47 |
+
mm_inputs: list[MultiModalKwargs]
|
| 48 |
+
mm_positions: list[PlaceholderRange]
|
| 49 |
+
sampling_params: Optional[SamplingParams]
|
| 50 |
+
pooling_params: Optional[PoolingParams]
|
| 51 |
+
generator: Optional[torch.Generator]
|
| 52 |
+
|
| 53 |
+
block_ids: tuple[list[int], ...]
|
| 54 |
+
num_computed_tokens: int
|
| 55 |
+
output_token_ids: list[int]
|
| 56 |
+
|
| 57 |
+
mrope_positions: Optional[torch.Tensor] = None
|
| 58 |
+
mrope_position_delta: Optional[int] = None
|
| 59 |
+
|
| 60 |
+
lora_request: Optional[LoRARequest] = None
|
| 61 |
+
|
| 62 |
+
def __post_init__(self):
|
| 63 |
+
self.num_prompt_tokens = len(self.prompt_token_ids)
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def num_tokens(self) -> int:
|
| 67 |
+
return self.num_prompt_tokens + len(self.output_token_ids)
|
| 68 |
+
|
| 69 |
+
def get_token_id(self, idx: int) -> int:
|
| 70 |
+
if idx < self.num_prompt_tokens:
|
| 71 |
+
return self.prompt_token_ids[idx]
|
| 72 |
+
else:
|
| 73 |
+
return self.output_token_ids[idx - self.num_prompt_tokens]
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class SamplingMetadataTopNSigma(SamplingMetadata):
|
| 77 |
+
top_n_sigma: torch.Tensor
|
| 78 |
+
no_top_n_sigma: bool
|
| 79 |
+
|
| 80 |
+
class InputBatch:
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
max_num_reqs: int,
|
| 85 |
+
max_model_len: int,
|
| 86 |
+
max_num_batched_tokens: int,
|
| 87 |
+
device: torch.device,
|
| 88 |
+
pin_memory: bool,
|
| 89 |
+
vocab_size: int,
|
| 90 |
+
block_sizes: list[int], # The block_size of each kv cache group
|
| 91 |
+
logits_processing_needs_token_ids: bool = False,
|
| 92 |
+
is_spec_decode: bool = False,
|
| 93 |
+
):
|
| 94 |
+
self.is_spec_decode = is_spec_decode
|
| 95 |
+
self.max_num_reqs = max_num_reqs
|
| 96 |
+
self.max_model_len = max_model_len
|
| 97 |
+
self.max_num_batched_tokens = max_num_batched_tokens
|
| 98 |
+
self.device = device
|
| 99 |
+
self.pin_memory = pin_memory
|
| 100 |
+
self.vocab_size = vocab_size
|
| 101 |
+
self.logits_processing_needs_token_ids = (
|
| 102 |
+
logits_processing_needs_token_ids)
|
| 103 |
+
|
| 104 |
+
self._req_ids: list[Optional[str]] = []
|
| 105 |
+
self.req_id_to_index: dict[str, int] = {}
|
| 106 |
+
|
| 107 |
+
# TODO(woosuk): This buffer could be too large if max_model_len is big.
|
| 108 |
+
# Find a way to reduce the CPU memory usage.
|
| 109 |
+
# This buffer is not directly transferred to the NPU, so it does not
|
| 110 |
+
# need to be pinned.
|
| 111 |
+
self.token_ids_cpu_tensor = torch.zeros(
|
| 112 |
+
(max_num_reqs, max_model_len),
|
| 113 |
+
device="cpu",
|
| 114 |
+
dtype=torch.int32,
|
| 115 |
+
pin_memory=False,
|
| 116 |
+
)
|
| 117 |
+
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
| 118 |
+
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
| 119 |
+
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
| 120 |
+
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
| 121 |
+
self.num_computed_tokens_cpu_tensor = torch.zeros(
|
| 122 |
+
(max_num_reqs, ),
|
| 123 |
+
device="cpu",
|
| 124 |
+
dtype=torch.int32,
|
| 125 |
+
pin_memory=pin_memory,
|
| 126 |
+
)
|
| 127 |
+
self.num_computed_tokens_cpu = \
|
| 128 |
+
self.num_computed_tokens_cpu_tensor.numpy()
|
| 129 |
+
|
| 130 |
+
# Block table.
|
| 131 |
+
self.block_table = MultiGroupBlockTable(
|
| 132 |
+
max_num_reqs=max_num_reqs,
|
| 133 |
+
max_model_len=max_model_len,
|
| 134 |
+
max_num_batched_tokens=max_num_batched_tokens,
|
| 135 |
+
pin_memory=pin_memory,
|
| 136 |
+
device=device,
|
| 137 |
+
block_sizes=block_sizes,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Sampling-related.
|
| 141 |
+
self.temperature = torch.empty((max_num_reqs, ),
|
| 142 |
+
dtype=torch.float32,
|
| 143 |
+
device=device)
|
| 144 |
+
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 145 |
+
dtype=torch.float32,
|
| 146 |
+
device="cpu",
|
| 147 |
+
pin_memory=pin_memory)
|
| 148 |
+
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
|
| 149 |
+
self.greedy_reqs: set[str] = set()
|
| 150 |
+
self.random_reqs: set[str] = set()
|
| 151 |
+
|
| 152 |
+
self.top_p = torch.empty((max_num_reqs, ),
|
| 153 |
+
dtype=torch.float32,
|
| 154 |
+
device=device)
|
| 155 |
+
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 156 |
+
dtype=torch.float32,
|
| 157 |
+
device="cpu",
|
| 158 |
+
pin_memory=pin_memory)
|
| 159 |
+
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
|
| 160 |
+
self.top_p_reqs: set[str] = set()
|
| 161 |
+
|
| 162 |
+
self.top_k = torch.empty((max_num_reqs, ),
|
| 163 |
+
dtype=torch.int32,
|
| 164 |
+
device=device)
|
| 165 |
+
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 166 |
+
dtype=torch.int32,
|
| 167 |
+
device="cpu",
|
| 168 |
+
pin_memory=pin_memory)
|
| 169 |
+
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
| 170 |
+
self.top_k_reqs: set[str] = set()
|
| 171 |
+
|
| 172 |
+
# IDs of requests which do not support spec decoding
|
| 173 |
+
self.spec_decode_unsupported_reqs: set[str] = set()
|
| 174 |
+
|
| 175 |
+
self.min_p = torch.empty((max_num_reqs, ),
|
| 176 |
+
dtype=torch.float32,
|
| 177 |
+
device=device)
|
| 178 |
+
self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 179 |
+
dtype=torch.float32,
|
| 180 |
+
device="cpu",
|
| 181 |
+
pin_memory=pin_memory)
|
| 182 |
+
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
| 183 |
+
self.min_p_reqs: set[str] = set()
|
| 184 |
+
|
| 185 |
+
# topnsigma penalty
|
| 186 |
+
self.top_n_sigma = torch.empty((max_num_reqs, ),
|
| 187 |
+
dtype=torch.float,
|
| 188 |
+
device=device)
|
| 189 |
+
self.top_n_sigma_cpu_tensor = torch.empty(
|
| 190 |
+
(max_num_reqs, ),
|
| 191 |
+
dtype=torch.float,
|
| 192 |
+
device="cpu",
|
| 193 |
+
pin_memory=pin_memory)
|
| 194 |
+
self.top_n_sigma_cpu = \
|
| 195 |
+
self.top_n_sigma_cpu_tensor.numpy()
|
| 196 |
+
self.top_n_sigma_reqs: set[str] = set()
|
| 197 |
+
|
| 198 |
+
# Frequency penalty related data structures
|
| 199 |
+
self.frequency_penalties = torch.empty((max_num_reqs, ),
|
| 200 |
+
dtype=torch.float,
|
| 201 |
+
device=device)
|
| 202 |
+
self.frequency_penalties_cpu_tensor = torch.empty(
|
| 203 |
+
(max_num_reqs, ),
|
| 204 |
+
dtype=torch.float,
|
| 205 |
+
device="cpu",
|
| 206 |
+
pin_memory=pin_memory)
|
| 207 |
+
self.frequency_penalties_cpu = \
|
| 208 |
+
self.frequency_penalties_cpu_tensor.numpy()
|
| 209 |
+
self.frequency_penalties_reqs: set[str] = set()
|
| 210 |
+
|
| 211 |
+
# Presence penalty related data structures
|
| 212 |
+
self.presence_penalties = torch.empty((max_num_reqs, ),
|
| 213 |
+
dtype=torch.float,
|
| 214 |
+
device=device)
|
| 215 |
+
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
|
| 216 |
+
dtype=torch.float,
|
| 217 |
+
device="cpu",
|
| 218 |
+
pin_memory=pin_memory)
|
| 219 |
+
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
|
| 220 |
+
)
|
| 221 |
+
self.presence_penalties_reqs: set[str] = set()
|
| 222 |
+
|
| 223 |
+
# Repetition penalty related data structures
|
| 224 |
+
self.repetition_penalties = torch.empty((max_num_reqs, ),
|
| 225 |
+
dtype=torch.float,
|
| 226 |
+
device=device)
|
| 227 |
+
self.repetition_penalties_cpu_tensor = torch.empty(
|
| 228 |
+
(max_num_reqs, ),
|
| 229 |
+
dtype=torch.float,
|
| 230 |
+
device="cpu",
|
| 231 |
+
pin_memory=pin_memory)
|
| 232 |
+
self.repetition_penalties_cpu = \
|
| 233 |
+
self.repetition_penalties_cpu_tensor.numpy()
|
| 234 |
+
self.repetition_penalties_reqs: set[str] = set()
|
| 235 |
+
|
| 236 |
+
# req_index -> (min_tokens, stop_token_ids)
|
| 237 |
+
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
|
| 238 |
+
|
| 239 |
+
# lora related
|
| 240 |
+
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
| 241 |
+
dtype=np.int32)
|
| 242 |
+
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
| 243 |
+
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
| 244 |
+
|
| 245 |
+
# req_index -> generator
|
| 246 |
+
# NOTE(woosuk): The indices of the requests that do not have their own
|
| 247 |
+
# generator should not be included in the dictionary.
|
| 248 |
+
self.generators: dict[int, torch.Generator] = {}
|
| 249 |
+
|
| 250 |
+
self.num_logprobs: dict[str, int] = {}
|
| 251 |
+
# NOTE(rob): num_prompt_logprobs only includes reqs
|
| 252 |
+
# that are currently in the prefill phase.
|
| 253 |
+
self.num_prompt_logprobs: dict[str, int] = {}
|
| 254 |
+
|
| 255 |
+
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
| 256 |
+
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
| 257 |
+
|
| 258 |
+
self.logit_bias: list[Optional[dict[int,
|
| 259 |
+
float]]] = [None] * max_num_reqs
|
| 260 |
+
self.has_allowed_token_ids: set[str] = set()
|
| 261 |
+
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
|
| 262 |
+
# the value is False. Since we use masked_fill_ to set -inf.
|
| 263 |
+
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
|
| 264 |
+
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
|
| 265 |
+
|
| 266 |
+
# req_index -> bad_words_token_ids
|
| 267 |
+
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
| 268 |
+
|
| 269 |
+
self.req_output_token_ids: list[Optional[list[int]]] = []
|
| 270 |
+
|
| 271 |
+
# Define logits processors.
|
| 272 |
+
# TODO(andy): logits processor list should be extensible via engine
|
| 273 |
+
# constructor argument; for now the list is fixed.
|
| 274 |
+
self.logitsprocs = init_builtin_logitsprocs(
|
| 275 |
+
pin_memory_available=pin_memory,
|
| 276 |
+
max_num_reqs=max_num_reqs + 1,
|
| 277 |
+
device=device)
|
| 278 |
+
|
| 279 |
+
# This is updated each time the batch constituents change.
|
| 280 |
+
self.sampling_metadata = self._make_sampling_metadata()
|
| 281 |
+
|
| 282 |
+
self.pooling_params: dict[str, PoolingParams] = {}
|
| 283 |
+
|
| 284 |
+
@property
|
| 285 |
+
def req_ids(self) -> list[str]:
|
| 286 |
+
# None elements should only be present transiently
|
| 287 |
+
# while performing state updates to the batch.
|
| 288 |
+
return cast(list[str], self._req_ids)
|
| 289 |
+
|
| 290 |
+
def add_request(
|
| 291 |
+
self,
|
| 292 |
+
request: "CachedRequestState",
|
| 293 |
+
req_index: Optional[int] = None,
|
| 294 |
+
) -> None:
|
| 295 |
+
if req_index is None:
|
| 296 |
+
req_index = self.num_reqs
|
| 297 |
+
assert req_index < self.max_num_reqs
|
| 298 |
+
|
| 299 |
+
req_id = request.req_id
|
| 300 |
+
if req_index == len(self._req_ids):
|
| 301 |
+
self._req_ids.append(req_id)
|
| 302 |
+
self.req_output_token_ids.append(request.output_token_ids)
|
| 303 |
+
else:
|
| 304 |
+
self._req_ids[req_index] = req_id
|
| 305 |
+
self.req_output_token_ids[req_index] = request.output_token_ids
|
| 306 |
+
|
| 307 |
+
self.req_id_to_index[req_id] = req_index
|
| 308 |
+
|
| 309 |
+
# Copy the prompt token ids and output token ids.
|
| 310 |
+
num_prompt_tokens = len(request.prompt_token_ids)
|
| 311 |
+
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
| 312 |
+
self.token_ids_cpu[
|
| 313 |
+
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
| 314 |
+
start_idx = num_prompt_tokens
|
| 315 |
+
end_idx = start_idx + len(request.output_token_ids)
|
| 316 |
+
self.token_ids_cpu[req_index,
|
| 317 |
+
start_idx:end_idx] = request.output_token_ids
|
| 318 |
+
# Number of token ids in token_ids_cpu.
|
| 319 |
+
# NOTE(woosuk): This may include spec decode tokens.
|
| 320 |
+
self.num_tokens[req_index] = request.num_tokens
|
| 321 |
+
# Number of tokens without spec decode tokens.
|
| 322 |
+
self.num_tokens_no_spec[req_index] = request.num_tokens
|
| 323 |
+
|
| 324 |
+
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
| 325 |
+
self.block_table.add_row(request.block_ids, req_index)
|
| 326 |
+
|
| 327 |
+
if sampling_params := request.sampling_params:
|
| 328 |
+
if self.is_spec_decode and is_spec_decode_unsupported(
|
| 329 |
+
sampling_params):
|
| 330 |
+
self.spec_decode_unsupported_reqs.add(req_id)
|
| 331 |
+
if sampling_params.sampling_type == SamplingType.GREEDY:
|
| 332 |
+
# Avoid later division by zero.
|
| 333 |
+
self.temperature_cpu[req_index] = -1.0
|
| 334 |
+
self.greedy_reqs.add(req_id)
|
| 335 |
+
else:
|
| 336 |
+
self.temperature_cpu[req_index] = sampling_params.temperature
|
| 337 |
+
self.random_reqs.add(req_id)
|
| 338 |
+
|
| 339 |
+
self.top_p_cpu[req_index] = sampling_params.top_p
|
| 340 |
+
if sampling_params.top_p < 1:
|
| 341 |
+
self.top_p_reqs.add(req_id)
|
| 342 |
+
top_k = sampling_params.top_k
|
| 343 |
+
if 0 < top_k < self.vocab_size:
|
| 344 |
+
self.top_k_reqs.add(req_id)
|
| 345 |
+
else:
|
| 346 |
+
top_k = self.vocab_size
|
| 347 |
+
self.top_k_cpu[req_index] = top_k
|
| 348 |
+
self.min_p_cpu[req_index] = sampling_params.min_p
|
| 349 |
+
self.frequency_penalties_cpu[
|
| 350 |
+
req_index] = sampling_params.frequency_penalty
|
| 351 |
+
if sampling_params.min_p > _SAMPLING_EPS:
|
| 352 |
+
self.min_p_reqs.add(req_id)
|
| 353 |
+
if sampling_params.frequency_penalty != 0.0:
|
| 354 |
+
self.frequency_penalties_reqs.add(req_id)
|
| 355 |
+
self.presence_penalties_cpu[
|
| 356 |
+
req_index] = sampling_params.presence_penalty
|
| 357 |
+
if sampling_params.presence_penalty != 0.0:
|
| 358 |
+
self.presence_penalties_reqs.add(req_id)
|
| 359 |
+
self.repetition_penalties_cpu[
|
| 360 |
+
req_index] = sampling_params.repetition_penalty
|
| 361 |
+
if sampling_params.repetition_penalty != 1.0:
|
| 362 |
+
self.repetition_penalties_reqs.add(req_id)
|
| 363 |
+
if sampling_params.min_tokens:
|
| 364 |
+
self.min_tokens[req_index] = (
|
| 365 |
+
sampling_params.min_tokens,
|
| 366 |
+
sampling_params.all_stop_token_ids)
|
| 367 |
+
|
| 368 |
+
if sampling_params.extra_args and "top_n_sigma" in sampling_params.extra_args:
|
| 369 |
+
self.top_n_sigma_cpu[
|
| 370 |
+
req_index] = sampling_params.extra_args["top_n_sigma"]
|
| 371 |
+
self.top_n_sigma_reqs.add(req_id)
|
| 372 |
+
else:
|
| 373 |
+
self.top_n_sigma_cpu[req_index] = -1
|
| 374 |
+
|
| 375 |
+
# NOTE(woosuk): self.generators should not include the requests that
|
| 376 |
+
# do not have their own generator.
|
| 377 |
+
if request.generator is not None:
|
| 378 |
+
self.generators[req_index] = request.generator
|
| 379 |
+
|
| 380 |
+
if sampling_params.logprobs is not None:
|
| 381 |
+
self.num_logprobs[req_id] = sampling_params.logprobs
|
| 382 |
+
if sampling_params.prompt_logprobs is not None:
|
| 383 |
+
self.num_prompt_logprobs[
|
| 384 |
+
req_id] = sampling_params.prompt_logprobs
|
| 385 |
+
if sampling_params.logit_bias is not None:
|
| 386 |
+
self.logit_bias[req_index] = sampling_params.logit_bias
|
| 387 |
+
|
| 388 |
+
if sampling_params.allowed_token_ids:
|
| 389 |
+
self.has_allowed_token_ids.add(req_id)
|
| 390 |
+
if self.allowed_token_ids_mask_cpu_tensor is None:
|
| 391 |
+
# Lazy allocation for this tensor, which can be large.
|
| 392 |
+
# False means we don't fill with -inf.
|
| 393 |
+
self.allowed_token_ids_mask = torch.zeros(
|
| 394 |
+
self.max_num_reqs,
|
| 395 |
+
self.vocab_size,
|
| 396 |
+
dtype=torch.bool,
|
| 397 |
+
device=self.device)
|
| 398 |
+
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
|
| 399 |
+
self.max_num_reqs,
|
| 400 |
+
self.vocab_size,
|
| 401 |
+
dtype=torch.bool,
|
| 402 |
+
device="cpu")
|
| 403 |
+
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
|
| 404 |
+
# False means we don't fill with -inf.
|
| 405 |
+
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
| 406 |
+
sampling_params.allowed_token_ids] = False
|
| 407 |
+
|
| 408 |
+
if sampling_params.bad_words_token_ids:
|
| 409 |
+
self.bad_words_token_ids[
|
| 410 |
+
req_index] = sampling_params.bad_words_token_ids
|
| 411 |
+
else:
|
| 412 |
+
assert request.pooling_params is not None
|
| 413 |
+
self.pooling_params[req_id] = request.pooling_params
|
| 414 |
+
|
| 415 |
+
# Add request lora ID
|
| 416 |
+
if request.lora_request:
|
| 417 |
+
lora_id = request.lora_request.lora_int_id
|
| 418 |
+
if lora_id not in self.lora_id_to_request_ids:
|
| 419 |
+
self.lora_id_to_request_ids[lora_id] = set()
|
| 420 |
+
|
| 421 |
+
self.request_lora_mapping[req_index] = lora_id
|
| 422 |
+
self.lora_id_to_request_ids[lora_id].add(request.req_id)
|
| 423 |
+
self.lora_id_to_lora_request[lora_id] = request.lora_request
|
| 424 |
+
else:
|
| 425 |
+
# No LoRA
|
| 426 |
+
self.request_lora_mapping[req_index] = 0
|
| 427 |
+
|
| 428 |
+
def remove_request(self, req_id: str) -> Optional[int]:
|
| 429 |
+
"""This method must always be followed by a call to condense()."""
|
| 430 |
+
|
| 431 |
+
req_index = self.req_id_to_index.pop(req_id, None)
|
| 432 |
+
if req_index is None:
|
| 433 |
+
return None
|
| 434 |
+
self._req_ids[req_index] = None
|
| 435 |
+
self.req_output_token_ids[req_index] = None
|
| 436 |
+
|
| 437 |
+
self.greedy_reqs.discard(req_id)
|
| 438 |
+
self.random_reqs.discard(req_id)
|
| 439 |
+
self.top_p_reqs.discard(req_id)
|
| 440 |
+
self.top_k_reqs.discard(req_id)
|
| 441 |
+
self.min_p_reqs.discard(req_id)
|
| 442 |
+
self.min_tokens.pop(req_index, None)
|
| 443 |
+
self.frequency_penalties_reqs.discard(req_id)
|
| 444 |
+
self.presence_penalties_reqs.discard(req_id)
|
| 445 |
+
self.repetition_penalties_reqs.discard(req_id)
|
| 446 |
+
self.spec_decode_unsupported_reqs.discard(req_id)
|
| 447 |
+
self.top_n_sigma_reqs.discard(req_id)
|
| 448 |
+
self.generators.pop(req_index, None)
|
| 449 |
+
self.num_logprobs.pop(req_id, None)
|
| 450 |
+
self.num_prompt_logprobs.pop(req_id, None)
|
| 451 |
+
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
| 452 |
+
|
| 453 |
+
# LoRA
|
| 454 |
+
lora_id = self.request_lora_mapping[req_index]
|
| 455 |
+
if lora_id != 0:
|
| 456 |
+
self.lora_id_to_request_ids[lora_id].discard(req_id)
|
| 457 |
+
if len(self.lora_id_to_request_ids[lora_id]) == 0:
|
| 458 |
+
self.lora_id_to_request_ids.pop(lora_id)
|
| 459 |
+
self.lora_id_to_lora_request.pop(lora_id)
|
| 460 |
+
self.request_lora_mapping[req_index] = 0
|
| 461 |
+
|
| 462 |
+
self.logit_bias[req_index] = None
|
| 463 |
+
self.has_allowed_token_ids.discard(req_id)
|
| 464 |
+
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
| 465 |
+
# False means we don't fill with -inf.
|
| 466 |
+
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
| 467 |
+
self.bad_words_token_ids.pop(req_index, None)
|
| 468 |
+
self.pooling_params.pop(req_id, None)
|
| 469 |
+
return req_index
|
| 470 |
+
|
| 471 |
+
def swap_states(self, i1: int, i2: int) -> None:
|
| 472 |
+
old_id_i1 = self._req_ids[i1]
|
| 473 |
+
old_id_i2 = self._req_ids[i2]
|
| 474 |
+
self._req_ids[i1], self._req_ids[i2] =\
|
| 475 |
+
self._req_ids[i2], self._req_ids[i1] # noqa
|
| 476 |
+
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
|
| 477 |
+
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
|
| 478 |
+
assert old_id_i1 is not None and old_id_i2 is not None
|
| 479 |
+
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
|
| 480 |
+
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
| 481 |
+
self.num_tokens[i1], self.num_tokens[i2] =\
|
| 482 |
+
self.num_tokens[i2], self.num_tokens[i1]
|
| 483 |
+
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
|
| 484 |
+
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
|
| 485 |
+
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
|
| 486 |
+
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
|
| 487 |
+
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
|
| 488 |
+
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
|
| 489 |
+
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
|
| 490 |
+
self.temperature_cpu[i2], self.temperature_cpu[i1]
|
| 491 |
+
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
|
| 492 |
+
self.top_p_cpu[i2], self.top_p_cpu[i1]
|
| 493 |
+
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
|
| 494 |
+
self.top_k_cpu[i2], self.top_k_cpu[i1]
|
| 495 |
+
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
|
| 496 |
+
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
|
| 497 |
+
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
|
| 498 |
+
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
| 499 |
+
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
|
| 500 |
+
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
| 501 |
+
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
|
| 502 |
+
self.min_p_cpu[i2], self.min_p_cpu[i1]
|
| 503 |
+
self.top_n_sigma_cpu[i1], self.top_n_sigma_cpu[i2] =\
|
| 504 |
+
self.top_n_sigma_cpu[i2], self.top_n_sigma_cpu[i1]
|
| 505 |
+
|
| 506 |
+
# NOTE: the following is unsafe
|
| 507 |
+
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
| 508 |
+
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
| 509 |
+
# instead, we need to temporiarily copy the data for one of the indices
|
| 510 |
+
# TODO(lucas): optimize this by only copying valid indices
|
| 511 |
+
tmp = self.token_ids_cpu[i1, ...].copy()
|
| 512 |
+
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
| 513 |
+
self.token_ids_cpu[i2, ...] = tmp
|
| 514 |
+
|
| 515 |
+
swap_dict_values(self.generators, i1, i2)
|
| 516 |
+
swap_dict_values(self.min_tokens, i1, i2)
|
| 517 |
+
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
| 518 |
+
|
| 519 |
+
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
| 520 |
+
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
| 521 |
+
self.logit_bias[i1], self.logit_bias[i2] =\
|
| 522 |
+
self.logit_bias[i2], self.logit_bias[i1]
|
| 523 |
+
|
| 524 |
+
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
| 525 |
+
self.allowed_token_ids_mask_cpu_tensor[i1], \
|
| 526 |
+
self.allowed_token_ids_mask_cpu_tensor[i2] =\
|
| 527 |
+
self.allowed_token_ids_mask_cpu_tensor[i2], \
|
| 528 |
+
self.allowed_token_ids_mask_cpu_tensor[i1]
|
| 529 |
+
self.block_table.swap_row(i1, i2)
|
| 530 |
+
|
| 531 |
+
def condense(self, empty_req_indices: list[int]) -> None:
|
| 532 |
+
"""Move non-empty requests down into lower, empty indices.
|
| 533 |
+
|
| 534 |
+
Args:
|
| 535 |
+
empty_req_indices: empty batch indices, sorted descending.
|
| 536 |
+
"""
|
| 537 |
+
num_reqs = self.num_reqs
|
| 538 |
+
if num_reqs == 0:
|
| 539 |
+
# The batched states are empty.
|
| 540 |
+
self._req_ids.clear()
|
| 541 |
+
self.req_output_token_ids.clear()
|
| 542 |
+
return
|
| 543 |
+
|
| 544 |
+
# NOTE(woosuk): This function assumes that the empty_req_indices
|
| 545 |
+
# is sorted in descending order.
|
| 546 |
+
last_req_index = num_reqs + len(empty_req_indices) - 1
|
| 547 |
+
while empty_req_indices:
|
| 548 |
+
# Find the largest non-empty index.
|
| 549 |
+
while last_req_index in empty_req_indices:
|
| 550 |
+
last_req_index -= 1
|
| 551 |
+
|
| 552 |
+
# Find the smallest empty index.
|
| 553 |
+
empty_index = empty_req_indices.pop()
|
| 554 |
+
if empty_index >= last_req_index:
|
| 555 |
+
break
|
| 556 |
+
|
| 557 |
+
# Swap the states.
|
| 558 |
+
req_id = self._req_ids[last_req_index]
|
| 559 |
+
output_token_ids = self.req_output_token_ids[last_req_index]
|
| 560 |
+
assert req_id is not None
|
| 561 |
+
self._req_ids[empty_index] = req_id
|
| 562 |
+
self._req_ids[last_req_index] = None
|
| 563 |
+
self.req_output_token_ids[empty_index] = output_token_ids
|
| 564 |
+
self.req_output_token_ids[last_req_index] = None
|
| 565 |
+
self.req_id_to_index[req_id] = empty_index
|
| 566 |
+
|
| 567 |
+
num_tokens = self.num_tokens[last_req_index]
|
| 568 |
+
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
| 569 |
+
last_req_index, :num_tokens]
|
| 570 |
+
self.num_tokens[empty_index] = num_tokens
|
| 571 |
+
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
|
| 572 |
+
last_req_index]
|
| 573 |
+
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
|
| 574 |
+
last_req_index]
|
| 575 |
+
self.num_computed_tokens_cpu[
|
| 576 |
+
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
| 577 |
+
self.block_table.move_row(last_req_index, empty_index)
|
| 578 |
+
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
| 579 |
+
last_req_index]
|
| 580 |
+
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
| 581 |
+
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
| 582 |
+
self.frequency_penalties_cpu[
|
| 583 |
+
empty_index] = self.frequency_penalties_cpu[last_req_index]
|
| 584 |
+
self.presence_penalties_cpu[
|
| 585 |
+
empty_index] = self.presence_penalties_cpu[last_req_index]
|
| 586 |
+
self.repetition_penalties_cpu[
|
| 587 |
+
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
| 588 |
+
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
|
| 589 |
+
self.top_n_sigma_cpu[
|
| 590 |
+
empty_index] = self.top_n_sigma_cpu[last_req_index]
|
| 591 |
+
generator = self.generators.pop(last_req_index, None)
|
| 592 |
+
if generator is not None:
|
| 593 |
+
self.generators[empty_index] = generator
|
| 594 |
+
|
| 595 |
+
min_token = self.min_tokens.pop(last_req_index, None)
|
| 596 |
+
if min_token is not None:
|
| 597 |
+
self.min_tokens[empty_index] = min_token
|
| 598 |
+
|
| 599 |
+
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
| 600 |
+
last_req_index]
|
| 601 |
+
|
| 602 |
+
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
|
| 603 |
+
|
| 604 |
+
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
| 605 |
+
self.allowed_token_ids_mask_cpu_tensor[
|
| 606 |
+
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
|
| 607 |
+
last_req_index]
|
| 608 |
+
|
| 609 |
+
bad_words_token_ids = self.bad_words_token_ids.pop(
|
| 610 |
+
last_req_index, None)
|
| 611 |
+
if bad_words_token_ids is not None:
|
| 612 |
+
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
| 613 |
+
# Decrement last_req_index since it is now empty.
|
| 614 |
+
last_req_index -= 1
|
| 615 |
+
|
| 616 |
+
# Trim lists to the batch size.
|
| 617 |
+
del self._req_ids[self.num_reqs:]
|
| 618 |
+
del self.req_output_token_ids[self.num_reqs:]
|
| 619 |
+
|
| 620 |
+
def refresh_sampling_metadata(self):
|
| 621 |
+
self.sampling_metadata = self._make_sampling_metadata()
|
| 622 |
+
|
| 623 |
+
def _make_sampling_metadata(self) -> Union[SamplingMetadata, SamplingMetadataTopNSigma]:
|
| 624 |
+
num_reqs = self.num_reqs
|
| 625 |
+
if not self.all_greedy:
|
| 626 |
+
temperature = copy_slice(self.temperature_cpu_tensor,
|
| 627 |
+
self.temperature, num_reqs)
|
| 628 |
+
else:
|
| 629 |
+
temperature = None
|
| 630 |
+
if not self.no_top_p:
|
| 631 |
+
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
|
| 632 |
+
if not self.no_top_k:
|
| 633 |
+
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
|
| 634 |
+
if not self.no_min_p:
|
| 635 |
+
copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs)
|
| 636 |
+
|
| 637 |
+
if not self.no_penalties:
|
| 638 |
+
# Since syncing these tensors is expensive only copy them
|
| 639 |
+
# if necessary i.e. if there are requests which require
|
| 640 |
+
# penalties to be applied during sampling.
|
| 641 |
+
copy_slice(self.frequency_penalties_cpu_tensor,
|
| 642 |
+
self.frequency_penalties, num_reqs)
|
| 643 |
+
copy_slice(self.presence_penalties_cpu_tensor,
|
| 644 |
+
self.presence_penalties, num_reqs)
|
| 645 |
+
copy_slice(self.repetition_penalties_cpu_tensor,
|
| 646 |
+
self.repetition_penalties, num_reqs)
|
| 647 |
+
|
| 648 |
+
if not self.no_top_n_sigma:
|
| 649 |
+
copy_slice(self.top_n_sigma_cpu_tensor,
|
| 650 |
+
self.top_n_sigma, num_reqs)
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
needs_prompt_token_ids = (not self.no_penalties or
|
| 654 |
+
(self.num_reqs > 0
|
| 655 |
+
and self.logits_processing_needs_token_ids))
|
| 656 |
+
if needs_prompt_token_ids:
|
| 657 |
+
# The prompt tokens are used only for applying penalties or
|
| 658 |
+
# step pooling during the sampling/pooling process.
|
| 659 |
+
# Hence copy these tensors only when there are requests which
|
| 660 |
+
# need penalties/step_pooler to be applied.
|
| 661 |
+
prompt_token_ids = self._make_prompt_token_ids_tensor()
|
| 662 |
+
else:
|
| 663 |
+
prompt_token_ids = None
|
| 664 |
+
|
| 665 |
+
allowed_token_ids_mask: Optional[torch.Tensor] = None
|
| 666 |
+
if not self.no_allowed_token_ids:
|
| 667 |
+
assert self.allowed_token_ids_mask is not None
|
| 668 |
+
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
|
| 669 |
+
self.allowed_token_ids_mask, num_reqs)
|
| 670 |
+
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
|
| 671 |
+
|
| 672 |
+
return SamplingMetadataTopNSigma(
|
| 673 |
+
temperature=temperature,
|
| 674 |
+
all_greedy=self.all_greedy,
|
| 675 |
+
all_random=self.all_random,
|
| 676 |
+
top_p=None if self.no_top_p else self.top_p[:num_reqs],
|
| 677 |
+
top_k=None if self.no_top_k else self.top_k[:num_reqs],
|
| 678 |
+
generators=self.generators,
|
| 679 |
+
max_num_logprobs=self.max_num_logprobs,
|
| 680 |
+
prompt_token_ids=prompt_token_ids,
|
| 681 |
+
frequency_penalties=self.frequency_penalties[:num_reqs],
|
| 682 |
+
presence_penalties=self.presence_penalties[:num_reqs],
|
| 683 |
+
repetition_penalties=self.repetition_penalties[:num_reqs],
|
| 684 |
+
top_n_sigma=self.top_n_sigma[:num_reqs],
|
| 685 |
+
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
|
| 686 |
+
no_penalties=self.no_penalties,
|
| 687 |
+
no_top_n_sigma=self.no_top_n_sigma,
|
| 688 |
+
allowed_token_ids_mask=allowed_token_ids_mask,
|
| 689 |
+
bad_words_token_ids=self.bad_words_token_ids,
|
| 690 |
+
logitsprocs=self.logitsprocs,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
@property
|
| 694 |
+
def pooling_metadata(self) -> PoolingMetadata:
|
| 695 |
+
if len(self.pooling_params) == 0:
|
| 696 |
+
pooling_params = []
|
| 697 |
+
else:
|
| 698 |
+
# Note, for now this assumes that all request in the batch
|
| 699 |
+
# are either sampling or pooling requests
|
| 700 |
+
assert len(self.req_ids) == len(self.pooling_params)
|
| 701 |
+
pooling_params = [
|
| 702 |
+
self.pooling_params[req_id] for req_id in self.req_ids
|
| 703 |
+
]
|
| 704 |
+
|
| 705 |
+
return PoolingMetadata(
|
| 706 |
+
prompt_lens=torch.from_numpy(
|
| 707 |
+
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
|
| 708 |
+
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
| 709 |
+
pooling_params=pooling_params,
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
| 713 |
+
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
| 714 |
+
prompt_token_ids_cpu_tensor = torch.empty(
|
| 715 |
+
(self.num_reqs, max_prompt_len),
|
| 716 |
+
device="cpu",
|
| 717 |
+
dtype=torch.int64,
|
| 718 |
+
pin_memory=self.pin_memory,
|
| 719 |
+
)
|
| 720 |
+
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
|
| 721 |
+
prompt_token_ids[:] = self.token_ids_cpu[:self.
|
| 722 |
+
num_reqs, :max_prompt_len]
|
| 723 |
+
# Use the value of vocab_size as a pad since we don't have a
|
| 724 |
+
# token_id of this value.
|
| 725 |
+
for i in range(self.num_reqs):
|
| 726 |
+
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
|
| 727 |
+
return prompt_token_ids_cpu_tensor.to(device=self.device,
|
| 728 |
+
non_blocking=True)
|
| 729 |
+
|
| 730 |
+
def make_lora_inputs(
|
| 731 |
+
self, num_scheduled_tokens: np.ndarray
|
| 732 |
+
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
| 733 |
+
"""
|
| 734 |
+
Given the num_scheduled_tokens for each request in the batch, return
|
| 735 |
+
datastructures used to activate the current LoRAs.
|
| 736 |
+
Returns:
|
| 737 |
+
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
|
| 738 |
+
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
|
| 739 |
+
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
|
| 740 |
+
where, token_lora_mapping[i] is the LoRA id to use for ith token.
|
| 741 |
+
3. lora_requests: Set of relevant LoRA requests.
|
| 742 |
+
"""
|
| 743 |
+
|
| 744 |
+
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
|
| 745 |
+
prompt_lora_mapping = tuple(req_lora_mapping)
|
| 746 |
+
token_lora_mapping = tuple(
|
| 747 |
+
req_lora_mapping.repeat(num_scheduled_tokens))
|
| 748 |
+
active_lora_requests: set[LoRARequest] = set(
|
| 749 |
+
self.lora_id_to_lora_request.values())
|
| 750 |
+
|
| 751 |
+
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
| 752 |
+
|
| 753 |
+
@property
|
| 754 |
+
def num_reqs(self) -> int:
|
| 755 |
+
return len(self.req_id_to_index)
|
| 756 |
+
|
| 757 |
+
@property
|
| 758 |
+
def all_greedy(self) -> bool:
|
| 759 |
+
return len(self.random_reqs) == 0
|
| 760 |
+
|
| 761 |
+
@property
|
| 762 |
+
def all_random(self) -> bool:
|
| 763 |
+
return len(self.greedy_reqs) == 0
|
| 764 |
+
|
| 765 |
+
@property
|
| 766 |
+
def no_top_p(self) -> bool:
|
| 767 |
+
return len(self.top_p_reqs) == 0
|
| 768 |
+
|
| 769 |
+
@property
|
| 770 |
+
def no_top_k(self) -> bool:
|
| 771 |
+
return len(self.top_k_reqs) == 0
|
| 772 |
+
|
| 773 |
+
@property
|
| 774 |
+
def no_min_p(self) -> bool:
|
| 775 |
+
return len(self.min_p_reqs) == 0
|
| 776 |
+
|
| 777 |
+
@property
|
| 778 |
+
def no_penalties(self) -> bool:
|
| 779 |
+
return (len(self.presence_penalties_reqs) == 0
|
| 780 |
+
and len(self.frequency_penalties_reqs) == 0
|
| 781 |
+
and len(self.repetition_penalties_reqs) == 0)
|
| 782 |
+
@property
|
| 783 |
+
def no_top_n_sigma(self) -> bool:
|
| 784 |
+
return len(self.top_n_sigma_reqs) == 0
|
| 785 |
+
|
| 786 |
+
@property
|
| 787 |
+
def max_num_logprobs(self) -> Optional[int]:
|
| 788 |
+
return max(self.num_logprobs.values()) if self.num_logprobs else None
|
| 789 |
+
|
| 790 |
+
@property
|
| 791 |
+
def no_prompt_logprob(self) -> bool:
|
| 792 |
+
return not self.num_prompt_logprobs
|
| 793 |
+
|
| 794 |
+
@property
|
| 795 |
+
def no_allowed_token_ids(self) -> bool:
|
| 796 |
+
return len(self.has_allowed_token_ids) == 0
|
model-00002-of-000062.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1e29a512e3737d1826c80a2277a8b42021878847753aadbe5e1ae2a2df3d7f8d
|
| 3 |
+
size 1242564208
|
model-00003-of-000062.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3f63aa17d947032e0a524b5798eee3becbfc9a9b6f8a352ead3232e7b34bb289
|
| 3 |
+
size 1242564208
|
model-00005-of-000062.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4e3907f683d7f8382d2a792304155e8533ffa3a94dd4bb5ff825124b0dba3835
|
| 3 |
+
size 24650809648
|
model-00045-of-000062.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:97458250006f5949c65b338a26503738200c5fb2415f4cc664a6b224aa9dce70
|
| 3 |
+
size 24650810432
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_openpangu_moe.py
ADDED
|
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 3 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from typing import List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import torch.utils.checkpoint
|
| 22 |
+
from torch import nn
|
| 23 |
+
from transformers.activations import ACT2FN
|
| 24 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 25 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
| 26 |
+
from transformers.modeling_outputs import (
|
| 27 |
+
BaseModelOutputWithPast,
|
| 28 |
+
CausalLMOutputWithPast,
|
| 29 |
+
)
|
| 30 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 31 |
+
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
|
| 32 |
+
from transformers.utils.import_utils import is_torch_fx_available
|
| 33 |
+
|
| 34 |
+
from .configuration_openpangu_moe import PanguUltraMoEConfig
|
| 35 |
+
|
| 36 |
+
if is_torch_fx_available():
|
| 37 |
+
if not is_torch_greater_or_equal_than_1_13:
|
| 38 |
+
import torch.fx
|
| 39 |
+
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class PanguUltraMoERMSNorm(nn.Module):
|
| 43 |
+
def __init__(self, hidden_dim, epsilon=1e-5):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.weight = nn.Parameter(torch.empty(hidden_dim))
|
| 46 |
+
self.epsilon = epsilon
|
| 47 |
+
|
| 48 |
+
def forward(self, input_x):
|
| 49 |
+
origin_dtype = input_x.dtype
|
| 50 |
+
var = input_x.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
| 51 |
+
input_x = input_x * torch.rsqrt(var + self.epsilon)
|
| 52 |
+
output_x = self.weight * input_x
|
| 53 |
+
return output_x.to(origin_dtype)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class PanguUltraMoERotaryEmbedding(nn.Module):
|
| 57 |
+
def __init__(
|
| 58 |
+
self, dim, max_position_embeddings=131072, base=25600000.0, device=None
|
| 59 |
+
):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.dim = dim
|
| 62 |
+
self.max_position_embeddings = max_position_embeddings
|
| 63 |
+
self.base = base
|
| 64 |
+
self._set_cache(
|
| 65 |
+
seq_len=max_position_embeddings,
|
| 66 |
+
device=device,
|
| 67 |
+
dtype=torch.get_default_dtype(),
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def _set_cache(self, seq_len, device, dtype):
|
| 71 |
+
self.max_seq_len_cached = seq_len
|
| 72 |
+
dim = self.dim
|
| 73 |
+
|
| 74 |
+
inv_freq = 1.0 / (
|
| 75 |
+
self.base
|
| 76 |
+
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
| 77 |
+
)
|
| 78 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 79 |
+
|
| 80 |
+
t = torch.arange(seq_len, device=device, dtype=torch.float32)
|
| 81 |
+
|
| 82 |
+
freqs = torch.outer(t, inv_freq)
|
| 83 |
+
|
| 84 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 85 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 86 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 87 |
+
|
| 88 |
+
def forward(self, x, kv_len, max_seq_len=None):
|
| 89 |
+
if max_seq_len is None:
|
| 90 |
+
self._set_cache(seq_len=kv_len, device=x.device, dtype=x.dtype)
|
| 91 |
+
elif max_seq_len > self.max_seq_len_cached:
|
| 92 |
+
self._set_cache(seq_len=max_seq_len, device=x.device, dtype=x.dtype)
|
| 93 |
+
|
| 94 |
+
batch_size = x.shape[0]
|
| 95 |
+
seq_len = x.shape[1]
|
| 96 |
+
if seq_len == 1:
|
| 97 |
+
cos = (
|
| 98 |
+
torch.index_select(self.cos_cached, dim=0, index=kv_len)
|
| 99 |
+
.unsqueeze(1)
|
| 100 |
+
.unsqueeze(1)
|
| 101 |
+
)
|
| 102 |
+
sin = (
|
| 103 |
+
torch.index_select(self.sin_cached, dim=0, index=kv_len)
|
| 104 |
+
.unsqueeze(1)
|
| 105 |
+
.unsqueeze(1)
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
cos = (
|
| 109 |
+
self.cos_cached[:seq_len]
|
| 110 |
+
.unsqueeze(0)
|
| 111 |
+
.unsqueeze(2)
|
| 112 |
+
.repeat(batch_size, 1, 1, 1)
|
| 113 |
+
)
|
| 114 |
+
sin = (
|
| 115 |
+
self.sin_cached[:seq_len]
|
| 116 |
+
.unsqueeze(0)
|
| 117 |
+
.unsqueeze(2)
|
| 118 |
+
.repeat(batch_size, 1, 1, 1)
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
cos = cos[0, :, 0, :]
|
| 122 |
+
sin = sin[0, :, 0, :]
|
| 123 |
+
return (
|
| 124 |
+
cos.to(dtype=x.dtype),
|
| 125 |
+
sin.to(dtype=x.dtype),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def rotate_half(x):
|
| 130 |
+
"""Rotates half the hidden dims of the input."""
|
| 131 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 132 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 133 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
| 137 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
q (`torch.Tensor`): The query tensor.
|
| 141 |
+
k (`torch.Tensor`): The key tensor.
|
| 142 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 143 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 144 |
+
position_ids (`torch.Tensor`):
|
| 145 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
| 146 |
+
used to pass offsetted position ids when working with a KV-cache.
|
| 147 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 148 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 149 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 150 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 151 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 152 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 153 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 154 |
+
Returns:
|
| 155 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 156 |
+
"""
|
| 157 |
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
| 158 |
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
| 159 |
+
|
| 160 |
+
b, h, s, d = q.shape
|
| 161 |
+
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
| 162 |
+
|
| 163 |
+
b, h, s, d = k.shape
|
| 164 |
+
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
| 165 |
+
|
| 166 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 167 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 168 |
+
return q_embed, k_embed
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class MLP(nn.Module):
|
| 172 |
+
def __init__(self, config, hidden_size=None, intermediate_size=None):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
| 175 |
+
self.intermediate_size = (
|
| 176 |
+
config.intermediate_size if intermediate_size is None else intermediate_size
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 180 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 181 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 182 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 183 |
+
|
| 184 |
+
def forward(self, x):
|
| 185 |
+
output = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 186 |
+
return output
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class MoEGate(nn.Module):
|
| 190 |
+
def __init__(self, config):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.top_k = config.num_experts_per_tok
|
| 193 |
+
self.routed_scaling_factor = config.routed_scaling_factor
|
| 194 |
+
|
| 195 |
+
self.norm_topk_prob = config.norm_topk_prob
|
| 196 |
+
self.weight = nn.Parameter(
|
| 197 |
+
torch.empty((config.num_routed_experts, config.hidden_size))
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
def forward(self, hidden_states):
|
| 201 |
+
bsz, seq_len, h = hidden_states.shape
|
| 202 |
+
hidden_states = hidden_states.view(-1, h)
|
| 203 |
+
logits = F.linear(
|
| 204 |
+
hidden_states.to(torch.float32), self.weight.to(torch.float32), None
|
| 205 |
+
)
|
| 206 |
+
scores = logits.sigmoid()
|
| 207 |
+
scores_for_choice = scores.view(bsz * seq_len, -1)
|
| 208 |
+
_, topk_idx = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)
|
| 209 |
+
topk_weight = scores.gather(1, topk_idx)
|
| 210 |
+
|
| 211 |
+
if self.top_k > 1 and self.norm_topk_prob:
|
| 212 |
+
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
| 213 |
+
topk_weight = topk_weight / denominator
|
| 214 |
+
topk_weight = topk_weight * self.routed_scaling_factor
|
| 215 |
+
|
| 216 |
+
return topk_idx, topk_weight
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class PanguUltraMoE(nn.Module):
|
| 220 |
+
def __init__(self, config):
|
| 221 |
+
super().__init__()
|
| 222 |
+
self.num_shared_experts = config.num_shared_experts
|
| 223 |
+
self.num_routed_experts = config.num_routed_experts
|
| 224 |
+
self.experts = nn.ModuleList(
|
| 225 |
+
[
|
| 226 |
+
MLP(config, intermediate_size=config.moe_intermediate_size)
|
| 227 |
+
for i in range(self.num_routed_experts)
|
| 228 |
+
]
|
| 229 |
+
)
|
| 230 |
+
self.gate = MoEGate(config)
|
| 231 |
+
if self.num_shared_experts is not None:
|
| 232 |
+
intermediate_size = config.moe_intermediate_size * self.num_shared_experts
|
| 233 |
+
self.shared_experts = MLP(
|
| 234 |
+
config=config, intermediate_size=intermediate_size
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
def forward(self, hidden_states):
|
| 238 |
+
if self.num_shared_experts is not None:
|
| 239 |
+
shared_output = self.shared_experts(hidden_states)
|
| 240 |
+
input_shape = hidden_states.shape
|
| 241 |
+
topk_ids, topk_weight = self.gate(hidden_states)
|
| 242 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
| 243 |
+
counts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
| 244 |
+
counts.scatter_(1, topk_ids, 1)
|
| 245 |
+
tokens_per_expert = counts.sum(dim=0)
|
| 246 |
+
idxs = topk_ids.view(-1).argsort()
|
| 247 |
+
sorted_tokens = hidden_states[idxs // topk_ids.shape[1]]
|
| 248 |
+
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
| 249 |
+
|
| 250 |
+
output_hidden_states = []
|
| 251 |
+
start_idx = 0
|
| 252 |
+
for i, num_tokens in enumerate(tokens_per_expert):
|
| 253 |
+
end_idx = start_idx + num_tokens
|
| 254 |
+
if num_tokens == 0:
|
| 255 |
+
continue
|
| 256 |
+
expert = self.experts[i]
|
| 257 |
+
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
| 258 |
+
expert_out = expert(tokens_for_this_expert)
|
| 259 |
+
output_hidden_states.append(expert_out)
|
| 260 |
+
start_idx = end_idx
|
| 261 |
+
|
| 262 |
+
if len(output_hidden_states) > 0:
|
| 263 |
+
cat_hidden_states = torch.cat(output_hidden_states, dim=0)
|
| 264 |
+
else:
|
| 265 |
+
cat_hidden_states = sorted_tokens.new_empty(0)
|
| 266 |
+
|
| 267 |
+
final_hidden_states = torch.empty_like(cat_hidden_states)
|
| 268 |
+
final_hidden_states[idxs] = cat_hidden_states
|
| 269 |
+
final_out = final_hidden_states.view(*topk_ids.shape, -1).to(topk_weight.dtype)
|
| 270 |
+
final_out = (
|
| 271 |
+
final_out.mul_(topk_weight.unsqueeze(dim=-1))
|
| 272 |
+
.sum(dim=1)
|
| 273 |
+
.to(final_hidden_states.dtype)
|
| 274 |
+
).view(*input_shape)
|
| 275 |
+
if self.num_shared_experts is not None:
|
| 276 |
+
final_out = final_out + shared_output
|
| 277 |
+
return final_out
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class PanguUltraMoEAttention(nn.Module):
|
| 281 |
+
def __init__(self, config: PanguUltraMoEConfig, layer_idx: Optional[int] = None):
|
| 282 |
+
super().__init__()
|
| 283 |
+
self.layer_idx = layer_idx
|
| 284 |
+
|
| 285 |
+
self.attention_dropout = config.attention_dropout
|
| 286 |
+
self.hidden_size = config.hidden_size
|
| 287 |
+
self.num_heads = config.num_attention_heads
|
| 288 |
+
self.attention_q_lora_dim = config.attention_q_lora_dim
|
| 289 |
+
self.attention_qk_rope_dim = config.attention_qk_rope_dim
|
| 290 |
+
self.attention_kv_lora_dim = config.attention_kv_lora_dim
|
| 291 |
+
self.attention_v_dim = config.attention_v_dim
|
| 292 |
+
self.attention_qk_dim = config.attention_qk_dim
|
| 293 |
+
self.q_head_dim = config.attention_qk_dim + config.attention_qk_rope_dim
|
| 294 |
+
|
| 295 |
+
if self.attention_q_lora_dim is None:
|
| 296 |
+
self.q_proj = nn.Linear(
|
| 297 |
+
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
|
| 298 |
+
)
|
| 299 |
+
else:
|
| 300 |
+
self.q_a_proj = nn.Linear(
|
| 301 |
+
self.hidden_size, config.attention_q_lora_dim, bias=False
|
| 302 |
+
)
|
| 303 |
+
self.q_a_layernorm = PanguUltraMoERMSNorm(config.attention_q_lora_dim)
|
| 304 |
+
self.q_b_proj = nn.Linear(
|
| 305 |
+
config.attention_q_lora_dim,
|
| 306 |
+
self.num_heads * self.q_head_dim,
|
| 307 |
+
bias=False,
|
| 308 |
+
)
|
| 309 |
+
self.kv_a_proj_with_mqa = nn.Linear(
|
| 310 |
+
self.hidden_size,
|
| 311 |
+
config.attention_kv_lora_dim + config.attention_qk_rope_dim,
|
| 312 |
+
bias=False,
|
| 313 |
+
)
|
| 314 |
+
self.kv_a_layernorm = PanguUltraMoERMSNorm(config.attention_kv_lora_dim)
|
| 315 |
+
self.kv_b_proj = nn.Linear(
|
| 316 |
+
config.attention_kv_lora_dim,
|
| 317 |
+
self.num_heads * (config.attention_qk_dim + self.attention_v_dim),
|
| 318 |
+
bias=False,
|
| 319 |
+
)
|
| 320 |
+
self.o_proj = nn.Linear(
|
| 321 |
+
self.num_heads * self.attention_v_dim,
|
| 322 |
+
self.hidden_size,
|
| 323 |
+
bias=False,
|
| 324 |
+
)
|
| 325 |
+
self.rotary_emb = PanguUltraMoERotaryEmbedding(
|
| 326 |
+
self.attention_qk_rope_dim,
|
| 327 |
+
max_position_embeddings=config.max_position_embeddings,
|
| 328 |
+
base=config.rope_theta,
|
| 329 |
+
)
|
| 330 |
+
self.softmax_scale = self.q_head_dim ** (-0.5)
|
| 331 |
+
|
| 332 |
+
def forward(
|
| 333 |
+
self,
|
| 334 |
+
hidden_states: torch.Tensor,
|
| 335 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 336 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 337 |
+
past_key_value: Optional[Cache] = None,
|
| 338 |
+
output_attentions: bool = False,
|
| 339 |
+
use_cache: bool = False,
|
| 340 |
+
**kwargs,
|
| 341 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 342 |
+
bsz, q_len, _ = hidden_states.size()
|
| 343 |
+
|
| 344 |
+
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
| 345 |
+
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
| 346 |
+
q_nope, q_pe = torch.split(
|
| 347 |
+
q, [self.attention_qk_dim, self.attention_qk_rope_dim], dim=-1
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
latent_kv = self.kv_a_proj_with_mqa(hidden_states)
|
| 351 |
+
kv_a, k_pe = torch.split(
|
| 352 |
+
latent_kv, [self.attention_kv_lora_dim, self.attention_qk_rope_dim], dim=-1
|
| 353 |
+
)
|
| 354 |
+
k_pe = k_pe.view(bsz, q_len, 1, self.attention_qk_rope_dim).transpose(1, 2)
|
| 355 |
+
kv = (
|
| 356 |
+
self.kv_b_proj(self.kv_a_layernorm(kv_a))
|
| 357 |
+
.view(
|
| 358 |
+
bsz, q_len, self.num_heads, self.attention_qk_dim + self.attention_v_dim
|
| 359 |
+
)
|
| 360 |
+
.transpose(1, 2)
|
| 361 |
+
)
|
| 362 |
+
kv_seq_len = kv.shape[-2]
|
| 363 |
+
if past_key_value is not None:
|
| 364 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 365 |
+
cos, sin = self.rotary_emb(kv, kv_seq_len)
|
| 366 |
+
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
| 367 |
+
|
| 368 |
+
k_nope, value = torch.split(
|
| 369 |
+
kv, [self.attention_qk_dim, self.attention_v_dim], dim=-1
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
def concat_nope_pe(nope, pe):
|
| 373 |
+
states = torch.empty(
|
| 374 |
+
[bsz, self.num_heads, q_len, self.q_head_dim],
|
| 375 |
+
dtype=nope.dtype,
|
| 376 |
+
device=nope.device,
|
| 377 |
+
)
|
| 378 |
+
states[:, :, :, : self.attention_qk_dim] = nope
|
| 379 |
+
states[:, :, :, self.attention_qk_dim :] = pe
|
| 380 |
+
return states
|
| 381 |
+
|
| 382 |
+
query = concat_nope_pe(q_nope, q_pe)
|
| 383 |
+
key = concat_nope_pe(k_nope, k_pe)
|
| 384 |
+
|
| 385 |
+
if past_key_value is not None:
|
| 386 |
+
key, value = past_key_value.update(
|
| 387 |
+
key, value, self.layer_idx, {"sin": sin, "cos": cos}
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
attn_weights = (
|
| 391 |
+
torch.matmul(query, key.transpose(2, 3)) * self.softmax_scale
|
| 392 |
+
+ attention_mask
|
| 393 |
+
)
|
| 394 |
+
attn_weights = nn.functional.softmax(
|
| 395 |
+
attn_weights, dim=-1, dtype=torch.float32
|
| 396 |
+
).to(query.dtype)
|
| 397 |
+
attn_weights = nn.functional.dropout(
|
| 398 |
+
attn_weights, p=self.attention_dropout, training=self.training
|
| 399 |
+
)
|
| 400 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 401 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
|
| 402 |
+
attn_output = self.o_proj(attn_output)
|
| 403 |
+
|
| 404 |
+
return attn_output, past_key_value
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class PanguUltraMoEDecoderLayer(nn.Module):
|
| 408 |
+
def __init__(self, config: PanguUltraMoEConfig, layer_idx: int):
|
| 409 |
+
super().__init__()
|
| 410 |
+
self.hidden_size = config.hidden_size
|
| 411 |
+
self.self_attn = PanguUltraMoEAttention(config=config, layer_idx=layer_idx)
|
| 412 |
+
|
| 413 |
+
self.mlp = (
|
| 414 |
+
PanguUltraMoE(config)
|
| 415 |
+
if (
|
| 416 |
+
config.num_routed_experts is not None
|
| 417 |
+
and layer_idx >= config.num_dense_layers
|
| 418 |
+
)
|
| 419 |
+
else MLP(config)
|
| 420 |
+
)
|
| 421 |
+
self.input_layernorm = PanguUltraMoERMSNorm(
|
| 422 |
+
config.hidden_size, epsilon=config.rms_norm_eps
|
| 423 |
+
)
|
| 424 |
+
self.post_attention_layernorm = PanguUltraMoERMSNorm(
|
| 425 |
+
config.hidden_size, epsilon=config.rms_norm_eps
|
| 426 |
+
)
|
| 427 |
+
if getattr(config, "sandwich_norm", False):
|
| 428 |
+
self.sandwich_norm = True
|
| 429 |
+
self.pre_mlp_layernorm = PanguUltraMoERMSNorm(
|
| 430 |
+
config.hidden_size, epsilon=config.rms_norm_eps
|
| 431 |
+
)
|
| 432 |
+
self.post_mlp_layernorm = PanguUltraMoERMSNorm(
|
| 433 |
+
config.hidden_size, epsilon=config.rms_norm_eps
|
| 434 |
+
)
|
| 435 |
+
else:
|
| 436 |
+
self.sandwich_norm = False
|
| 437 |
+
|
| 438 |
+
def forward(
|
| 439 |
+
self,
|
| 440 |
+
hidden_states: torch.Tensor,
|
| 441 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 442 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 443 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 444 |
+
use_cache: Optional[bool] = False,
|
| 445 |
+
**kwargs,
|
| 446 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 447 |
+
residual = hidden_states
|
| 448 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 449 |
+
|
| 450 |
+
hidden_states, present_key_value = self.self_attn(
|
| 451 |
+
hidden_states=hidden_states,
|
| 452 |
+
attention_mask=attention_mask,
|
| 453 |
+
position_ids=position_ids,
|
| 454 |
+
past_key_value=past_key_value,
|
| 455 |
+
use_cache=use_cache,
|
| 456 |
+
**kwargs,
|
| 457 |
+
)
|
| 458 |
+
if self.sandwich_norm:
|
| 459 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 460 |
+
hidden_states = residual + hidden_states
|
| 461 |
+
residual = hidden_states
|
| 462 |
+
hidden_states = self.pre_mlp_layernorm(hidden_states)
|
| 463 |
+
else:
|
| 464 |
+
hidden_states = residual + hidden_states
|
| 465 |
+
residual = hidden_states
|
| 466 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 467 |
+
|
| 468 |
+
hidden_states = self.mlp(hidden_states)
|
| 469 |
+
|
| 470 |
+
if self.sandwich_norm:
|
| 471 |
+
hidden_states = self.post_mlp_layernorm(hidden_states)
|
| 472 |
+
hidden_states = residual + hidden_states
|
| 473 |
+
|
| 474 |
+
return (hidden_states, present_key_value)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
class PanguUltraMoEPreTrainedModel(PreTrainedModel):
|
| 478 |
+
config_class = PanguUltraMoEConfig
|
| 479 |
+
base_model_prefix = "model"
|
| 480 |
+
supports_gradient_checkpointing = True
|
| 481 |
+
_no_split_modules = ["PanguUltraMoEDecoderLayer"]
|
| 482 |
+
_skip_keys_device_placement = "past_key_values"
|
| 483 |
+
_supports_cache_class = True
|
| 484 |
+
|
| 485 |
+
def _init_weights(self, module):
|
| 486 |
+
std = self.config.initializer_range
|
| 487 |
+
self._initialize_linear(module, std)
|
| 488 |
+
self._initialize_embedding(module, std)
|
| 489 |
+
|
| 490 |
+
def _initialize_linear(self, module, std):
|
| 491 |
+
if isinstance(module, nn.Linear):
|
| 492 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 493 |
+
if module.bias is not None:
|
| 494 |
+
module.bias.data.zero_()
|
| 495 |
+
|
| 496 |
+
def _initialize_embedding(self, module, std):
|
| 497 |
+
if isinstance(module, nn.Embedding):
|
| 498 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 499 |
+
if module.padding_idx is not None:
|
| 500 |
+
module.weight.data[module.padding_idx].zero_()
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class PanguUltraMoEModel(PanguUltraMoEPreTrainedModel):
|
| 504 |
+
def __init__(self, config: PanguUltraMoEConfig):
|
| 505 |
+
super().__init__(config)
|
| 506 |
+
|
| 507 |
+
self.vocab_size = config.vocab_size
|
| 508 |
+
self.hidden_size = config.hidden_size
|
| 509 |
+
self.padding_idx = config.pad_token_id
|
| 510 |
+
self.layer_num = config.num_hidden_layers
|
| 511 |
+
self.epsilon = config.rms_norm_eps
|
| 512 |
+
|
| 513 |
+
self.embed_tokens = nn.Embedding(
|
| 514 |
+
self.vocab_size, self.hidden_size, self.padding_idx
|
| 515 |
+
)
|
| 516 |
+
self.layers = nn.ModuleList(
|
| 517 |
+
[PanguUltraMoEDecoderLayer(config, idx) for idx in range(self.layer_num)]
|
| 518 |
+
)
|
| 519 |
+
self.norm = PanguUltraMoERMSNorm(self.hidden_size, epsilon=self.epsilon)
|
| 520 |
+
self.gradient_checkpointing = False
|
| 521 |
+
|
| 522 |
+
self.post_init()
|
| 523 |
+
|
| 524 |
+
def get_input_embeddings(self):
|
| 525 |
+
return self.embed_tokens
|
| 526 |
+
|
| 527 |
+
def set_input_embeddings(self, value):
|
| 528 |
+
self.embed_tokens = value
|
| 529 |
+
|
| 530 |
+
def forward(
|
| 531 |
+
self,
|
| 532 |
+
input_ids: torch.LongTensor = None,
|
| 533 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 534 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 535 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 536 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 537 |
+
use_cache: Optional[bool] = None,
|
| 538 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 539 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 540 |
+
raise ValueError("You have to specify input_ids or inputs_embeds.")
|
| 541 |
+
|
| 542 |
+
if input_ids is not None:
|
| 543 |
+
hidden_states = self.embed_tokens(input_ids)
|
| 544 |
+
batch_size, seq_length = input_ids.size()
|
| 545 |
+
else:
|
| 546 |
+
hidden_states = inputs_embeds
|
| 547 |
+
batch_size, seq_length = inputs_embeds.size()
|
| 548 |
+
|
| 549 |
+
if position_ids is None:
|
| 550 |
+
position_ids = torch.arange(
|
| 551 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
| 552 |
+
).unsqueeze(0)
|
| 553 |
+
|
| 554 |
+
past_key_values_length = 0
|
| 555 |
+
if use_cache:
|
| 556 |
+
use_legacy_cache = not isinstance(past_key_values, Cache)
|
| 557 |
+
if use_legacy_cache:
|
| 558 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 559 |
+
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
| 560 |
+
position_ids += past_key_values_length
|
| 561 |
+
|
| 562 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
| 563 |
+
attention_mask,
|
| 564 |
+
(batch_size, seq_length),
|
| 565 |
+
hidden_states,
|
| 566 |
+
past_key_values_length,
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
for decoder_layer in self.layers:
|
| 570 |
+
hidden_states, present_key_value = decoder_layer(
|
| 571 |
+
hidden_states,
|
| 572 |
+
attention_mask=attention_mask,
|
| 573 |
+
position_ids=position_ids,
|
| 574 |
+
past_key_value=past_key_values,
|
| 575 |
+
use_cache=use_cache,
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
hidden_states = self.norm(hidden_states)
|
| 579 |
+
|
| 580 |
+
if use_cache and use_legacy_cache:
|
| 581 |
+
present_key_value = present_key_value.to_legacy_cache()
|
| 582 |
+
|
| 583 |
+
return BaseModelOutputWithPast(
|
| 584 |
+
last_hidden_state=hidden_states,
|
| 585 |
+
past_key_values=present_key_value,
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
class PanguUltraMoEForCausalLM(PanguUltraMoEPreTrainedModel):
|
| 590 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 591 |
+
|
| 592 |
+
def __init__(self, config):
|
| 593 |
+
super().__init__(config)
|
| 594 |
+
self.model = PanguUltraMoEModel(config)
|
| 595 |
+
self.vocab_size = config.vocab_size
|
| 596 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 597 |
+
|
| 598 |
+
self.post_init()
|
| 599 |
+
|
| 600 |
+
def get_input_embeddings(self):
|
| 601 |
+
return self.model.embed_tokens
|
| 602 |
+
|
| 603 |
+
def set_input_embeddings(self, value):
|
| 604 |
+
self.model.embed_tokens = value
|
| 605 |
+
|
| 606 |
+
def get_output_embeddings(self):
|
| 607 |
+
return self.lm_head
|
| 608 |
+
|
| 609 |
+
def set_output_embeddings(self, new_embeddings):
|
| 610 |
+
self.lm_head = new_embeddings
|
| 611 |
+
|
| 612 |
+
def set_decoder(self, decoder):
|
| 613 |
+
self.model = decoder
|
| 614 |
+
|
| 615 |
+
def get_decoder(self):
|
| 616 |
+
return self.model
|
| 617 |
+
|
| 618 |
+
def forward(
|
| 619 |
+
self,
|
| 620 |
+
input_ids: torch.LongTensor = None,
|
| 621 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 622 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 623 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 624 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 625 |
+
labels: Optional[torch.LongTensor] = None,
|
| 626 |
+
use_cache: Optional[bool] = None,
|
| 627 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 628 |
+
|
| 629 |
+
outputs = self.model(
|
| 630 |
+
input_ids=input_ids,
|
| 631 |
+
attention_mask=attention_mask,
|
| 632 |
+
position_ids=position_ids,
|
| 633 |
+
past_key_values=past_key_values,
|
| 634 |
+
inputs_embeds=inputs_embeds,
|
| 635 |
+
use_cache=use_cache,
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
logits = self.lm_head(outputs[0])
|
| 639 |
+
logits = logits.float()
|
| 640 |
+
|
| 641 |
+
loss = None
|
| 642 |
+
if labels is not None:
|
| 643 |
+
loss = self.loss_function(
|
| 644 |
+
logits=logits, labels=labels, vocab_size=self.vocab_size
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
return CausalLMOutputWithPast(
|
| 648 |
+
loss=loss,
|
| 649 |
+
logits=logits,
|
| 650 |
+
past_key_values=outputs.past_key_values,
|
| 651 |
+
hidden_states=outputs.hidden_states,
|
| 652 |
+
attentions=outputs.attentions,
|
| 653 |
+
)
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "[unused10]",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "<unk>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"unk_token": {
|
| 24 |
+
"content": "<unk>",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
}
|
| 30 |
+
}
|
tokenization_openpangu.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
|
| 3 |
+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 6 |
+
# and OPT implementations in this library. It has been modified from its
|
| 7 |
+
# original forms to accommodate minor architectural differences compared
|
| 8 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 9 |
+
#
|
| 10 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 11 |
+
# you may not use this file except in compliance with the License.
|
| 12 |
+
# You may obtain a copy of the License at
|
| 13 |
+
#
|
| 14 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 15 |
+
#
|
| 16 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 17 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 18 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 19 |
+
# See the License for the specific language governing permissions and
|
| 20 |
+
# limitations under the License.
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
from shutil import copyfile
|
| 24 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 25 |
+
|
| 26 |
+
import sentencepiece as spm
|
| 27 |
+
|
| 28 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 29 |
+
from transformers.utils import logging
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
|
| 35 |
+
|
| 36 |
+
PRETRAINED_VOCAB_FILES_MAP = {}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def convert_bool(string):
|
| 40 |
+
if isinstance(string, str):
|
| 41 |
+
if string.lower() == "true":
|
| 42 |
+
return True
|
| 43 |
+
elif string.lower() == "false":
|
| 44 |
+
return False
|
| 45 |
+
else:
|
| 46 |
+
return string
|
| 47 |
+
else:
|
| 48 |
+
return string
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class PanguUltraMoETokenizer(PreTrainedTokenizer):
|
| 52 |
+
"""
|
| 53 |
+
Construct a tokenizer. Based on byte-level Byte-Pair-Encoding.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
vocab_file (`str`):
|
| 57 |
+
Path to the vocabulary file.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 61 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
| 62 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 63 |
+
_auto_class = "AutoTokenizer"
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
vocab_file,
|
| 68 |
+
unk_token="<unk>",
|
| 69 |
+
bos_token="<s>",
|
| 70 |
+
eos_token="</s>",
|
| 71 |
+
pad_token="</s>",
|
| 72 |
+
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
| 73 |
+
add_bos_token=True,
|
| 74 |
+
add_eos_token=False,
|
| 75 |
+
decode_with_prefix_space=False,
|
| 76 |
+
clean_up_tokenization_spaces=False,
|
| 77 |
+
**kwargs,
|
| 78 |
+
):
|
| 79 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
| 80 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 81 |
+
self.sp_model.Load(vocab_file)
|
| 82 |
+
super().__init__(
|
| 83 |
+
bos_token=bos_token,
|
| 84 |
+
eos_token=eos_token,
|
| 85 |
+
unk_token=unk_token,
|
| 86 |
+
pad_token=pad_token,
|
| 87 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 88 |
+
**kwargs,
|
| 89 |
+
)
|
| 90 |
+
self.vocab_file = vocab_file
|
| 91 |
+
self.add_bos_token = convert_bool(add_bos_token)
|
| 92 |
+
self.add_eos_token = add_eos_token
|
| 93 |
+
self.decode_with_prefix_space = decode_with_prefix_space
|
| 94 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 95 |
+
self.sp_model.Load(vocab_file)
|
| 96 |
+
self._no_prefix_space_tokens = None
|
| 97 |
+
|
| 98 |
+
""" Initialisation"""
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def no_prefix_space_tokens(self):
|
| 102 |
+
if self._no_prefix_space_tokens is None:
|
| 103 |
+
vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
|
| 104 |
+
self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("▁")}
|
| 105 |
+
return self._no_prefix_space_tokens
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def vocab_size(self):
|
| 109 |
+
"""Returns vocab size"""
|
| 110 |
+
return self.sp_model.get_piece_size()
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def bos_token_id(self) -> Optional[int]:
|
| 114 |
+
return self.sp_model.bos_id()
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def eos_token_id(self) -> Optional[int]:
|
| 118 |
+
return super().eos_token_id
|
| 119 |
+
|
| 120 |
+
def get_vocab(self):
|
| 121 |
+
"""Returns vocab as a dict"""
|
| 122 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 123 |
+
vocab.update(self.added_tokens_encoder)
|
| 124 |
+
return vocab
|
| 125 |
+
|
| 126 |
+
def _tokenize(self, text):
|
| 127 |
+
"""Returns a tokenized string."""
|
| 128 |
+
return self.sp_model.encode(text, out_type=str)
|
| 129 |
+
|
| 130 |
+
def _convert_token_to_id(self, token):
|
| 131 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 132 |
+
return self.sp_model.piece_to_id(token)
|
| 133 |
+
|
| 134 |
+
def _convert_id_to_token(self, index):
|
| 135 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 136 |
+
token = self.sp_model.IdToPiece(index)
|
| 137 |
+
return token
|
| 138 |
+
|
| 139 |
+
def _maybe_add_prefix_space(self, tokens, decoded):
|
| 140 |
+
if tokens and tokens[0] not in self.no_prefix_space_tokens:
|
| 141 |
+
return " " + decoded
|
| 142 |
+
else:
|
| 143 |
+
return decoded
|
| 144 |
+
|
| 145 |
+
def convert_tokens_to_string(self, tokens):
|
| 146 |
+
"""Converts a sequence of tokens (string) in a single string."""
|
| 147 |
+
current_sub_tokens = []
|
| 148 |
+
out_string = ""
|
| 149 |
+
prev_is_special = False
|
| 150 |
+
for token in tokens:
|
| 151 |
+
# make sure that special tokens are not decoded using sentencepiece model
|
| 152 |
+
if token in self.all_special_tokens:
|
| 153 |
+
# Decode the current sub-tokens first
|
| 154 |
+
if current_sub_tokens:
|
| 155 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
| 156 |
+
current_sub_tokens = []
|
| 157 |
+
# Append the special token without adding extra spaces
|
| 158 |
+
out_string += token
|
| 159 |
+
prev_is_special = True
|
| 160 |
+
else:
|
| 161 |
+
current_sub_tokens.append(token)
|
| 162 |
+
prev_is_special = False
|
| 163 |
+
# Decode any remaining sub-tokens
|
| 164 |
+
if current_sub_tokens:
|
| 165 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
| 166 |
+
# Clean up leading and trailing spaces
|
| 167 |
+
if self.clean_up_tokenization_spaces:
|
| 168 |
+
out_string = self.clean_up_tokenization(out_string)
|
| 169 |
+
out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
|
| 170 |
+
return out_string[1:]
|
| 171 |
+
|
| 172 |
+
# Override decode to set spaces_between_special_tokens to True as default
|
| 173 |
+
def decode(self,
|
| 174 |
+
token_ids,
|
| 175 |
+
spaces_between_special_tokens: bool = False,
|
| 176 |
+
**kwargs):
|
| 177 |
+
return super().decode(
|
| 178 |
+
token_ids=token_ids,
|
| 179 |
+
spaces_between_special_tokens=spaces_between_special_tokens,
|
| 180 |
+
**kwargs,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 184 |
+
"""
|
| 185 |
+
Save the vocabulary and special tokens file to a directory.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
save_directory (`str`):
|
| 189 |
+
The directory in which to save the vocabulary.
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
`Tuple(str)`: Paths to the files saved.
|
| 193 |
+
"""
|
| 194 |
+
if not os.path.isdir(save_directory):
|
| 195 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 196 |
+
return ("",)
|
| 197 |
+
out_vocab_file = os.path.join(
|
| 198 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 202 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 203 |
+
elif not os.path.isfile(self.vocab_file):
|
| 204 |
+
with open(out_vocab_file, "wb") as fi:
|
| 205 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 206 |
+
fi.write(content_spiece_model)
|
| 207 |
+
|
| 208 |
+
return (out_vocab_file,)
|
| 209 |
+
|
| 210 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
| 211 |
+
if self.add_bos_token:
|
| 212 |
+
bos_token_ids = [self.bos_token_id]
|
| 213 |
+
else:
|
| 214 |
+
bos_token_ids = []
|
| 215 |
+
|
| 216 |
+
output = bos_token_ids + token_ids_0
|
| 217 |
+
|
| 218 |
+
if token_ids_1 is not None:
|
| 219 |
+
output = output + token_ids_1
|
| 220 |
+
|
| 221 |
+
if self.add_eos_token:
|
| 222 |
+
output = output + [self.eos_token_id]
|
| 223 |
+
|
| 224 |
+
return output
|
| 225 |
+
|
| 226 |
+
def get_special_tokens_mask(
|
| 227 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 228 |
+
) -> List[int]:
|
| 229 |
+
"""
|
| 230 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 231 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
token_ids_0 (`List[int]`):
|
| 235 |
+
List of IDs.
|
| 236 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 237 |
+
Optional second list of IDs for sequence pairs.
|
| 238 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 239 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 243 |
+
"""
|
| 244 |
+
if already_has_special_tokens:
|
| 245 |
+
return super().get_special_tokens_mask(
|
| 246 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
if token_ids_1 is None:
|
| 250 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 251 |
+
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
|
| 252 |
+
|
| 253 |
+
def create_token_type_ids_from_sequences(
|
| 254 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 255 |
+
) -> List[int]:
|
| 256 |
+
"""
|
| 257 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
|
| 258 |
+
use of token type ids, therefore a list of zeros is returned.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
token_ids_0 (`List[int]`):
|
| 262 |
+
List of IDs.
|
| 263 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 264 |
+
Optional second list of IDs for sequence pairs.
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
`List[int]`: List of zeros.
|
| 268 |
+
"""
|
| 269 |
+
eos = [self.eos_token_id]
|
| 270 |
+
|
| 271 |
+
if token_ids_1 is None:
|
| 272 |
+
return len(token_ids_0 + eos) * [0]
|
| 273 |
+
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"add_bos_token": false, "add_eos_token": false, "add_prefix_space": true, "added_tokens_decoder": {"0": {"content": "<unk>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "1": {"content": "<s>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "2": {"content": "</s>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45806": {"content": "<|User|>:", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45813": {"content": "<|Bot|>:", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45830": {"content": "[unused0]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45840": {"content": "[unused1]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45846": {"content": "[unused2]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45849": {"content": "[unused3]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45861": {"content": "[unused4]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45866": {"content": "[unused5]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45874": {"content": "[unused6]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45883": {"content": "[unused7]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45884": {"content": "[unused8]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45887": {"content": "[unused9]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45892": {"content": "[unused10]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45920": {"content": "[unused11]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45932": {"content": "[unused12]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45938": {"content": "[unused13]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45953": {"content": "[unused14]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45968": {"content": "[unused15]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45974": {"content": "[unused16]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45982": {"content": "[unused17]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45986": {"content": "[unused18]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46005": {"content": "[unused19]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46007": {"content": "[unused20]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46014": {"content": "[unused21]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46017": {"content": "[unused22]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46028": {"content": "[unused23]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46032": {"content": "[unused24]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46081": {"content": "[unused25]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46086": {"content": "[unused26]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46101": {"content": "[unused27]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46183": {"content": "[unused28]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46230": {"content": "[unused29]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46245": {"content": "[unused30]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46257": {"content": "[unused31]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "144208": {"content": "[unused32]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "144209": {"content": "[unused33]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}}, "auto_map": {"AutoTokenizer": ["tokenization_openpangu.PanguUltraMoETokenizer", null]}, "bos_token": "<s>", "clean_up_tokenization_spaces": false, "eos_token": "[unused10]", "legacy": true, "model_max_length": 1000000000000000019884624838656, "pad_token": "<unk>", "sp_model_kwargs": {}, "spaces_between_special_tokens": false, "tokenizer_class": "PanguUltraMoETokenizer", "unk_token": "<unk>", "use_default_system_prompt": false, "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '[unused9]系统:[unused10]' }}{% endif %}{% if message['role'] == 'system' %}{{ '[unused9]系统:' + message['content'] + '[unused10]' }}{% endif %}{% if message['role'] == 'assistant' %}{{'[unused9]助手:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'tool' %}{{'[unused9]工具:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'function' %}{{'[unused9]方法:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'user' %}{{'[unused9]用户:' + message['content'] + '[unused10]'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[unused9]助手:' }}{% endif %}"}
|