wangrongsheng commited on
Commit
a86f2f6
·
verified ·
1 Parent(s): 621e54c

Add files using upload-large-folder tool

Browse files
Files changed (49) hide show
  1. .gitignore +1 -0
  2. LICENSE +34 -0
  3. README.md +108 -6
  4. README_EN.md +111 -0
  5. checklist.chk +99 -0
  6. config.json +40 -0
  7. configuration_openpangu_moe.py +82 -0
  8. doc/docker.md +31 -0
  9. doc/docker_EN.md +31 -0
  10. doc/vllm_ascend_for_openpangu_ultra_moe_718b.md +215 -0
  11. doc/vllm_ascend_for_openpangu_ultra_moe_718b_EN.md +216 -0
  12. generation_config.json +11 -0
  13. inference/generate.py +106 -0
  14. inference/generate.sh +70 -0
  15. inference/model.py +918 -0
  16. inference/runner.py +411 -0
  17. inference/runner_config/tp1.yaml +30 -0
  18. inference/runner_config/tp32.yaml +30 -0
  19. inference/split_weight.py +387 -0
  20. inference/split_weight.sh +13 -0
  21. inference/vllm_ascend/_build_info.py +3 -0
  22. inference/vllm_ascend/attention/attention.py +1220 -0
  23. inference/vllm_ascend/attention/mla_v1.py +1224 -0
  24. inference/vllm_ascend/entrypoints/openai/reasoning_parsers/__init__.py +6 -0
  25. inference/vllm_ascend/entrypoints/openai/reasoning_parsers/pangu_reasoning_parser.py +171 -0
  26. inference/vllm_ascend/entrypoints/openai/tool_parsers/__init__.py +6 -0
  27. inference/vllm_ascend/entrypoints/openai/tool_parsers/pangu_tool_parser.py +300 -0
  28. inference/vllm_ascend/envs.py +153 -0
  29. inference/vllm_ascend/models/__init__.py +68 -0
  30. inference/vllm_ascend/models/open_pangu.py +1127 -0
  31. inference/vllm_ascend/ops/fused_moe.py +1530 -0
  32. inference/vllm_ascend/patch/worker/patch_common/__init__.py +27 -0
  33. inference/vllm_ascend/patch/worker/patch_common/patch_config.py +97 -0
  34. inference/vllm_ascend/patch/worker/patch_common/patch_parsers.py +26 -0
  35. inference/vllm_ascend/patch/worker/patch_common/patch_sampler.py +159 -0
  36. inference/vllm_ascend/quantization/w8a8.py +757 -0
  37. inference/vllm_ascend/quantization/w8a8_dynamic.py +831 -0
  38. inference/vllm_ascend/utils.py +563 -0
  39. inference/vllm_ascend/worker/model_runner_v1.py +0 -0
  40. inference/vllm_ascend/worker/npu_input_batch.py +796 -0
  41. model-00002-of-000062.safetensors +3 -0
  42. model-00003-of-000062.safetensors +3 -0
  43. model-00005-of-000062.safetensors +3 -0
  44. model-00045-of-000062.safetensors +3 -0
  45. model.safetensors.index.json +0 -0
  46. modeling_openpangu_moe.py +653 -0
  47. special_tokens_map.json +30 -0
  48. tokenization_openpangu.py +273 -0
  49. tokenizer_config.json +1 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .DS_Store
LICENSE ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0
2
+
3
+ This OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0 (the "Agreement") is a legal agreement between You and Huawei Technologies Co., Ltd. ("Huawei", "We" or "Us"), and it governs Your reproducing, use, modification, and distribution of openPangu as made available by Huawei under this Agreement.
4
+
5
+ By using, reproducing, modifying, distributing, performing or displaying any portion or element of openPangu, or otherwise accepting the terms of this Agreement, You agree to be bound by this Agreement.
6
+
7
+ 1. Definitions.
8
+ 1.1. “openPangu” or “Model” means openPangu large language models and software, including trained model weights, parameters (including optimizer states), accompanying source code and scripts released under this Agreement.
9
+ 1.2. “Derivative Model” means all (1) modifications to the Model, (2) works based on the Model, and (3) any other derivative works of the Model. For clarity, information or content results from operating or otherwise using the Model is not a Derivative Model.
10
+ 1.3. “You” or “Your” means an individual or Legal Entity exercising permissions granted by this Agreement and/or using the Model for any purpose.
11
+ 1.4. “Third Party” or “Third Parties” means individuals or legal entities that are not under common control with Us or You.
12
+
13
+ 2. License Grant. Subject to Your full compliance with the terms and conditions of this Agreement, We hereby grant to You a perpetual, worldwide, non-exclusive, non-transferable, no-charge, royalty-free license (except as stated in Section 3) to use, reproduce, modify, and distribute the Model.
14
+
15
+ 3. Conditions for License Grant. You represent and warrant that You will not, access, download, install, run, deploy, integrate, modify, or otherwise use the Model, directly or indirectly, within the European Union.
16
+
17
+
18
+ 4. Redistribution.
19
+ 4.1. If You distribute the Model or Derivative Model, You shall retain in Your distribution (1) a copy of this agreement, and (2) all copyright notices and other notices of origin included in the Model that are applicable to Your distribution.
20
+ 4.2. Further, if You distribute or make available to Third Parties a product or service (including another AI model) based on the Model, You are required to (1) display the acknowledgement “Powered by openPangu” and (2) include a trademark notice “openPangu is a trademark of Huawei Technologies Co., Ltd.” on related webpages, user manuals, product documentations or other advertising materials mentioning features of the Model.
21
+ 4.3. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for Derivative Model made by You as a whole, provided Your use, reproduction, and distribution of the Model otherwise complies with the terms and conditions of this Agreement.
22
+
23
+ 5. Ownership. We do not claim ownership to any information or content generated using the Model or Derivative Model that are made by You. You are solely responsible for evaluating the accuracy and appropriateness of such information or content for Your use case.
24
+
25
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of Huawei, except as required for complying with Section 4.2.
26
+
27
+ 7. Indemnity. You will indemnify and hold harmless Huawei from and against any claim by any third party arising out of or related to Your use or distribution of the Model or Derivative Model made by You (e.g. a violation against Section 3). For avoidance of doubt, “third party” in this clause include supervisory authorities.
28
+
29
+ 8. THE MODEL IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, NONINFRINGEMENT, ACCURACY, OR THE ABSENCE OF LATENT OR OTHER DEFECTS OR ERRORS, WHETHER OR NOT DISCOVERABLE, ALL TO THE GREATEST EXTENT PERMISSIBLE UNDER APPLICABLE LAW.
30
+
31
+ 9. IN NO EVENT SHALL WE BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE MODEL, IN WHOLE OR IN PART, NO MATTER HOW IT’S CAUSED OR THE LEGAL THEORY IT IS BASED ON, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
32
+
33
+
34
+ END OF THE TERMS AND CONDITIONS
README.md CHANGED
@@ -1,6 +1,108 @@
1
- ---
2
- license: other
3
- license_name: openpangu-model-license-agreement-version-1.0
4
- license_link: >-
5
- https://ai.gitcode.com/ascend-tribe/openpangu-ultra-moe-718b-model/blob/main/LICENSE
6
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 开源盘古 Ultra-MoE-718B
2
+ 中文 | [English](README_EN.md)
3
+
4
+ ## 1. 简介
5
+ openPangu-Ultra-MoE-718B 是基于昇腾NPU从零训练的大规模混合专家语言模型,总参数量为718B,激活参数量为39B。openPangu-Ultra-MoE-718B 训练了约19T tokens,具备快慢思考融合能力。
6
+
7
+ ## 2. 模型架构
8
+ openPangu-Ultra-MoE-718B 的模型架构采用了业界主流的Multi-head Latent Attention (MLA)、Multi-Token Prediction (MTP)、大稀疏比等架构,以及一些特有的设计:
9
+
10
+ - Depth-Scaled Sandwich-Norm和TinyInit:通过调整层归一化结构与参数初始化,提升训练稳定性。
11
+ - 基于EP-Group的负载均衡策略:通过优化负载均衡损失函数,改善专家特化效果。
12
+
13
+ ## 3. 测评结果
14
+
15
+ | 测评集 | 测评指标 | 慢思考 |
16
+ |:----------------:|:----------------------------:|:-----:|
17
+ | **通用能力** | | |
18
+ | C-Eval | Acc | 91.06 |
19
+ | CLUEWSC | Acc | 94.67 |
20
+ | MMLU-Pro | Exact Match | 82.40 |
21
+ | ArenaHard_v0.1 | w/o Style Control | 96.80 |
22
+ | GPQA-Diamond | Avg@4 | 76.77 |
23
+ | SuperGPQA | Acc | 61.67 |
24
+ | IF-Eval | Prompt Strict | 80.59 |
25
+ | SysBench | Constraint Satisfaction Rate | 91.43 |
26
+ | **数学能力** | | |
27
+ | CNMO 2024 | Avg@32 | 80.73 |
28
+ | AIME25 | Avg@16 | 75.21 |
29
+ | AIME24 | Avg@16 | 80.21 |
30
+ | MATH-500 | Avg@1 | 97.40 |
31
+ | **代码能力** | | |
32
+ | LiveCodeBench | Avg@3 (01/25~05/25) | 61.14 |
33
+ | MBPP+ | Avg@2 | 81.48 |
34
+
35
+ **注:** 评测过程中,system prompt 为空。
36
+
37
+
38
+ ## 4. 部署和使用
39
+ ### 4.1 环境准备
40
+ #### 硬件规格
41
+ Atlas 800T A2 (64GB, >=32卡),驱动与固件安装包获取请参照[[Atlas 800T A2](https://www.hiascend.com/hardware/firmware-drivers/community?product=4&model=26&cann=8.2.RC1.alpha003&driver=Ascend+HDK+25.0.RC1)]
42
+
43
+ #### 软件环境
44
+ - 方式一:基于裸机环境安装以下配套软件
45
+ - 操作系统:Linux(推荐openEuler>=24.03)
46
+ - CANN==8.1.RC1,安装准备及流程请参照[[CANN Install](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/82RC1alpha002/softwareinst/instg/instg_0001.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)]
47
+ - python==3.10
48
+ - torch==2.1.0
49
+ - torch-npu==2.1.0.post12
50
+ - transformers>=4.48.2
51
+
52
+ - 方式二:从docker镜像启动容器
53
+
54
+ 参考[[Docker使用指南](doc/docker.md)]
55
+
56
+ 以上软件配套经过验证,理论可以支持更高的版本,如有疑问,可以提交issue。
57
+
58
+ ### 4.2 权重完整性校验
59
+ 请参考以下方法对下载内容进行完整性校验,hash 值存储在 checklist.chk 文件中。
60
+
61
+ ```
62
+ #!/usr/bin/env bash
63
+ ARCH=$(uname -m)
64
+ MODEL_PATH="${TARGET_FOLDER}/${MODEL_FOLDER_PATH}"
65
+ cd "$MODEL_PATH" || exit 1
66
+ if [ "$ARCH" = "arm64" ]; then
67
+ sha256sum checklist.chk
68
+ else
69
+ sha256sum -c checklist.chk
70
+ fi
71
+ ```
72
+
73
+ ### 4.3 推理权重转换
74
+ 本次样例 openPangu-Ultra-MoE-718B 推理采用 Tensor Parallel 并行策略,叠加昇腾 NPU 融合大算子,需要提前对 safetensors 权重进行切分,下述内容提供32卡并行推理的权重切分示例,切分后的权重会保存在`model/`目录下:
75
+ ```bash
76
+ cd inference
77
+ bash split_weight.sh
78
+ ```
79
+
80
+ ### 4.4 推理样例
81
+ openPangu-Ultra-MoE-718B 在 Atlas 800T A2 上4机32卡bfloat16推理示例,主节点选取节点IP0:
82
+ ```bash
83
+ cd inference
84
+ # 主节点IP0: ${NNODES} ${NODE_RANK} ${NPROC_PER_NODE} ${MASTER_ADDR} ${PROMPT}
85
+ bash generate.sh 4 0 8 IP0 "3*7=?"
86
+ # 从节点IP1
87
+ bash generate.sh 4 1 8 IP0 "3*7=?"
88
+ # 从节点IP2
89
+ bash generate.sh 4 2 8 IP0 "3*7=?"
90
+ # 从节点IP3
91
+ bash generate.sh 4 3 8 IP0 "3*7=?"
92
+ ```
93
+ 模型默认为慢思考模式,可以通过以下手段切换至快思考模式:如`generate.py`示例中`fast_thinking_template`所示,在用户输入结尾添加` /no_think`标记可以将当前轮次切换至快思考模式。
94
+
95
+ ### 4.5 使用推理框架
96
+ vllm_ascend:参考[[vllm_ascend_for_openPangu_ultra_moe_718b](doc/vllm_ascend_for_openpangu_ultra_moe_718b.md)]
97
+
98
+ ## 5. 模型许可证
99
+ 除文件中对开源许可证另有约定外,openPangu-Ultra-MoE-718B 模型根据 OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0 授权,旨在允许使用并促进人工智能技术的进一步发展。有关详细信息,请参阅模型存储库根目录中的 [LICENSE](LICENSE) 文件。
100
+
101
+ ## 6. 免责声明
102
+ 由于 openPangu-Ultra-MoE-718B (“模型”)所依赖的技术固有的限制,以及人工智能生成的内容是由盘古自动生成的,华为无法对以下事项做出任何保证:
103
+ - 该模型的输出通过AI算法自动生成,不能排除某些信息可能存在缺陷、不合理或引起不适的可能性,生成的内容不代表华为的态度或立场;
104
+ - 无法保证该模型100%准确、可靠、功能齐全、及时、安全、无错误、不间断、持续稳定或无任何故障;
105
+ - 该模型的输出内容不构成任何建议或决策,也不保证生成的内容的真实性、完整性、准确性、及时性、合法性、功能性或实用性。生成的内容不能替代医疗、法律等领域的专业人士回答您的问题。生成的内容仅供参考,不代表华为的任何态度、立场或观点。您需要根据实际情况做出独立判断,华为不承担任何责任。
106
+
107
+ ## 7. 反馈
108
+ 如果有任何意见和建议,请提交issue或联系[openPangu@huawei.com](url)。
README_EN.md ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # openPangu-Ultra-MoE-718B
2
+ English | [中文](README.md)
3
+
4
+ ## 1. Introduction
5
+ The openPangu-Ultra-MoE-718B is a large-scale mixture-of-experts language model trained from scratch on Ascend NPU, with a total parameter count of 718B and 39B activated parameters per token. The openPangu-Ultra-MoE-718B is trained on approximately 19 trillion tokens, and equipped with the capability to switch between fast and slow thinking.
6
+
7
+ ## 2. Model Architecture
8
+ The architecture of the openPangu-Ultra-MoE-718B adopts the mainstream Multi-head Latent Attention (MLA), Multi-Token Prediction (MTP), high MoE sparsity, and features several different designs:
9
+
10
+ - Depth-Scaled Sandwich-Norm and TinyInit: These techniques adjust the layer normalization structure and parameter initialization for improved training stability.
11
+
12
+ - EP-Group load balancing loss: This technique optimizes the load balancing loss, achieving better expert specialization.
13
+
14
+
15
+ ## 3. Results
16
+
17
+ | Benchmark | Metric | Slow-thinking |
18
+ |:-------------------------:|:------------------------------:|:-----------------:|
19
+ | **General** | | |
20
+ | C-Eval | Acc | 91.06 |
21
+ | CLUEWSC | Acc | 94.67 |
22
+ | MMLU-Pro | Exact Match | 82.40 |
23
+ | ArenaHard_v0.1 | w/o Style Control | 96.80 |
24
+ | GPQA-Diamond | Avg@4 | 76.77 |
25
+ | SuperGPQA | Acc | 61.67 |
26
+ | IF-Eval | Prompt Strict | 80.59 |
27
+ | SysBench | Constraint Satisfaction Rate | 91.43 |
28
+ | **Math** | | |
29
+ | CNMO 2024 | Avg@32 | 80.73 |
30
+ | AIME25 | Avg@16 | 75.21 |
31
+ | AIME24 | Avg@16 | 80.21 |
32
+ | MATH-500 | Avg@1 | 97.40 |
33
+ | **Coding** | | |
34
+ | LiveCodeBench | Avg@3 (01/25~05/25) | 61.14 |
35
+ | MBPP+ | Avg@2 | 81.48 |
36
+
37
+ **Note:** The system prompt is empty during the evaluation process.
38
+
39
+
40
+ ## 4. Deployment
41
+ ### 4.1 Environment
42
+ #### Hardware Requirements
43
+ Atlas 800T A2 (64GB, >=32 NPUs), please refer to [[Atlas 800T A2](https://www.hiascend.com/hardware/firmware-drivers/community?product=4&model=26&cann=8.2.RC1.alpha003&driver=Ascend+HDK+25.0.RC1)] for obtaining the driver and firmware installation packages.
44
+
45
+ #### System Requirements & Dependencies
46
+ - Method 1: Install the following supporting software in a bare-metal environment.
47
+ - System: Linux (openEuler ≥ 24.03 recommended)
48
+ - CANN==8.1.RC1, please refer to [[CANN Install](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/82RC1alpha002/softwareinst/instg/instg_0001.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)] for installation
49
+ - python==3.10
50
+ - torch==2.1.0
51
+ - torch-npu==2.1.0.post12
52
+ - transformers>=4.48.2
53
+
54
+ - Method 2: Start a container from a docker image.
55
+
56
+ Refer to the [[Docker User Guide](doc/docker_EN.md)]
57
+
58
+ The above software environment has been verified, and theoretically supports newer versions. For any questions, please submit an issue.
59
+
60
+ ### 4.2 Integrity Check
61
+ Please refer to the following methods to verify the integrity of the downloaded content. The hash values are stored in the `checklist.chk` file.
62
+
63
+ ```
64
+ #!/usr/bin/env bash
65
+ ARCH=$(uname -m)
66
+ MODEL_PATH="${TARGET_FOLDER}/${MODEL_FOLDER_PATH}"
67
+ cd "$MODEL_PATH" || exit 1
68
+ if [ "$ARCH" = "arm64" ]; then
69
+ sha256sum checklist.chk
70
+ else
71
+ sha256sum -c checklist.chk
72
+ fi
73
+ ```
74
+
75
+ ### 4.3 Model Weights Conversion
76
+ This inference example of the openPangu-Ultra-MoE-718B adopts Tensor Parallel strategy with fused operators on Ascend NPU. It requires pre-sharding of the model weights. The following provides an example of weight sharding for 32-NPU parallel inference, with the split weights saved in the `model/` directory.
77
+ ```bash
78
+ cd inference
79
+ bash split_weight.sh
80
+ ```
81
+
82
+ ### 4.4 Inference Examples
83
+ The following provides a simple bfloat16 inference example of the openPangu-Ultra-MoE-718B deployed on a 4-node 32-NPU Atlas 800T A2 cluster, for which the node IP0 is selected as the master node:
84
+ ```bash
85
+ cd inference
86
+ # Master node IP0: ${NNODES} ${NODE_RANK} ${NPROC_PER_NODE} ${MASTER_ADDR} ${PROMPT}
87
+ bash generate.sh 4 0 8 IP0 "3*7=?"
88
+ # Worker node IP1
89
+ bash generate.sh 4 1 8 IP0 "3*7=?"
90
+ # Worker node IP2
91
+ bash generate.sh 4 2 8 IP0 "3*7=?"
92
+ # Worker node IP3
93
+ bash generate.sh 4 3 8 IP0 "3*7=?"
94
+ ```
95
+ The model operates by default in slow thinking mode and can be configured to fast thinking mode through the following method: As demonstrated in the `fast_thinking_template` within the `generate.py` example, appending the ` /no_think` flag to the end of user input.
96
+
97
+ ### 4.5 Using Inference Framework
98
+ Vllm-ascend:please refer to [[vllm_ascend_for_openpangu_ultra_moe_718b_EN](doc/vllm_ascend_for_openpangu_ultra_moe_718b_EN.md)]
99
+
100
+ ## 5. Model License
101
+ Unless otherwise noted, the openPangu-Ultra-MoE-718B model is licensed under the terms and conditions of OPENPANGU MODEL LICENSE AGREEMENT VERSION 1.0, which is intended to be used permissively and enable the further development of artificial intelligence technologies. Please refer to the [LICENSE](LICENSE) file located in the root directory of the model repository for details.
102
+
103
+ ## 6. Disclaimer
104
+ Due to the technical limitations inherent in the technology on which the openPangu-Ultra-MoE-718B (“Model”) relies and the fact that the artificial intelligence generated content is automatically produced by Model, Huawei cannot make any guarantees regarding the following matters:
105
+
106
+ - The output of this Model is automatically generated via AI algorithms, it does not rule out the possibility that some of the information may be flawed, unreasonable, or cause discomfort, and the generated content does not represent Huawei's attitude or standpoint;
107
+ - There is no guarantee that this Model is 100% accurate, reliable, functional, timely, secure and safety, error-free, uninterrupted, continuously stable, or free of any faults;
108
+ - The output of this Model does not constitute any advices or decisions for you, and it does not guarantee the authenticity, completeness, accuracy, timeliness, legality, functionality, or practicality of the generated content. The generated content cannot replace professionals in medical, legal, and other fields in answering your questions. The generated content is for your reference only and does not represent any attitude, standpoint, or position of Huawei. You need to make independent judgments based on your actual situation, and Huawei does not assume any responsibilities.
109
+
110
+ ## 7. Contact
111
+ If you have any question, please raise an issue or contact us at [openPangu@huawei.com](url).
checklist.chk ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 714e6540d1371ad78a5816a6c69e2f0e330594939dd14b9dac7041e784d701a4 *./tokenizer_config.json
2
+ 6b16f1558c0cd4ae6ef1a2c605713be0a514f50e1ce2d2c878979ce988c148ec *./tokenizer.model
3
+ 81c7a7c24ed70acdeeff295a5a091a401cacfb49dcc30308a5b451956b051223 *./tokenization_openpangu.py
4
+ b34cf5e7c7660889303b6e2d0a346c440356385c9db551d06f6615cf9fc600d1 *./special_tokens_map.json
5
+ 035c0bc169b317e19096ac2f81d6fa790850533f573bedda9d3e88a1183aed9b *./modeling_openpangu_moe.py
6
+ 6580557dcb86f3285d3c056eeb21eb9b702e0fa908577ba471cc91ec6c9b3fcc *./model.safetensors.index.json
7
+ bc6505adabc0498ad07b49187858788c65c13dbf9446fd0bcf177a3e1b27220d *./inference/vllm_ascend/worker/npu_input_batch.py
8
+ 62c6734d1283e3d649a6478d2004f46bfee2f7878af7f2849c979b124e355302 *./inference/vllm_ascend/worker/model_runner_v1.py
9
+ e2457c558f048876afe069d1226e7080ac214478f1a9ac28ae472928b81b5a06 *./inference/vllm_ascend/utils.py
10
+ 6adfaa8a67ea9b561dec2e6a2392f6fc85ff376fb2030d8761c34c6c6d3f4cbf *./inference/vllm_ascend/quantization/w8a8_dynamic.py
11
+ 743bd96cfc109975a11fe5412c4b5de46f880501dcbbbdd10e11cbeb865fa4f2 *./inference/vllm_ascend/quantization/w8a8.py
12
+ e712ea36caf16c2a9dd21c5288f9d8e34c7fd2face444da44dca6db6c21f6c1b *./inference/vllm_ascend/patch/worker/patch_common/patch_sampler.py
13
+ 8c59df8086bde0cd4df674403f83000921a34403651a8ff2b31de9b28768247a *./inference/vllm_ascend/patch/worker/patch_common/patch_parsers.py
14
+ 8436ab93933989431160e55627b5dce5326f0fc5ec18263653902764ac8ace7b *./inference/vllm_ascend/patch/worker/patch_common/patch_config.py
15
+ 63a6ba0d0b0158d4586219c979bf96d5fe87b74123af93f1c8d9ed842db96500 *./inference/vllm_ascend/patch/worker/patch_common/__init__.py
16
+ 09273eb0e4696d2fb530881ba1ad9d331897dd81c0cd2f203ed3d0a475b4d39b *./inference/vllm_ascend/ops/fused_moe.py
17
+ b654e72ece161b3f04080e5c4d2476641c024939ac5308115fe1c65a6c5c7215 *./inference/vllm_ascend/models/open_pangu.py
18
+ e98aa2549f02017a35b07499216fe569e86400684087821820cf2d971c8fcbac *./inference/vllm_ascend/models/__init__.py
19
+ 52a968f10ebaebeb626248afd3e1d1b92f8fbfcaad19ebf05cafbc0bd03192cb *./inference/vllm_ascend/envs.py
20
+ 91eab52cdc19603b7b705b302e25345d849e18fa66875261a1135d5382392123 *./inference/vllm_ascend/entrypoints/openai/tool_parsers/pangu_tool_parser.py
21
+ d07256c9014f911f81269e65aad6c0d7dd61d4e82f5cb399e05285d5c1bc8fa8 *./inference/vllm_ascend/entrypoints/openai/tool_parsers/__init__.py
22
+ f9577c29bc4dc19a4cc41ccfcca17065402c9dd92221bef987c74808b23ed124 *./inference/vllm_ascend/entrypoints/openai/reasoning_parsers/pangu_reasoning_parser.py
23
+ 9070682b058a79d2b2874ba5e07ce72beff6efb870f75cdac30cdcf6ba8fadc7 *./inference/vllm_ascend/entrypoints/openai/reasoning_parsers/__init__.py
24
+ 2254aeca0be7b8922318e10c4a950f39afb30ba5fe3b46564a58671b237ac612 *./inference/vllm_ascend/attention/mla_v1.py
25
+ ba6d7edcf1cf464d6fd787b12a9bda2a16fea0ac0d5d1e54136baec503d6e696 *./inference/vllm_ascend/attention/attention.py
26
+ 4aaf57e6f6d2e139b3847b10ee59d738398ebfc4927a22325b27dad384874aec *./inference/vllm_ascend/_build_info.py
27
+ 5cd02a8ec3b7494e08c6cea07927c377b5da177b9cf704988d3c2ca367528c09 *./inference/split_weight.sh
28
+ a2ea110364d7eecf039e180b42986f1c830acc3fd6ac487f45f93f89f285c669 *./inference/split_weight.py
29
+ bd8f9bb4a9cd7d182a499307746c939514b4487ebbecbdf9527482f2c31aed9a *./inference/runner_config/tp32.yaml
30
+ 379d51e424a24b6599a7e5b85187ced84a8d016c051133432d0a4eaa58a5c900 *./inference/runner_config/tp1.yaml
31
+ 72f5bf3c6e4a990d5c9b2412f4b6bf9534b99ce7e1c196f7a42740a6f006e7ad *./inference/runner.py
32
+ 85841e6a1bc03eff679e3cf00958864f86e3260dc90e69379132f2e3dc6674ad *./inference/model.py
33
+ d646343af70b851f1942ee6dbdc1807686c13a48d4d660ebac777dba29eafdd1 *./inference/generate.sh
34
+ d4d48848ec2c7945670a6196323a570d1830e4edcf5336919641e58bcbc9da0a *./inference/generate.py
35
+ feb36ae08104d00af5bd32e6e20b67a11588e02911b15b3261650b22e1db3ad8 *./generation_config.json
36
+ 634ef0ce7809d0e44d31fecf72863146a08e609a8ba9cbe16b427f0de12fe2e0 *./configuration_openpangu_moe.py
37
+ 9988eca928867694568154f72de3cedaca34b258954652b3c95772d9a5f5118e *./config.json
38
+ fca1e83c102d7a4b9c9f1bcfdc8766df2037d77550ca33dc45c8342fd5b73d0d *./model-00062-of-000062.safetensors
39
+ 0548a05b2afa374e731f8785ad5ee7302335dcba952d3797b759ad920f3f4fce *./model-00061-of-000062.safetensors
40
+ 4351cc1440af69fca1735baa78c09658b08c9ae67fcc076ad4fa9ea2d25e084c *./model-00060-of-000062.safetensors
41
+ 8802bb0b554a1e6404b58b13f328c82fb4f2e2d6d7524f7b4f14d7bb6e81d0f3 *./model-00059-of-000062.safetensors
42
+ 59a1f7667adcfff55145b4b21789e7c494aa80d8bc2fa466d0fa43cd0f3ff43e *./model-00058-of-000062.safetensors
43
+ 65c1cd81cf879606bcddc4af60b506e9bcda80955be72693e642d3a325ffd8e7 *./model-00057-of-000062.safetensors
44
+ b72df9faac3f73cf7191dbb454229f74bc22e1a98d5ad7ef2aea85e0b14d4123 *./model-00056-of-000062.safetensors
45
+ 4ded3c2baad08d7646cf3da25dcecda3bcbdd542f90acf76848d93000b3ab23b *./model-00055-of-000062.safetensors
46
+ 86cc6e275ab190244cc2bfc0cff5689fb7c145623ec54d03ed351df6af829ec2 *./model-00054-of-000062.safetensors
47
+ a5bba51810248ad0a9a54e55b56b946732737ee5f5f289fab580b7c36f21d3f6 *./model-00053-of-000062.safetensors
48
+ cd153a6a0fd5dea192768257c71d47655d7a7a1dbdffdaffeab4dfb7ad27b3eb *./model-00052-of-000062.safetensors
49
+ 85000bb98b77232c47542490df67faec31e3d5105b7d4de34e46466d53220bd6 *./model-00051-of-000062.safetensors
50
+ 0da40653e3dcb3a32aa929674062224f3a74fa3eeefb6dcc5a6698cd9f596708 *./model-00050-of-000062.safetensors
51
+ 678d2b3ac3c73f387a18a446d83dd866ffc925f82ff56541d3b71a747c6b7d06 *./model-00049-of-000062.safetensors
52
+ a49b5e3f1be43a6a3bff9fec25d5d8dffad5a8ebd96c8e389933f40287f1888e *./model-00048-of-000062.safetensors
53
+ 54d984279e435df0426cb052ed99067b7a4a1637234a24a8f05bcb7bbd89d0d2 *./model-00047-of-000062.safetensors
54
+ 4fe5e98cb4e492639bafe37c3e3144c3fe8f27a9fd45e48c92c780f39599becc *./model-00046-of-000062.safetensors
55
+ 97458250006f5949c65b338a26503738200c5fb2415f4cc664a6b224aa9dce70 *./model-00045-of-000062.safetensors
56
+ fd548640dbe4cc04ef4834ac32cada57fb43b0fb9971f0e36030d94041dd1b0d *./model-00044-of-000062.safetensors
57
+ 58847a3be6d606e21941769346e86165891f4fa7242cc84cda7edc9620298ad2 *./model-00043-of-000062.safetensors
58
+ f0be4dc1d9326543061e858d566955aefa9d4a757ebd8f92df648bd9c16a236b *./model-00042-of-000062.safetensors
59
+ d6de2cffa758d32d790c89ac7df7aa313ec1a2923faf52549357a3e8ff16d74f *./model-00041-of-000062.safetensors
60
+ cf9fb7c2ca6e977d9e6f19bd3204ba8b8ad15d33cef9a0ad9f6e9b867319fc8b *./model-00040-of-000062.safetensors
61
+ 2f12c46798cd8f51740964ea58b59ef682eca6ee2ae9402283214f1f0fd4113c *./model-00039-of-000062.safetensors
62
+ 0a531b364281f6060567549b741a1a53a81c1d9211170354bf32c88494b486e9 *./model-00038-of-000062.safetensors
63
+ e3fe1e5795ffb480aa934410434f30711a4dc982f7eea3df5ac1da57783bc619 *./model-00037-of-000062.safetensors
64
+ e36d0dbfbfbb7a792246c931c89d95be2e7afff1872ff455114c58d82b7398d2 *./model-00036-of-000062.safetensors
65
+ af8bfd55a58902cdd8293c67f2c9c5e069ca38f3086645a05bfb9140fe6d51ba *./model-00035-of-000062.safetensors
66
+ 9aa7ab7d596db78af87e87a92e8111fc673000b3132cffe02fa2e164d77b8b32 *./model-00034-of-000062.safetensors
67
+ dfa6683035c9caca4de02afc6b3b4fc5fe786e0ee687ddfa1f72c4171eb33821 *./model-00033-of-000062.safetensors
68
+ cdb597fd48c542dd1508d9ce069b8e7c19007908fbdd2dc8a443d73bee6f1754 *./model-00032-of-000062.safetensors
69
+ 457f476851463f36e3fdcb54e54dc39c9890e64b6b06b100ac5013c3c72385e4 *./model-00031-of-000062.safetensors
70
+ 3d1ee1e248181a08f2401d314f71d342a95e8c9fbee5877800b392be54b1343f *./model-00030-of-000062.safetensors
71
+ 1856663f2077287ba21f3bd821705c1f55bac979e768d67fc31c52727b34ae2f *./model-00029-of-000062.safetensors
72
+ 0cfe77fd4bfac0f9623411d1234872e141b211eb0b3cf61238380f5b34c3c043 *./model-00028-of-000062.safetensors
73
+ 933263ed6db42b1e16407b4400b70260425b1cc11f9ed8fdaad6ef5935f05fb4 *./model-00027-of-000062.safetensors
74
+ d15f1371a364df11676b105291e481fbbd1999ea2152ec0b14905f5d9cb854fa *./model-00026-of-000062.safetensors
75
+ 744dedfcb6bc74624351cc299772ee6be389147301b05f4fe645ebac7bedb53b *./model-00025-of-000062.safetensors
76
+ fd9cf078f3a819e230a78fdf8201e37fd25696f576e1fd0d46fb122deb11c2c8 *./model-00024-of-000062.safetensors
77
+ 3ec91acf57a576b8550d14312dda2acf345ce09769407d1596e0cbdff5a1200a *./model-00023-of-000062.safetensors
78
+ 6df9b0923f9ec0ca7ab20dd0934fc36467d02f02d0cb996904b2f1181d37502b *./model-00022-of-000062.safetensors
79
+ 0e4ac1edb16a327ed30dd75325fdb9d6e7da6d35724fc086c66c74822d6a1de7 *./model-00021-of-000062.safetensors
80
+ 9837d1382bdab63f5781e6b76f3634775a94cd792adadb4d51763693aa670c36 *./model-00020-of-000062.safetensors
81
+ 6a1de462c450f0ddc2e379224cc5ca0ebd31b9bcb9102ed4eb7effe8298de963 *./model-00019-of-000062.safetensors
82
+ 1bf72bcf187656c13ac2f003fa534bbb79d22374711a7d5315766b2299800c4e *./model-00018-of-000062.safetensors
83
+ 41bda994ada5d86c56f660db9350acd900201d1bc131379ce892e60564154e5f *./model-00017-of-000062.safetensors
84
+ 009a80118b5e38fc072a5bcaf20efa832148c14e55765ae2e4176edda11d6608 *./model-00016-of-000062.safetensors
85
+ 310713a4eeebdb60bd62b1c9c0a51bcd6efd6fdefde5c31d4adc8ce854d06f23 *./model-00015-of-000062.safetensors
86
+ 91ef0b66652923f58267a69bb39a18da89ce71732a7a92b458447bb938fb17e7 *./model-00014-of-000062.safetensors
87
+ deabb0cc16d4bcaa6576583acefa13c92e0bc16c6209e8db0fdea2bb55b45501 *./model-00013-of-000062.safetensors
88
+ 95d7980844ee1411f71358f8268d537f0f6c2d5d87a15e101e956d1b3c9c61b2 *./model-00012-of-000062.safetensors
89
+ 4661dd325fda474f8490505bfe9b3652ae2976284825d17cfb7bd60d01aacae2 *./model-00011-of-000062.safetensors
90
+ 9f3168f097ba85d266b062732aab8cb28daae08fef108305291819750be1b384 *./model-00010-of-000062.safetensors
91
+ 91547f7296b08e93903a250eec9a25562f714a7bdfab511ae4fd0f358aaec832 *./model-00009-of-000062.safetensors
92
+ 5ee6ca3f506708b453076c56027fe7366958ef4baa06fdd69f0992e15544ad17 *./model-00008-of-000062.safetensors
93
+ 07fa95a4e6b3e9b3475076f160e839cb715ab426fb78a63d8e2decb636cb8987 *./model-00007-of-000062.safetensors
94
+ a0b071272706a4d4d4ed5a331d60590930a3529c13c4bc158b6f1b0bc3dd8c85 *./model-00006-of-000062.safetensors
95
+ 4e3907f683d7f8382d2a792304155e8533ffa3a94dd4bb5ff825124b0dba3835 *./model-00005-of-000062.safetensors
96
+ 2bd5c0012a3cedf4582160173f857480dd58426282a0e5826609a02b5aff5b3e *./model-00004-of-000062.safetensors
97
+ 3f63aa17d947032e0a524b5798eee3becbfc9a9b6f8a352ead3232e7b34bb289 *./model-00003-of-000062.safetensors
98
+ 1e29a512e3737d1826c80a2277a8b42021878847753aadbe5e1ae2a2df3d7f8d *./model-00002-of-000062.safetensors
99
+ c692b00cbc19ee5a2d4a9bb71496f12a846d14e06d7e0ac79b46abc3243ee115 *./model-00001-of-000062.safetensors
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "PanguUltraMoEForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_openpangu_moe.PanguUltraMoEConfig",
8
+ "AutoModel": "modeling_openpangu_moe.PanguUltraMoEModel",
9
+ "AutoModelForCausalLM": "modeling_openpangu_moe.PanguUltraMoEForCausalLM"
10
+ },
11
+ "num_dense_layers": 3,
12
+ "hidden_act": "silu",
13
+ "hidden_size": 7680,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 18432,
16
+ "attention_kv_lora_dim": 512,
17
+ "max_position_embeddings": 131072,
18
+ "model_type": "pangu_ultra_moe",
19
+ "moe_intermediate_size": 2048,
20
+ "num_routed_experts": 256,
21
+ "num_shared_experts": 1,
22
+ "num_attention_heads": 128,
23
+ "num_experts_per_tok": 8,
24
+ "num_hidden_layers": 61,
25
+ "num_key_value_heads": 128,
26
+ "num_mtp_layers": 1,
27
+ "attention_q_lora_dim": 1536,
28
+ "attention_qk_dim": 128,
29
+ "attention_qk_rope_dim": 64,
30
+ "rms_norm_eps": 1e-05,
31
+ "rope_theta": 25600000,
32
+ "routed_scaling_factor": 2.5,
33
+ "sandwich_norm": true,
34
+ "tie_word_embeddings": false,
35
+ "torch_dtype": "bfloat16",
36
+ "transformers_version": "4.48.2",
37
+ "use_cache": true,
38
+ "attention_v_dim": 128,
39
+ "vocab_size": 153600
40
+ }
configuration_openpangu_moe.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3
+
4
+ """openPanguUltraMoE 718B model configuration"""
5
+
6
+ from transformers.configuration_utils import PretrainedConfig
7
+
8
+
9
+ class PanguUltraMoEConfig(PretrainedConfig):
10
+
11
+ model_type = "pangu_ultra_moe"
12
+ keys_to_ignore_at_inference = ["past_key_values"]
13
+
14
+ def __init__(
15
+ self,
16
+ vocab_size=153600,
17
+ hidden_size=7680,
18
+ intermediate_size=18432,
19
+ moe_intermediate_size=2048,
20
+ num_hidden_layers=61,
21
+ num_mtp_layers=1,
22
+ num_attention_heads=128,
23
+ num_key_value_heads=128,
24
+ num_shared_experts=1,
25
+ num_routed_experts=256,
26
+ routed_scaling_factor=2.5,
27
+ attention_kv_lora_dim=512,
28
+ attention_q_lora_dim=1536,
29
+ attention_qk_rope_dim=64,
30
+ attention_v_dim=128,
31
+ attention_qk_dim=128,
32
+ num_experts_per_tok=8,
33
+ num_dense_layers=3,
34
+ norm_topk_prob=True,
35
+ hidden_act="silu",
36
+ max_position_embeddings=131072,
37
+ initializer_range=0.02,
38
+ rms_norm_eps=1e-5,
39
+ use_cache=True,
40
+ pad_token_id=None,
41
+ bos_token_id=0,
42
+ eos_token_id=1,
43
+ tie_word_embeddings=False,
44
+ rope_theta=25600000,
45
+ attention_dropout=0.0,
46
+ **kwargs,
47
+ ):
48
+ self.vocab_size = vocab_size
49
+ self.max_position_embeddings = max_position_embeddings
50
+ self.hidden_size = hidden_size
51
+ self.num_hidden_layers = num_hidden_layers
52
+ self.num_attention_heads = num_attention_heads
53
+ self.num_key_value_heads = num_key_value_heads
54
+ self.hidden_act = hidden_act
55
+ self.initializer_range = initializer_range
56
+ self.rms_norm_eps = rms_norm_eps
57
+ self.use_cache = use_cache
58
+ self.rope_theta = rope_theta
59
+
60
+ self.num_dense_layers = num_dense_layers
61
+ self.intermediate_size = intermediate_size
62
+ self.moe_intermediate_size = moe_intermediate_size
63
+ self.num_shared_experts = num_shared_experts
64
+ self.num_routed_experts = num_routed_experts
65
+ self.routed_scaling_factor = routed_scaling_factor
66
+ self.num_experts_per_tok = num_experts_per_tok
67
+ self.norm_topk_prob = norm_topk_prob
68
+ self.attention_kv_lora_dim = attention_kv_lora_dim
69
+ self.attention_q_lora_dim = attention_q_lora_dim
70
+ self.attention_qk_rope_dim = attention_qk_rope_dim
71
+ self.attention_v_dim = attention_v_dim
72
+ self.attention_qk_dim = attention_qk_dim
73
+ self.attention_dropout = attention_dropout
74
+ self.num_mtp_layers = num_mtp_layers
75
+
76
+ super().__init__(
77
+ pad_token_id=pad_token_id,
78
+ bos_token_id=bos_token_id,
79
+ eos_token_id=eos_token_id,
80
+ tie_word_embeddings=tie_word_embeddings,
81
+ **kwargs,
82
+ )
doc/docker.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Atlas 800T A2一键部署
2
+
3
+ 参考使用vllm-ascend社区开源[镜像](https://vllm-ascend.readthedocs.io/en/latest/quick_start.html)
4
+
5
+ ```bash
6
+ export IMAGE=m.daocloud.io/quay.io/ascend/vllm-ascend:v0.9.2rc1
7
+ export NAME=atlas_800t_a2
8
+ docker run --rm \
9
+ --name $NAME \
10
+ --net=host \
11
+ --device /dev/davinci0 \
12
+ --device /dev/davinci1 \
13
+ --device /dev/davinci2 \
14
+ --device /dev/davinci3 \
15
+ --device /dev/davinci4 \
16
+ --device /dev/davinci5 \
17
+ --device /dev/davinci6 \
18
+ --device /dev/davinci7 \
19
+ --device /dev/davinci_manager \
20
+ --device /dev/devmm_svm \
21
+ --device /dev/hisi_hdc \
22
+ -v /usr/local/dcmi:/usr/local/dcmi \
23
+ -v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
24
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
25
+ -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
26
+ -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
27
+ -v /etc/ascend_install.info:/etc/ascend_install.info \
28
+ -v /data:/data \
29
+ -v /tmp:/tmp \
30
+ -it $IMAGE bash
31
+ ```
doc/docker_EN.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Deploy on Atlas 800T A2
2
+
3
+ Please refer to the vllm-ascend community open-source [mirror](https://vllm-ascend.readthedocs.io/en/latest/quick_start.html)
4
+
5
+ ```bash
6
+ export IMAGE=m.daocloud.io/quay.io/ascend/vllm-ascend:v0.9.2rc1
7
+ export NAME=atlas_800t_a2
8
+ docker run --rm \
9
+ --name $NAME \
10
+ --net=host \
11
+ --device /dev/davinci0 \
12
+ --device /dev/davinci1 \
13
+ --device /dev/davinci2 \
14
+ --device /dev/davinci3 \
15
+ --device /dev/davinci4 \
16
+ --device /dev/davinci5 \
17
+ --device /dev/davinci6 \
18
+ --device /dev/davinci7 \
19
+ --device /dev/davinci_manager \
20
+ --device /dev/devmm_svm \
21
+ --device /dev/hisi_hdc \
22
+ -v /usr/local/dcmi:/usr/local/dcmi \
23
+ -v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
24
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
25
+ -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
26
+ -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
27
+ -v /etc/ascend_install.info:/etc/ascend_install.info \
28
+ -v /data:/data \
29
+ -v /tmp:/tmp \
30
+ -it $IMAGE bash
31
+ ```
doc/vllm_ascend_for_openpangu_ultra_moe_718b.md ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## openPangu-Ultra-MoE-718B在[vllm-ascend](https://github.com/vllm-project/vllm-ascend)部署指导文档
2
+
3
+ ### 部署环境说明
4
+
5
+ Atlas 800T A2(64GB) 64卡可以部署openPangu-Ultra-MoE-718B(bf16),32卡可部署盘古 Ultra MoE (int8),选用vllm-ascend社区镜像v0.9.1-dev,多个节点都需拉取镜像。
6
+ ```bash
7
+ docker pull quay.io/ascend/vllm-ascend:v0.9.1-dev
8
+ ```
9
+
10
+ * 网络环境检测
11
+ 在每个节点上依次执行以下命令。所有结果必须为 success 且状态必须为 UP:
12
+ ```bash
13
+ # Check the remote switch ports
14
+ for i in {0..7}; do hccn_tool -i $i -lldp -g | grep Ifname; done
15
+ # Get the link status of the Ethernet ports (UP or DOWN)
16
+ for i in {0..7}; do hccn_tool -i $i -link -g ; done
17
+ # Check the network health status
18
+ for i in {0..7}; do hccn_tool -i $i -net_health -g ; done
19
+ # View the network detected IP configuration
20
+ for i in {0..7}; do hccn_tool -i $i -netdetect -g ; done
21
+ # View gateway configuration
22
+ for i in {0..7}; do hccn_tool -i $i -gateway -g ; done
23
+ # View NPU network configuration
24
+ cat /etc/hccn.conf
25
+ ```
26
+
27
+ ### 镜像启动和推理代码适配
28
+
29
+ 以下操作需在每个节点都执行。
30
+
31
+ 启动镜像。
32
+ ```bash
33
+ # Update the vllm-ascend image
34
+ export IMAGE=quay.io/ascend/vllm-ascend:v0.9.1-dev # Use correct image id
35
+ export NAME=vllm-ascend # Custom docker name
36
+
37
+ # Run the container using the defined variables
38
+ # Note if you are running bridge network with docker, Please expose available ports for multiple nodes communication in advance
39
+ # To prevent device interference from other docker containers, add the argument "--privileged"
40
+ docker run --rm \
41
+ --name $NAME \
42
+ --network host \
43
+ --device /dev/davinci0 \
44
+ --device /dev/davinci1 \
45
+ --device /dev/davinci2 \
46
+ --device /dev/davinci3 \
47
+ --device /dev/davinci4 \
48
+ --device /dev/davinci5 \
49
+ --device /dev/davinci6 \
50
+ --device /dev/davinci7 \
51
+ --device /dev/davinci_manager \
52
+ --device /dev/devmm_svm \
53
+ --device /dev/hisi_hdc \
54
+ -v /usr/local/dcmi:/usr/local/dcmi \
55
+ -v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
56
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
57
+ -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
58
+ -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
59
+ -v /etc/ascend_install.info:/etc/ascend_install.info \
60
+ -v /mnt/sfs_turbo/.cache:/root/.cache \
61
+ -it $IMAGE bash
62
+ ```
63
+
64
+ 如果未进入容器,需以root用户进入容器。
65
+ ```
66
+ docker exec -itu root $NAME /bin/bash
67
+ ```
68
+
69
+ 下载vllm(v0.9.2),替换镜像内置的vllm代码。
70
+ ```bash
71
+ pip install --no-deps vllm==0.9.2 pybase64==1.4.1
72
+ ```
73
+
74
+ 下载[vllm-ascend (v0.9.2rc1)](https://github.com/vllm-project/vllm-ascend/releases/tag/v0.9.2rc1),替换镜像内置的vllm-ascend代码(`/vllm-workspace/vllm-ascend/`)。例如下载Assets中的[Source code
75
+ (tar.gz)](https://github.com/vllm-project/vllm-ascend/archive/refs/tags/v0.9.2rc1.tar.gz)得到v0.9.2rc1.tar.gz,然后解压并替换:
76
+ ```bash
77
+ tar -zxvf vllm-ascend-0.9.2rc1.tar.gz -C /vllm-workspace/vllm-ascend/ --strip-components=1
78
+ export PYTHONPATH=/vllm-workspace/vllm-ascend/:${PYTHONPATH}
79
+ ```
80
+
81
+ 使用当前代码仓中适配盘古模型的vllm-ascend代码替换`/vllm-workspace/vllm-ascend/vllm_ascend/`中的部分代码。
82
+ ```bash
83
+ yes | cp -r inference/vllm_ascend/* /vllm-workspace/vllm-ascend/vllm_ascend/
84
+ ```
85
+
86
+ ### BF16推理
87
+
88
+ 以下操作需在每个节点都执行。
89
+
90
+ 运行命令:
91
+ ```bash
92
+ # This obtained through ifconfig
93
+ # nic_name is the network interface name corresponding to local_ip
94
+ local_ip=`hostname -I | cut -d' ' -f1`
95
+ nic_name=$(ifconfig | grep -B 1 "$local_ip" | head -n 1 | awk '{print $1}' | sed 's/://')
96
+ export HCCL_IF_IP=$local_ip
97
+ export GLOO_SOCKET_IFNAME=$nic_name
98
+ export TP_SOCKET_IFNAME=$nic_name
99
+ export HCCL_SOCKET_IFNAME=$nic_name
100
+ export OMP_PROC_BIND=false
101
+ export OMP_NUM_THREADS=100
102
+ export VLLM_USE_V1=1
103
+ export HCCL_BUFFSIZE=1024
104
+ export VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP=1
105
+ export VLLM_ASCEND_ENABLE_TOP_N_SIGMA=1 # enable top-n-sigma sampling
106
+ export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
107
+
108
+ MASTER_NODE_IP=xxx.xxx.xxx.xxx # master/head node ip
109
+ NODE_RANK=xxx # current node rank (0~7)
110
+ NUM_NODES=8 # number of nodes
111
+ NUM_NPUS_LOCAL=8 # number of NPUs per node
112
+ DATA_PARALLEL_SIZE_LOCAL=4 # DP size per node, can be set to 1, 2, or 4
113
+ LOCAL_CKPT_DIR=/root/.cache/pangu_ultra_moe # The pangu_ultra_moe bf16 weight
114
+ # Specifying HOST=127.0.0.1 (localhost) means the server can only be accessed from the master device.
115
+ # Specifying HOST=0.0.0.0 allows the vLLM server to be accessed from other devices on the same network or even from the internet, provided proper network configuration (e.g., firewall rules, port forwarding) is in place.
116
+ HOST=xxx.xxx.xxx.xxx
117
+
118
+ if [[ $NODE_RANK -ne 0 ]]; then
119
+ headless="--headless"
120
+ else
121
+ headless=""
122
+ fi
123
+
124
+ vllm serve $LOCAL_CKPT_DIR \
125
+ --host $HOST \
126
+ --port 8004 \
127
+ --data-parallel-size $((NUM_NODES*DATA_PARALLEL_SIZE_LOCAL)) \
128
+ --data-parallel-size-local $DATA_PARALLEL_SIZE_LOCAL \
129
+ --data-parallel-start-rank $((DATA_PARALLEL_SIZE_LOCAL*NODE_RANK)) \
130
+ --data-parallel-address $MASTER_NODE_IP \
131
+ --data-parallel-rpc-port 13389 \
132
+ --tensor-parallel-size $((NUM_NPUS_LOCAL/DATA_PARALLEL_SIZE_LOCAL)) \
133
+ --seed 1024 \
134
+ --served-model-name pangu_ultra_moe \
135
+ --enable-expert-parallel \
136
+ --max-num-seqs 8 \
137
+ --max-model-len 32768 \
138
+ --max-num-batched-tokens 4096 \
139
+ --trust-remote-code \
140
+ --no-enable-prefix-caching \
141
+ --gpu-memory-utilization 0.9 \
142
+ ${headless} \
143
+ --additional-config '{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}'
144
+ ```
145
+
146
+ ### 发请求测试
147
+
148
+ 服务启动后,在主节点或者其他节点向主节点发送测试请求:
149
+
150
+ ```bash
151
+ curl http://${MASTER_NODE_IP}:8004/v1/chat/completions \
152
+ -H "Content-Type: application/json" \
153
+ -d '{
154
+ "model": "pangu_ultra_moe",
155
+ "messages": [
156
+ {
157
+ "role": "user",
158
+ "content": "Who are you?"
159
+ }
160
+ ],
161
+ "max_tokens": 512,
162
+ "temperature": 0.7,
163
+ "top_p": 1.0,
164
+ "top_k": -1,
165
+ "vllm_xargs": {"top_n_sigma": 0.05}
166
+ }'
167
+ ```
168
+
169
+ ### Int8推理
170
+
171
+ #### ModelSlim量化
172
+
173
+ openPangu-Ultra-MoE-718B模型支持使用开源量化框架[ModelSlim](https://gitcode.com/Ascend/msit/blob/br_noncom_pangu_ultra_moe_8.1.RC1_POC_20251231/msmodelslim/example/Pangu/README.md)进行量化,当前模型支持W8A8权重激活量化。
174
+
175
+ ##### openPangu-Ultra-MoE-718B W8A8 动态量化
176
+
177
+ ```bash
178
+ python3 quant_pangu_ultra_moe_w8a8.py --model_path {浮点权重路径} --save_path {W8A8量化权重路径} --dynamic
179
+ ```
180
+
181
+ ##### openPangu-Ultra-MoE-718B W8A8 混合量化 + MTP 量化
182
+
183
+ 生成openPangu-Ultra-MoE-718B模型W8A8量化权重(含MTP)
184
+ ```bash
185
+ python3 quant_pangu_ultra_moe_w8a8.py --model_path {浮点权重路径} --save_path {W8A8量化权重路径} --dynamic --quant_mtp mix
186
+ ```
187
+
188
+ 相较于BF16模型,int8量化模型的config.json增加以下字段:
189
+ ```
190
+ "mla_quantize": "w8a8",
191
+ "quantize": "w8a8_dynamic",
192
+ ```
193
+
194
+ 如果MTP量化,增加字段:
195
+ ```
196
+ "mtp_quantize": "w8a8_dynamic",
197
+ ```
198
+ ModelSlim量化脚本生成量化模型后会自动追加上述字段到config.json中。
199
+
200
+ #### Int8推理
201
+
202
+ 相较于BF16模型推理,int8量化模型推理仅需使用4节点(32卡),修改变量
203
+ ```bash
204
+ NUM_NODES=4
205
+ ```
206
+
207
+ 启动命令需要修改为对应的量化权重路径,另外增加`--quantization ascend`:
208
+ ```bash
209
+ LOCAL_CKPT_DIR=/root/.cache/pang_ultra_moe_w8a8
210
+
211
+ vllm serve $LOCAL_CKPT_DIR \
212
+ ...
213
+ --quantization ascend
214
+ ...
215
+ ```
doc/vllm_ascend_for_openpangu_ultra_moe_718b_EN.md ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Deployment Guide of the openPangu-Ultra-MoE-718B Based on [vllm-ascend](https://github.com/vllm-project/vllm-ascend)
2
+
3
+ ### Deployment Environment Description
4
+
5
+ The Atlas 800T A2 (64 GB) supports the deployment of the openPangu-Ultra-MoE-718B (bf16) with 64 cards and the deployment of the openPangu-Ultra-MoE-718B (int8) with 32 cards. The vllm-ascend community image v0.9.1-dev is used and needs to be pulled on multiple nodes.
6
+ ```bash
7
+ docker pull quay.io/ascend/vllm-ascend:v0.9.1-dev
8
+ ```
9
+
10
+ * Network Detection
11
+ Run the following commands on each node: All the results must be success and the status must be UP.
12
+ ```bash
13
+ # Check the remote switch ports
14
+ for i in {0..7}; do hccn_tool -i $i -lldp -g | grep Ifname; done
15
+ # Get the link status of the Ethernet ports (UP or DOWN)
16
+ for i in {0..7}; do hccn_tool -i $i -link -g ; done
17
+ # Check the network health status
18
+ for i in {0..7}; do hccn_tool -i $i -net_health -g ; done
19
+ # View the network detected IP configuration
20
+ for i in {0..7}; do hccn_tool -i $i -netdetect -g ; done
21
+ # View gateway configuration
22
+ for i in {0..7}; do hccn_tool -i $i -gateway -g ; done
23
+ # View NPU network configuration
24
+ cat /etc/hccn.conf
25
+ ```
26
+
27
+ ### Docker Startup and Inference Code Adaptation
28
+
29
+ Perform the following operations on all nodes.
30
+
31
+ Run the following command to start the docker:
32
+ ```bash
33
+ # Update the vllm-ascend image
34
+ export IMAGE=quay.io/ascend/vllm-ascend:v0.9.1-dev # Use correct image id
35
+ export NAME=vllm-ascend # Custom docker name
36
+
37
+ # Run the container using the defined variables
38
+ # Note if you are running bridge network with docker, Please expose available ports for multiple nodes communication in advance
39
+ # To prevent device interference from other docker containers, add the argument "--privileged"
40
+ docker run --rm \
41
+ --name $NAME \
42
+ --network host \
43
+ --device /dev/davinci0 \
44
+ --device /dev/davinci1 \
45
+ --device /dev/davinci2 \
46
+ --device /dev/davinci3 \
47
+ --device /dev/davinci4 \
48
+ --device /dev/davinci5 \
49
+ --device /dev/davinci6 \
50
+ --device /dev/davinci7 \
51
+ --device /dev/davinci_manager \
52
+ --device /dev/devmm_svm \
53
+ --device /dev/hisi_hdc \
54
+ -v /usr/local/dcmi:/usr/local/dcmi \
55
+ -v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
56
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
57
+ -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
58
+ -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
59
+ -v /etc/ascend_install.info:/etc/ascend_install.info \
60
+ -v /mnt/sfs_turbo/.cache:/root/.cache \
61
+ -it $IMAGE bash
62
+ ```
63
+
64
+ If not inside the container, enter the container as the root user:
65
+ ```
66
+ docker exec -itu root $NAME /bin/bash
67
+ ```
68
+
69
+ Download vllm (v0.9.2) to replace the built-in vllm code of the image.
70
+ ```bash
71
+ pip install --no-deps vllm==0.9.2 pybase64==1.4.1
72
+ ```
73
+
74
+ Download [vllm-ascend (v0.9.2rc1)](https://github.com/vllm-project/vllm-ascend/releases/tag/v0.9.2rc1) and replace the built-in vllm-ascend code in the image (/vllm-workspace/vllm-ascend/). For example, download [Source code (tar.gz)](https://github.com/vllm-project/vllm-ascend/archive/refs/tags/v0.9.2rc1.tar.gz) from Assets to get v0.9.2rc1.tar.gz, then extract and replace:
75
+
76
+ ```bash
77
+ tar -zxvf vllm-ascend-0.9.2rc1.tar.gz -C /vllm-workspace/vllm-ascend/ --strip-components=1
78
+ export PYTHONPATH=/vllm-workspace/vllm-ascend/:${PYTHONPATH}
79
+ ```
80
+
81
+ Use the Pangu model-adapted vllm-ascend code from the current repository to replace parts of the code in `/vllm-workspace/vllm-ascend/vllm_ascend/`:
82
+
83
+ ```bash
84
+ yes | cp -r inference/vllm_ascend/* /vllm-workspace/vllm-ascend/vllm_ascend/
85
+ ```
86
+
87
+ ### BF16 Inference
88
+
89
+ Perform the following operations on all nodes.
90
+
91
+ Run command:
92
+ ```bash
93
+ # This obtained through ifconfig
94
+ # nic_name is the network interface name corresponding to local_ip
95
+ local_ip=`hostname -I | cut -d' ' -f1`
96
+ nic_name=$(ifconfig | grep -B 1 "$local_ip" | head -n 1 | awk '{print $1}' | sed 's/://')
97
+ export HCCL_IF_IP=$local_ip
98
+ export GLOO_SOCKET_IFNAME=$nic_name
99
+ export TP_SOCKET_IFNAME=$nic_name
100
+ export HCCL_SOCKET_IFNAME=$nic_name
101
+ export OMP_PROC_BIND=false
102
+ export OMP_NUM_THREADS=100
103
+ export VLLM_USE_V1=1
104
+ export HCCL_BUFFSIZE=1024
105
+ export VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP=1
106
+ export VLLM_ASCEND_ENABLE_TOP_N_SIGMA=1 # enable top-n-sigma sampling
107
+ export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
108
+
109
+ MASTER_NODE_IP=xxx.xxx.xxx.xxx # master/head node ip
110
+ NODE_RANK=xxx # current node rank (0~7)
111
+ NUM_NODES=8 # number of nodes
112
+ NUM_NPUS_LOCAL=8 # number of NPUs per node
113
+ DATA_PARALLEL_SIZE_LOCAL=4 # DP size per node, can be set to 1, 2, or 4
114
+ LOCAL_CKPT_DIR=/root/.cache/pangu_ultra_moe # The pangu_ultra_moe bf16 weight
115
+ # Specifying HOST=127.0.0.1 (localhost) means the server can only be accessed from the master device.
116
+ # Specifying HOST=0.0.0.0 allows the vLLM server to be accessed from other devices on the same network or even from the internet, provided proper network configuration (e.g., firewall rules, port forwarding) is in place.
117
+ HOST=xxx.xxx.xxx.xxx
118
+
119
+ if [[ $NODE_RANK -ne 0 ]]; then
120
+ headless="--headless"
121
+ else
122
+ headless=""
123
+ fi
124
+
125
+ vllm serve $LOCAL_CKPT_DIR \
126
+ --host $HOST \
127
+ --port 8004 \
128
+ --data-parallel-size $((NUM_NODES*DATA_PARALLEL_SIZE_LOCAL)) \
129
+ --data-parallel-size-local $DATA_PARALLEL_SIZE_LOCAL \
130
+ --data-parallel-start-rank $((DATA_PARALLEL_SIZE_LOCAL*NODE_RANK)) \
131
+ --data-parallel-address $MASTER_NODE_IP \
132
+ --data-parallel-rpc-port 13389 \
133
+ --tensor-parallel-size $((NUM_NPUS_LOCAL/DATA_PARALLEL_SIZE_LOCAL)) \
134
+ --seed 1024 \
135
+ --served-model-name pangu_ultra_moe \
136
+ --enable-expert-parallel \
137
+ --max-num-seqs 8 \
138
+ --max-model-len 32768 \
139
+ --max-num-batched-tokens 4096 \
140
+ --trust-remote-code \
141
+ --no-enable-prefix-caching \
142
+ --gpu-memory-utilization 0.9 \
143
+ ${headless} \
144
+ --additional-config '{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}'
145
+ ```
146
+
147
+ ### Test Request
148
+
149
+ After server launched, send test request from master node or other nodes:
150
+
151
+ ```bash
152
+ curl http://${MASTER_NODE_IP}:8004/v1/chat/completions \
153
+ -H "Content-Type: application/json" \
154
+ -d '{
155
+ "model": "pangu_ultra_moe",
156
+ "messages": [
157
+ {
158
+ "role": "user",
159
+ "content": "Who are you?"
160
+ }
161
+ ],
162
+ "max_tokens": 512,
163
+ "temperature": 0.7,
164
+ "top_p": 1.0,
165
+ "top_k": -1,
166
+ "vllm_xargs": {"top_n_sigma": 0.05}
167
+ }'
168
+ ```
169
+
170
+ ### Int8 Inference
171
+
172
+ #### ModelSlim Quantization
173
+
174
+ The openPangu-Ultra-MoE-718B model supports quantization using the open source quantization framework [ModelSlim](https://gitcode.com/Ascend/msit/blob/br_noncom_pangu_ultra_moe_8.1.RC1_POC_20251231/msmodelslim/example/Pangu/README.md). The current model supports W8A8 weight activation quantization.
175
+
176
+ ##### openPangu-Ultra-MoE-718B W8A8 Dynamic quantization
177
+
178
+ ```bash
179
+ Python3 quant_pangu_ultra_moe_w8a8.py --model_path {bf16 weight path} --save_path {W8A8 weight path} --dynamic
180
+ ```
181
+
182
+ ##### openPangu-Ultra-MoE-718B W8A8 Hybrid quantization + MTP quantization
183
+
184
+ Generate the openPangu-Ultra-MoE-718B model W8A8 quantization weight (including MTP)
185
+ ```bash
186
+ python3 quant_pangu_ultra_moe_w8a8.py --model_path {bf16 weight path} --save_path {W8A8 weight path} --dynamic --quant_mtp mix
187
+ ```
188
+
189
+ Compared with the BF16 model, the following fields are added to the config.json file of the int8 quantization model:
190
+ ```
191
+ "mla_quantize": "w8a8",
192
+ "quantize": "w8a8_dynamic",
193
+ ```
194
+
195
+ If the MTP is included, the following fields are added:
196
+ ```
197
+ "mtp_quantize": "w8a8_dynamic",
198
+ ```
199
+ After the ModelSlim quantization script generates a quantization model, the preceding fields are automatically added to the config.json file.
200
+
201
+ #### Int8 Inference
202
+
203
+ Compared with the BF16 model inference, the int8 quantization model inference uses only four nodes (32 cards). Variables are modified.
204
+ ```bash
205
+ NUM_NODES=4
206
+ ```
207
+
208
+ The startup command needs to be changed to the corresponding quantization weight path, and add `--quantization ascend`:
209
+ ```bash
210
+ LOCAL_CKPT_DIR=/root/.cache/pangu_ultra_moe_w8a8
211
+
212
+ vllm serve $LOCAL_CKPT_DIR \
213
+ ...
214
+ --quantization ascend
215
+ ...
216
+ ```
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 45892,
5
+ "do_sample": true,
6
+ "temperature": 0.7,
7
+ "top_p": 1.0,
8
+ "top_n_sigma": 0.05,
9
+ "top_k": -1,
10
+ "transformers_version": "4.48.2"
11
+ }
inference/generate.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3
+
4
+ import argparse
5
+ import logging
6
+
7
+ import torch
8
+ import yaml
9
+ from model import PanguUltraMoEForCausalLM
10
+ from runner import ModelRunner
11
+
12
+ root_logger = logging.getLogger()
13
+ root_logger.handlers.clear()
14
+ logging.basicConfig(
15
+ format="%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s",
16
+ level=logging.INFO,
17
+ )
18
+ torch.manual_seed(42)
19
+ torch.npu.manual_seed_all(42)
20
+
21
+ """
22
+ NOTE:
23
+ For enhancing model safety, we recommend the following system prompt.
24
+ It is suggested to be removed for other normal use cases and model evaluation.
25
+ """
26
+ safe_word = "你必须严格遵守法律法规和社会道德规范。" \
27
+ "生成任何内容时,都应避免涉及暴力、色情、恐怖主义、种族歧视、性别歧视等不当内容。" \
28
+ "一旦检测到输入或输出有此类倾向,应拒绝回答并发出警告。例如,如果输入内容包含暴力威胁或色情描述," \
29
+ "应返回错误信息:“您的输入包含不当内容,无法处理。”"
30
+
31
+
32
+ # basic token generator
33
+ def generate_default_prompt(bs):
34
+ # prompt batch size define actual model forward batch size
35
+ fast_thinking_template = "[unused9]系统:%s[unused10][unused9]用户:{} /no_think[unused10][unused9]助手:" % (safe_word,)
36
+ slow_thinking_template = "[unused9]系统:%s[unused10][unused9]用户:{}[unused10][unused9]助手:" % (safe_word,)
37
+ preset_prompts = [slow_thinking_template.format(args.prompt)]
38
+ preset_prompts = preset_prompts * (bs // len(preset_prompts) + 1)
39
+ preset_prompts = preset_prompts[:bs]
40
+ logging.info(f"prompt batch size: {bs}")
41
+ return preset_prompts
42
+
43
+
44
+ def generate_chat_prompt(bs):
45
+ preset_prompts = [
46
+ {"role": "system", "content": safe_word},
47
+ {"role": "user", "content": args.prompt}
48
+ ]
49
+ preset_prompts = [preset_prompts] * (bs // len(preset_prompts) + 1)
50
+ preset_prompts = preset_prompts[:bs]
51
+ logging.info(f"chat prompt batch size: {bs}")
52
+ return preset_prompts
53
+
54
+
55
+ def generate_prompt(bs, tokenizer_mode):
56
+ if tokenizer_mode == "default":
57
+ return generate_default_prompt(bs)
58
+ else:
59
+ return generate_chat_prompt(bs)
60
+
61
+
62
+ def parse_args():
63
+ parser = argparse.ArgumentParser(description="llm run parameters")
64
+ parser.add_argument("--yaml_file_path", type=str, help="inference configurations")
65
+ parser.add_argument(
66
+ "--local_rank",
67
+ type=int,
68
+ default=0,
69
+ help="Local rank id for torch distributed launch",
70
+ )
71
+ parser.add_argument("--prompt", type=str, default="3*7=?", help="user prompts")
72
+ parser_args = parser.parse_args()
73
+ return parser_args
74
+
75
+
76
+ def main(runner_config):
77
+ bs = runner_config.get("data_config").get("batch_size", 1)
78
+ tokenizer_mode = runner_config.get("model_config").get("tokenizer_mode", "default")
79
+ preset_prompts = generate_prompt(bs, tokenizer_mode)
80
+ logging.info(f"input prompts: {preset_prompts}")
81
+ model_runner = ModelRunner(runner_config)
82
+ torch.npu.set_compile_mode(jit_compile=False)
83
+ model_runner.init_model(PanguUltraMoEForCausalLM)
84
+ # warmup
85
+ model_runner.model_generate(preset_prompts, warm_up=True)
86
+ # generate
87
+ model_runner.model_generate(preset_prompts)
88
+
89
+
90
+ def read_yaml(yaml_file_path):
91
+ try:
92
+ with open(yaml_file_path, "r", encoding="utf-8") as file:
93
+ data = yaml.safe_load(file)
94
+ except FileNotFoundError:
95
+ logging.error(f"No such yaml file: {yaml_file_path}")
96
+ except yaml.YAMLERROR as e:
97
+ logging.error(f"Load yaml file failed: {e}")
98
+ return data
99
+
100
+
101
+ if __name__ == "__main__":
102
+ args = parse_args()
103
+ yaml_file_path = args.yaml_file_path
104
+ runner_config = read_yaml(yaml_file_path)
105
+ main(runner_config)
106
+ logging.info("model run success")
inference/generate.sh ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3
+ #
4
+ #!/bin/bash
5
+
6
+ # use example:
7
+ # bash generate.sh ${NNODES} ${NODE_RANK} ${NPROC_PER_NODE} ${MASTER_ADDR} ${PROMPTS}
8
+
9
+ # input args
10
+ export NNODES=$1
11
+ export NODE_RANK=$2
12
+ export NPROC_PER_NODE=$3
13
+ export MASTER_ADDR=$4 # master node IP
14
+ export prompt=$5
15
+ export MASTER_PORT=6038 # master node port
16
+ export WORLD_SIZE=32
17
+ export YAML=runner_config/tp32.yaml
18
+ export RANK_OFFSET=`expr $NODE_RANK \* ${NPROC_PER_NODE}`
19
+
20
+ # setup env
21
+ export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
22
+ export HCCL_SOCKET_IFNAME=enp # network card prefix. specify it according to the actual situation e.g enp/eth
23
+ export HCCL_IF_IP=`hostname -I|awk -F " " '{print$1}'` # get current node IP
24
+ export HCCL_IF_BASE_PORT=23456
25
+ export HCCL_OP_EXPANSION_MODE=AIV
26
+ export HCCL_CONNECT_TIMEOUT=1200
27
+ export HCCL_EXEC_TIMEOUT=1200
28
+ if [[ -d "/usr/local/Ascend/ascend-toolkit/latest" ]]; then
29
+ export ASCEND_HOME_PATH=/usr/local/Ascend/ascend-toolkit/latest
30
+ else
31
+ export ASCEND_HOME_PATH=/usr/local/Ascend/latest
32
+ fi
33
+ export PYTHONPATH=${PYTHONPATH}:${ASCEND_HOME_PATH}/python/site-packages/
34
+
35
+ # set result path
36
+ DATE=`date +%Y%m%d`
37
+ export MODEL_NAME="pangu_ultra_moe"
38
+ NAME=${MODEL_NAME}_${WORLD_SIZE}p
39
+ export TASK_QUEUE_ENABLE=2 # eager mode:opt host perf
40
+ export RES_PATH="res/${DATE}/${NAME}"
41
+ WORK_DIR=`pwd`
42
+ DUMP_PRECISION_PATH=${WORK_DIR}'/'${RES_PATH}'/dump_data'
43
+ mkdir -p ${WORK_DIR}'/'${RES_PATH}
44
+ mkdir -p ${DUMP_PRECISION_PATH}
45
+
46
+ # launch multi proc
47
+ cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
48
+ avg_core_per_rank=`expr $cores \/ $NPROC_PER_NODE`
49
+ core_gap=`expr $avg_core_per_rank \- 1`
50
+ for((i=0; i<${NPROC_PER_NODE}; i++))
51
+ do
52
+ echo $i
53
+ start=`expr $i \* $avg_core_per_rank`
54
+ end=`expr $start \+ $core_gap`
55
+ cmdopt=$start"-"$end
56
+ export LOCAL_RANK=$i
57
+ export RANK=$(expr $i + $RANK_OFFSET)
58
+ export RANK_ID=$RANK
59
+ if [ $i -eq 0 ];then
60
+ taskset -c $cmdopt python3 generate.py \
61
+ --prompt "$prompt" \
62
+ --yaml_file_path=${YAML} 2>&1 | tee ${WORK_DIR}/${RES_PATH}/log_${LOCAL_RANK}.log &
63
+ else
64
+ taskset -c $cmdopt python3 generate.py \
65
+ --prompt "$prompt" \
66
+ --yaml_file_path=${YAML} &> ${WORK_DIR}/${RES_PATH}/log_${LOCAL_RANK}.log &
67
+ fi
68
+ done
69
+
70
+ wait
inference/model.py ADDED
@@ -0,0 +1,918 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ import sys
19
+ import warnings
20
+ from typing import Dict, List, Optional, Tuple
21
+
22
+ import torch
23
+ import torch.distributed as dist
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ import torch_npu
27
+ from torch import nn
28
+ from torch.distributed.distributed_c10d import _world
29
+ from transformers.activations import ACT2FN
30
+ from transformers.cache_utils import Cache
31
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
32
+ from transformers.modeling_utils import PreTrainedModel
33
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
34
+ from transformers.utils.import_utils import is_torch_fx_available
35
+
36
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
37
+ from configuration_openpangu_moe import PanguUltraMoEConfig
38
+
39
+ if is_torch_fx_available():
40
+ if not is_torch_greater_or_equal_than_1_13:
41
+ import torch.fx
42
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
43
+
44
+
45
+ class PanguUltraMoERMSNorm(nn.Module):
46
+ def __init__(self, hidden_dim, epsilon=1e-5):
47
+ super().__init__()
48
+ self.weight = nn.Parameter(torch.empty(hidden_dim))
49
+ self.epsilon = epsilon
50
+
51
+ def forward(self, hidden_states, *args):
52
+ if len(args) == 0:
53
+ result = torch_npu.npu_rms_norm(hidden_states, self.weight, self.epsilon)[0]
54
+ return result
55
+ elif len(args) == 1 and args[0] is None:
56
+ result = torch_npu.npu_rms_norm(hidden_states, self.weight, self.epsilon)[0]
57
+ residual = hidden_states
58
+ return (result, residual)
59
+ elif len(args) == 1:
60
+ residual = args[0]
61
+ y, _, x = torch_npu.npu_add_rms_norm(
62
+ residual, hidden_states, self.weight, self.epsilon
63
+ )
64
+ return (y, x)
65
+ else:
66
+ raise NotImplementedError(f"PanguUltraMoERMSNorm inner error")
67
+
68
+
69
+ class PanguUltraMoERotaryEmbedding(nn.Module):
70
+ def __init__(
71
+ self, dim, max_position_embeddings=131072, base=25600000.0, device=None
72
+ ):
73
+ super().__init__()
74
+ self.dim = dim
75
+ self.max_position_embeddings = max_position_embeddings
76
+ self.base = base
77
+ self._set_cache(
78
+ seq_len=max_position_embeddings,
79
+ device=device,
80
+ dtype=torch.get_default_dtype(),
81
+ )
82
+
83
+ def _set_cache(self, seq_len, device, dtype):
84
+ self.max_seq_len_cached = seq_len
85
+ dim = self.dim
86
+
87
+ inv_freq = 1.0 / (
88
+ self.base
89
+ ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
90
+ )
91
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
92
+
93
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
94
+
95
+ freqs = torch.outer(t, inv_freq)
96
+
97
+ emb = torch.cat((freqs, freqs), dim=-1)
98
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
99
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
100
+
101
+ def forward(self, x, kv_len, max_seq_len=None):
102
+ if max_seq_len is None:
103
+ self._set_cache(seq_len=kv_len, device=x.device, dtype=x.dtype)
104
+ elif max_seq_len > self.max_seq_len_cached:
105
+ self._set_cache(seq_len=max_seq_len, device=x.device, dtype=x.dtype)
106
+
107
+ batch_size = x.shape[0]
108
+ seq_len = x.shape[1]
109
+ if seq_len == 1:
110
+ cos = (
111
+ torch.index_select(self.cos_cached, dim=0, index=kv_len)
112
+ .unsqueeze(1)
113
+ .unsqueeze(1)
114
+ )
115
+ sin = (
116
+ torch.index_select(self.sin_cached, dim=0, index=kv_len)
117
+ .unsqueeze(1)
118
+ .unsqueeze(1)
119
+ )
120
+ else:
121
+ cos = (
122
+ self.cos_cached[:seq_len]
123
+ .unsqueeze(0)
124
+ .unsqueeze(2)
125
+ .repeat(batch_size, 1, 1, 1)
126
+ )
127
+ sin = (
128
+ self.sin_cached[:seq_len]
129
+ .unsqueeze(0)
130
+ .unsqueeze(2)
131
+ .repeat(batch_size, 1, 1, 1)
132
+ )
133
+
134
+ cos = cos[0, :, 0, :]
135
+ sin = sin[0, :, 0, :]
136
+ return (
137
+ cos.to(dtype=x.dtype),
138
+ sin.to(dtype=x.dtype),
139
+ )
140
+
141
+
142
+ def rotate_half(x):
143
+ """Rotates half the hidden dims of the input."""
144
+ x1 = x[..., : x.shape[-1] // 2]
145
+ x2 = x[..., x.shape[-1] // 2 :]
146
+ return torch.cat((-x2, x1), dim=-1)
147
+
148
+
149
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
150
+ """Applies Rotary Position Embedding to the query and key tensors.
151
+
152
+ Args:
153
+ q (`torch.Tensor`): The query tensor.
154
+ k (`torch.Tensor`): The key tensor.
155
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
156
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
157
+ position_ids (`torch.Tensor`):
158
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
159
+ used to pass offsetted position ids when working with a KV-cache.
160
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
161
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
162
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
163
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
164
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
165
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
166
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
167
+ Returns:
168
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
169
+ """
170
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
171
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
172
+
173
+ b, h, s, d = q.shape
174
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
175
+
176
+ b, h, s, d = k.shape
177
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
178
+
179
+ q_embed = (q * cos) + (rotate_half(q) * sin)
180
+ k_embed = (k * cos) + (rotate_half(k) * sin)
181
+ return q_embed, k_embed
182
+
183
+
184
+ class MLP(nn.Module):
185
+ def __init__(self, config, runner_config, hidden_size=None, intermediate_size=None):
186
+ super().__init__()
187
+ self.runner_config = runner_config
188
+ self.moe_tp_size = self.runner_config.get("parallel_config").get(
189
+ "moe_tp_size", 1
190
+ )
191
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
192
+ self.intermediate_size = (
193
+ config.intermediate_size if intermediate_size is None else intermediate_size
194
+ )
195
+
196
+ self.intermediate_size_per_rank = self.intermediate_size // self.moe_tp_size
197
+ self.merge_up_gate_proj = nn.Linear(
198
+ self.hidden_size, self.intermediate_size_per_rank * 2, bias=False
199
+ )
200
+ self.down_proj = nn.Linear(
201
+ self.intermediate_size_per_rank, self.hidden_size, bias=False
202
+ )
203
+ self.act_fn = ACT2FN[config.hidden_act]
204
+
205
+ def forward(self, x):
206
+ merged_x = self.merge_up_gate_proj(x)
207
+ gate_state, up_state = merged_x.chunk(2, dim=-1)
208
+ intermediate_hidden_states = self.act_fn(gate_state) * up_state
209
+ down_proj = self.down_proj(intermediate_hidden_states)
210
+ if self.moe_tp_size > 1:
211
+ dist.all_reduce(down_proj)
212
+ return down_proj
213
+
214
+
215
+ class MoE(nn.Module):
216
+ def __init__(self, config, runner_config, hidden_size=None, intermediate_size=None):
217
+ super().__init__()
218
+ self.runner_config = runner_config
219
+ self.moe_tp_size = self.runner_config.get("parallel_config").get(
220
+ "moe_tp_size", 1
221
+ )
222
+ self.num_experts = config.num_routed_experts
223
+
224
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
225
+ self.intermediate_size = (
226
+ config.intermediate_size if intermediate_size is None else intermediate_size
227
+ )
228
+ self.intermediate_size_per_rank = self.intermediate_size // self.moe_tp_size
229
+
230
+ self.act_fn = ACT2FN[config.hidden_act]
231
+
232
+ self.group_w1_w3 = nn.Parameter(
233
+ torch.ones(
234
+ self.num_experts, self.intermediate_size_per_rank * 2, self.hidden_size
235
+ ),
236
+ requires_grad=False,
237
+ )
238
+ self.group_w2 = nn.Parameter(
239
+ torch.ones(
240
+ self.num_experts, self.hidden_size, self.intermediate_size_per_rank
241
+ ),
242
+ requires_grad=False,
243
+ )
244
+
245
+ def forward(self, hidden_states, expert_tokens, seq_len=None):
246
+ mm1_mm3 = torch_npu.npu_grouped_matmul(
247
+ [hidden_states],
248
+ [torch.transpose(self.group_w1_w3, 1, 2)],
249
+ group_list=expert_tokens,
250
+ group_type=0,
251
+ split_item=3,
252
+ )[0]
253
+ mm1, mm3 = mm1_mm3.chunk(2, dim=-1)
254
+ intermediate_hidden_states = self.act_fn(mm1) * mm3
255
+ hidden_states = torch_npu.npu_grouped_matmul(
256
+ [intermediate_hidden_states],
257
+ [torch.transpose(self.group_w2, 1, 2)],
258
+ group_list=expert_tokens,
259
+ group_type=0,
260
+ split_item=3,
261
+ )[0]
262
+ return hidden_states
263
+
264
+
265
+ class MoEGate(nn.Module):
266
+ def __init__(self, config):
267
+ super().__init__()
268
+ self.top_k = config.num_experts_per_tok
269
+ self.routed_scaling_factor = config.routed_scaling_factor
270
+
271
+ self.norm_topk_prob = config.norm_topk_prob
272
+ self.weight = nn.Parameter(
273
+ torch.empty((config.num_routed_experts, config.hidden_size))
274
+ )
275
+
276
+ def forward(self, hidden_states):
277
+ bsz, seq_len, h = hidden_states.shape
278
+ hidden_states = hidden_states.view(-1, h)
279
+ logits = F.linear(
280
+ hidden_states.to(torch.float32), self.weight.to(torch.float32), None
281
+ )
282
+ scores = logits.sigmoid()
283
+ scores_for_choice = scores.view(bsz * seq_len, -1)
284
+ _, topk_idx = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)
285
+ topk_weight = scores.gather(1, topk_idx)
286
+
287
+ if self.top_k > 1 and self.norm_topk_prob:
288
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
289
+ topk_weight = topk_weight / denominator
290
+ topk_weight = topk_weight * self.routed_scaling_factor
291
+
292
+ return topk_idx, topk_weight
293
+
294
+
295
+ class PanguUltraMoE(nn.Module):
296
+ def __init__(self, config, runner_config):
297
+ super().__init__()
298
+ self.runner_config = runner_config
299
+ self.hidden_dim = config.hidden_size
300
+ self.moe_tp_size = self.runner_config.get("parallel_config").get(
301
+ "moe_tp_size", 1
302
+ )
303
+ self.batch_size_decode = self.runner_config.get("data_config").get(
304
+ "batch_size", 1
305
+ )
306
+ self.batch_size_prefill = self.batch_size_decode
307
+ self.num_experts_per_tok = config.num_experts_per_tok
308
+ self.num_experts = config.num_routed_experts
309
+ self.num_shared_experts = config.num_shared_experts
310
+ self.top_k = config.num_experts_per_tok
311
+
312
+ self.experts_per_rank = config.num_routed_experts
313
+ self.experts = MoE(
314
+ config, self.runner_config, intermediate_size=config.moe_intermediate_size
315
+ )
316
+
317
+ self.gate = MoEGate(config)
318
+ if self.num_shared_experts is not None:
319
+ intermediate_size = config.moe_intermediate_size * self.num_shared_experts
320
+ self.shared_experts = MLP(
321
+ config, self.runner_config, intermediate_size=intermediate_size
322
+ )
323
+ self.row_idx_decode_len = self.batch_size_decode * self.top_k
324
+ self.row_idx_decode = (
325
+ torch.arange(0, self.row_idx_decode_len, dtype=torch.int32)
326
+ .view(self.top_k, -1)
327
+ .permute(1, 0)
328
+ .int()
329
+ .contiguous()
330
+ .npu()
331
+ )
332
+
333
+ def forward(self, hidden_states):
334
+ identity = hidden_states
335
+ topk_idx, topk_weight = self.gate(hidden_states)
336
+ y = self.moe_npu(hidden_states, topk_idx, topk_weight)
337
+ if self.num_shared_experts is not None:
338
+ y = y + self.shared_experts(identity)
339
+ return y
340
+
341
+ def moe_npu(self, x, topk_ids, topk_weight):
342
+ batch_size, sequence_length, h = x.shape
343
+ hidden_states = x.view(-1, x.shape[-1])
344
+
345
+ routing_weights = topk_weight.to(x.dtype)
346
+ expert_idx = topk_ids.int()
347
+ if sequence_length == 1:
348
+ row_idx = self.row_idx_decode
349
+ else:
350
+ row_idx_prefill_len = self.batch_size_prefill * sequence_length * self.top_k
351
+ row_idx = (
352
+ torch.arange(
353
+ 0, row_idx_prefill_len, dtype=torch.int32, device=topk_weight.device
354
+ )
355
+ .view(self.top_k, -1)
356
+ .permute(1, 0)
357
+ .int()
358
+ .contiguous()
359
+ )
360
+
361
+ active_num = batch_size * sequence_length
362
+ expanded_x, expanded_row_idx, expanded_expert_idx = (
363
+ torch_npu.npu_moe_init_routing(
364
+ hidden_states,
365
+ row_idx=row_idx,
366
+ expert_idx=expert_idx,
367
+ active_num=active_num,
368
+ )
369
+ )
370
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
371
+ expanded_expert_idx, self.num_experts
372
+ )
373
+ expert_tokens = expert_tokens.to(torch.int64)
374
+
375
+ hidden_states_ordered_by_experts = self.experts(
376
+ expanded_x, expert_tokens, seq_len=sequence_length
377
+ )
378
+
379
+ hidden_states = torch_npu.npu_moe_finalize_routing(
380
+ hidden_states_ordered_by_experts,
381
+ skip1=None,
382
+ skip2=None,
383
+ bias=None,
384
+ scales=routing_weights,
385
+ expanded_src_to_dst_row=expanded_row_idx,
386
+ export_for_source_row=expert_idx,
387
+ )
388
+ if self.moe_tp_size > 1:
389
+ dist.all_reduce(hidden_states)
390
+ hidden_states = hidden_states.view(batch_size, -1, self.hidden_dim)
391
+ return hidden_states
392
+
393
+
394
+ class PanguUltraMoEAttention(nn.Module):
395
+ def __init__(
396
+ self,
397
+ config: PanguUltraMoEConfig,
398
+ layer_idx: Optional[int] = None,
399
+ runner_config: Optional[Dict] = None,
400
+ ):
401
+ super().__init__()
402
+ if runner_config is not None:
403
+ self.attn_tp_size = runner_config.get("parallel_config").get(
404
+ "attn_tp_size", 1
405
+ )
406
+ else:
407
+ self.attn_tp_size = 1
408
+ self.layer_idx = layer_idx
409
+
410
+ self.hidden_size = config.hidden_size
411
+ self.num_heads = config.num_attention_heads
412
+ self.num_heads_per_rank = self.num_heads // self.attn_tp_size
413
+ self.num_key_value_heads_per_rank = self.num_heads_per_rank
414
+
415
+ self.max_position_embeddings = config.max_position_embeddings
416
+ self.rope_theta = config.rope_theta
417
+ self.attention_q_lora_dim = config.attention_q_lora_dim
418
+ self.attention_qk_rope_dim = config.attention_qk_rope_dim
419
+ self.attention_kv_lora_dim = config.attention_kv_lora_dim
420
+ self.attention_v_dim = config.attention_v_dim
421
+ self.attention_qk_dim = config.attention_qk_dim
422
+ self.q_head_dim = config.attention_qk_dim + config.attention_qk_rope_dim
423
+
424
+ if self.attention_q_lora_dim is None:
425
+ self.q_proj = nn.Linear(
426
+ self.hidden_size, self.num_heads_per_rank * self.q_head_dim, bias=False
427
+ )
428
+ else:
429
+ self.q_a_proj = nn.Linear(
430
+ self.hidden_size, config.attention_q_lora_dim, bias=False
431
+ )
432
+ self.q_a_layernorm = PanguUltraMoERMSNorm(config.attention_q_lora_dim)
433
+ self.q_b_proj = nn.Linear(
434
+ config.attention_q_lora_dim,
435
+ self.num_heads_per_rank * self.q_head_dim,
436
+ bias=False,
437
+ )
438
+
439
+ self.kv_a_proj_with_mqa = nn.Linear(
440
+ self.hidden_size,
441
+ config.attention_kv_lora_dim + config.attention_qk_rope_dim,
442
+ bias=False,
443
+ )
444
+ self.kv_a_layernorm = PanguUltraMoERMSNorm(config.attention_kv_lora_dim)
445
+
446
+ self.kv_b_proj_w_k = nn.Parameter(
447
+ torch.zeros(
448
+ self.num_heads_per_rank,
449
+ self.attention_qk_dim,
450
+ self.attention_kv_lora_dim,
451
+ )
452
+ )
453
+ self.kv_b_proj_w_v = nn.Parameter(
454
+ torch.zeros(
455
+ self.num_heads_per_rank,
456
+ self.attention_kv_lora_dim,
457
+ self.attention_v_dim,
458
+ )
459
+ )
460
+
461
+ self.o_proj = nn.Linear(
462
+ self.num_heads_per_rank * self.attention_v_dim,
463
+ self.hidden_size,
464
+ bias=False,
465
+ )
466
+
467
+ self.softmax_scale = self.q_head_dim ** (-0.5)
468
+
469
+ def bmm_5d(self, x, y):
470
+ b, s, n, _, d = x.shape
471
+ x = x.view(b * s, n, d).transpose(0, 1)
472
+ output = torch.matmul(x, y)
473
+ output = output.transpose(1, 0).view(b, s, n, -1)
474
+ return output
475
+
476
+ def prepare_qkv(
477
+ self,
478
+ hidden_states: torch.Tensor,
479
+ cos_sin: torch.Tensor = None,
480
+ kv_len: torch.IntTensor = None,
481
+ position_ids: Optional[torch.LongTensor] = None,
482
+ past_key_value: Optional[Cache] = None,
483
+ **kwargs,
484
+ ):
485
+ bsz, q_len, _ = hidden_states.size()
486
+
487
+ if self.attention_q_lora_dim is None:
488
+ q = self.q_proj(hidden_states)
489
+ else:
490
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
491
+
492
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
493
+ compressed_kv, k_pe = torch.split(
494
+ compressed_kv,
495
+ [self.attention_kv_lora_dim, self.attention_qk_rope_dim],
496
+ dim=-1,
497
+ )
498
+
499
+ q = q.view(bsz, q_len, self.num_heads_per_rank, self.q_head_dim)
500
+ q_nope, q_pe = torch.split(
501
+ q, [self.attention_qk_dim, self.attention_qk_rope_dim], dim=-1
502
+ )
503
+ q_pe = q_pe.transpose(1, 2)
504
+ q_nope = self.bmm_5d(
505
+ q_nope.view(bsz, q_len, self.num_heads_per_rank, 1, self.attention_qk_dim),
506
+ self.kv_b_proj_w_k,
507
+ )
508
+ q_nope = q_nope.view(
509
+ bsz, q_len, self.num_heads_per_rank, self.attention_kv_lora_dim
510
+ )
511
+ q_nope = q_nope.transpose(1, 2)
512
+
513
+ k_pe = k_pe.view(bsz, q_len, 1, self.attention_qk_rope_dim).transpose(1, 2)
514
+ k_nope = (
515
+ self.kv_a_layernorm(compressed_kv)
516
+ .view(bsz, -1, 1, self.attention_kv_lora_dim)
517
+ .transpose(1, 2)
518
+ )
519
+
520
+ cos, sin = cos_sin
521
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
522
+
523
+ query_states = torch.cat([q_nope, q_pe], dim=-1)
524
+ key_states = torch.cat([k_nope, k_pe], dim=-1)
525
+
526
+ kv_seq_len = k_nope.shape[-2]
527
+ if past_key_value is not None:
528
+ past_key_states = past_key_value[self.layer_idx][0]
529
+ torch_npu.scatter_update_(past_key_states, kv_len, key_states, -2)
530
+ if q_len == 1:
531
+ key_states = past_key_states
532
+ kv_seq_len = past_key_value[0][0].size()[-2]
533
+ value_states = key_states
534
+ return query_states, key_states, value_states, kv_seq_len
535
+
536
+ def apply_attention_npu(
537
+ self,
538
+ query_states,
539
+ key_states,
540
+ value_states,
541
+ kv_seq_len,
542
+ attention_mask: Optional[torch.Tensor] = None,
543
+ actual_seq_lengths_kv: list = None,
544
+ output_attentions: bool = False,
545
+ past_key_value: Optional[Cache] = None,
546
+ ):
547
+ # repeat k/v heads if n_kv_heads < n_heads
548
+ bsz, _, q_len, _ = query_states.size()
549
+ attn_weights = (
550
+ torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
551
+ )
552
+ if attention_mask is not None:
553
+ attn_weights = attn_weights + attention_mask
554
+ else:
555
+ raise ValueError("attention mask must not be None")
556
+
557
+ attn_weights = nn.functional.softmax(
558
+ attn_weights, dim=-1, dtype=torch.float32
559
+ ).to(query_states.dtype)
560
+ value_states = value_states[..., : self.attention_kv_lora_dim]
561
+ attn_output = torch.matmul(attn_weights, value_states)
562
+
563
+ attn_output = attn_output.transpose(1, 2).contiguous()
564
+ attn_output = self.bmm_5d(attn_output.unsqueeze(3), self.kv_b_proj_w_v)
565
+ attn_output = self.o_proj(attn_output.reshape(bsz, q_len, -1))
566
+ if self.attn_tp_size > 1:
567
+ dist.all_reduce(attn_output)
568
+ return attn_output
569
+
570
+ def forward(
571
+ self,
572
+ hidden_states: torch.Tensor,
573
+ kv_len: torch.IntTensor = None,
574
+ actual_seq_lengths_kv: list = None,
575
+ cos_sin: torch.Tensor = None,
576
+ attention_mask: Optional[torch.Tensor] = None,
577
+ position_ids: Optional[torch.LongTensor] = None,
578
+ past_key_value: Optional[Cache] = None,
579
+ output_attentions: bool = False,
580
+ use_cache: bool = False,
581
+ **kwargs,
582
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
583
+ if "padding_mask" in kwargs:
584
+ warnings.warn(
585
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
586
+ )
587
+ query_states, key_states, value_states, kv_seq_len = self.prepare_qkv(
588
+ hidden_states=hidden_states,
589
+ cos_sin=cos_sin,
590
+ kv_len=kv_len,
591
+ position_ids=position_ids,
592
+ past_key_value=past_key_value,
593
+ )
594
+ output = self.apply_attention_npu(
595
+ query_states=query_states,
596
+ key_states=key_states,
597
+ value_states=value_states,
598
+ kv_seq_len=kv_seq_len,
599
+ actual_seq_lengths_kv=actual_seq_lengths_kv,
600
+ attention_mask=attention_mask,
601
+ output_attentions=output_attentions,
602
+ past_key_value=past_key_value,
603
+ )
604
+ return output
605
+
606
+
607
+ class PanguUltraMoEDecoderLayer(nn.Module):
608
+ def __init__(
609
+ self, config: PanguUltraMoEConfig, runner_config: Dict, layer_idx: int
610
+ ):
611
+ super().__init__()
612
+ self.runner_config = runner_config
613
+ self.hidden_size = config.hidden_size
614
+
615
+ self.self_attn = PanguUltraMoEAttention(
616
+ config=config, runner_config=self.runner_config, layer_idx=layer_idx
617
+ )
618
+
619
+ self.mlp = (
620
+ PanguUltraMoE(config, self.runner_config)
621
+ if (
622
+ config.num_routed_experts is not None
623
+ and layer_idx >= config.num_dense_layers
624
+ )
625
+ else MLP(config, self.runner_config)
626
+ )
627
+ self.input_layernorm = PanguUltraMoERMSNorm(
628
+ config.hidden_size, epsilon=config.rms_norm_eps
629
+ )
630
+ self.post_attention_layernorm = PanguUltraMoERMSNorm(
631
+ config.hidden_size, epsilon=config.rms_norm_eps
632
+ )
633
+ if getattr(config, "sandwich_norm", False):
634
+ self.sandwich_norm = True
635
+ self.pre_mlp_layernorm = PanguUltraMoERMSNorm(
636
+ config.hidden_size, epsilon=config.rms_norm_eps
637
+ )
638
+ self.post_mlp_layernorm = PanguUltraMoERMSNorm(
639
+ config.hidden_size, epsilon=config.rms_norm_eps
640
+ )
641
+ else:
642
+ self.sandwich_norm = False
643
+
644
+ def forward(
645
+ self,
646
+ hidden_states: torch.Tensor,
647
+ kv_len: torch.IntTensor,
648
+ actual_seq_lengths_kv: list,
649
+ cos_sin: torch.Tensor,
650
+ past_residual: Optional[torch.Tensor] = None,
651
+ attention_mask: Optional[torch.Tensor] = None,
652
+ position_ids: Optional[torch.LongTensor] = None,
653
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
654
+ **kwargs,
655
+ ) -> Tuple[torch.FloatTensor]:
656
+ hidden_states, residual = self.input_layernorm(hidden_states, past_residual)
657
+
658
+ # Self Attention
659
+ hidden_states = self.self_attn(
660
+ hidden_states=hidden_states,
661
+ kv_len=kv_len,
662
+ actual_seq_lengths_kv=actual_seq_lengths_kv,
663
+ cos_sin=cos_sin,
664
+ attention_mask=attention_mask,
665
+ position_ids=position_ids,
666
+ past_key_value=past_key_value,
667
+ )
668
+
669
+ if self.sandwich_norm:
670
+ hidden_states = self.post_attention_layernorm(hidden_states)
671
+ hidden_states, residual = self.pre_mlp_layernorm(hidden_states, residual)
672
+ else:
673
+ hidden_states, residual = self.post_attention_layernorm(
674
+ hidden_states, residual
675
+ )
676
+
677
+ hidden_states = self.mlp(hidden_states)
678
+
679
+ if self.sandwich_norm:
680
+ hidden_states = self.post_mlp_layernorm(hidden_states)
681
+
682
+ outputs = (residual, hidden_states)
683
+ return outputs
684
+
685
+
686
+ class PanguUltraMoEPreTrainedModel(PreTrainedModel):
687
+ config_class = PanguUltraMoEConfig
688
+ base_model_prefix = "model"
689
+ supports_gradient_checkpointing = True
690
+ _no_split_modules = ["PanguUltraMoEDecoderLayer"]
691
+ _skip_keys_device_placement = "past_key_values"
692
+ _supports_cache_class = True
693
+
694
+ def _init_weights(self, module):
695
+ pass
696
+
697
+
698
+ class PanguUltraMoEModel(PanguUltraMoEPreTrainedModel):
699
+ def __init__(self, config: PanguUltraMoEConfig, runner_config: Dict):
700
+ super().__init__(config)
701
+ self.config = config
702
+ self.runner_config = runner_config
703
+ self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
704
+ self.rank_offset = int(os.getenv("RANK_OFFSET", "0"))
705
+ self.global_rank = self.local_rank + self.rank_offset
706
+ self.embed_tp_size = self.runner_config.get("parallel_config").get(
707
+ "embed_tp_size", 1
708
+ )
709
+ self.padding_idx = config.pad_token_id
710
+ self.vocab_size = config.vocab_size
711
+ self.vocab_size_per_rank = self.vocab_size // self.embed_tp_size
712
+
713
+ self.embed_tokens = nn.Embedding(
714
+ self.vocab_size_per_rank, config.hidden_size, self.padding_idx
715
+ )
716
+ self.layers = nn.ModuleList(
717
+ [
718
+ PanguUltraMoEDecoderLayer(config, self.runner_config, layer_idx)
719
+ for layer_idx in range(config.num_hidden_layers)
720
+ ]
721
+ )
722
+ self.norm = PanguUltraMoERMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
723
+ self.gradient_checkpointing = False
724
+ # Initialize weights and apply final processing
725
+ self.post_init()
726
+ self.rotary_emb = PanguUltraMoERotaryEmbedding(
727
+ self.config.attention_qk_rope_dim,
728
+ max_position_embeddings=self.config.max_position_embeddings,
729
+ base=self.config.rope_theta,
730
+ )
731
+
732
+ def forward(
733
+ self,
734
+ input_ids: torch.LongTensor,
735
+ kv_len: torch.IntTensor = None,
736
+ actual_seq_lengths_kv: list = None,
737
+ attention_mask: Optional[torch.Tensor] = None,
738
+ position_ids: Optional[torch.LongTensor] = None,
739
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
740
+ ):
741
+
742
+ batch_size, seq_length = input_ids.shape
743
+ past_key_values_length = past_key_values[0][0].size()[-2]
744
+
745
+ if position_ids is None:
746
+ device = input_ids.device
747
+ position_ids = torch.arange(
748
+ past_key_values_length,
749
+ seq_length + past_key_values_length,
750
+ dtype=torch.long,
751
+ device=device,
752
+ )
753
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
754
+ else:
755
+ position_ids = position_ids.view(-1, seq_length).long()
756
+
757
+ if self.embed_tp_size > 1:
758
+ new_input_ids = input_ids - self.global_rank * self.vocab_size_per_rank
759
+ mask = (new_input_ids >= 0) & (
760
+ new_input_ids < self.vocab_size_per_rank
761
+ ) # (bs, qlen)
762
+ new_input_ids_per_rank = new_input_ids * mask
763
+ inputs_embeds = self.embed_tokens(new_input_ids_per_rank) * mask.unsqueeze(
764
+ -1
765
+ )
766
+ dist.all_reduce(inputs_embeds)
767
+ else:
768
+ inputs_embeds = self.embed_tokens(input_ids)
769
+ hidden_states = inputs_embeds
770
+
771
+ cos_sin = self.rotary_emb(
772
+ hidden_states, kv_len, self.config.max_position_embeddings
773
+ )
774
+ residual = None
775
+
776
+ for decoder_layer in self.layers:
777
+ residual, hidden_states = decoder_layer(
778
+ hidden_states,
779
+ kv_len,
780
+ actual_seq_lengths_kv,
781
+ cos_sin=cos_sin,
782
+ past_residual=residual,
783
+ attention_mask=attention_mask,
784
+ position_ids=position_ids,
785
+ past_key_value=past_key_values,
786
+ )
787
+
788
+ hidden_states, _ = self.norm(hidden_states, residual)
789
+
790
+ return hidden_states
791
+
792
+
793
+ class PanguUltraMoEForCausalLM(PanguUltraMoEPreTrainedModel):
794
+ _tied_weights_keys = ["lm_head.weight"]
795
+
796
+ def __init__(self, config, runner_config):
797
+ super().__init__(config)
798
+ self.config = config
799
+ self.runner_config = runner_config
800
+ self.embed_tp_size = self.runner_config.get("parallel_config").get(
801
+ "embed_tp_size", 1
802
+ )
803
+ self.model = PanguUltraMoEModel(config, self.runner_config)
804
+ self.vocab_size = config.vocab_size
805
+ self.lm_head = nn.Linear(
806
+ config.hidden_size, config.vocab_size // self.embed_tp_size, bias=False
807
+ )
808
+
809
+ def forward(
810
+ self,
811
+ input_ids: torch.LongTensor = None,
812
+ kv_len: torch.IntTensor = None,
813
+ actual_seq_lengths_kv: list = None,
814
+ attention_mask: Optional[torch.Tensor] = None,
815
+ position_ids: Optional[torch.LongTensor] = None,
816
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
817
+ ):
818
+ outputs = self.model(
819
+ input_ids=input_ids,
820
+ kv_len=kv_len,
821
+ actual_seq_lengths_kv=actual_seq_lengths_kv,
822
+ attention_mask=attention_mask,
823
+ position_ids=position_ids,
824
+ past_key_values=past_key_values,
825
+ )
826
+
827
+ hidden_states = outputs
828
+
829
+ if hidden_states.size()[1] > 1:
830
+ gather_index, _ = torch.max(position_ids, dim=-1)
831
+ gather_index = (
832
+ gather_index.unsqueeze(1)
833
+ .unsqueeze(2)
834
+ .repeat(1, 1, hidden_states.shape[-1])
835
+ )
836
+ hidden_states = torch.gather(hidden_states, 1, gather_index)
837
+
838
+ logits = self.lm_head(hidden_states)
839
+ if self.embed_tp_size > 1:
840
+ new_logits = torch.zeros_like(logits).repeat(self.embed_tp_size, 1, 1)
841
+ dist.all_gather_into_tensor(new_logits, logits, group=_world._default_pg)
842
+ new_logits = new_logits.reshape(
843
+ self.embed_tp_size, logits.shape[0], logits.shape[1], -1
844
+ ).permute(1, 2, 0, 3)
845
+ logits = new_logits.reshape(logits.shape[0], logits.shape[1], -1)
846
+ logits = logits.float()
847
+
848
+ return logits
849
+
850
+ def init_cache(self, input_ids):
851
+ batch_size, seq_len = input_ids.size()
852
+
853
+ cache_seq_len = self.config.max_position_embeddings
854
+
855
+ past_key_values = ()
856
+ cache_key_shape = (
857
+ batch_size,
858
+ 1,
859
+ cache_seq_len,
860
+ self.config.attention_kv_lora_dim + self.config.attention_qk_rope_dim,
861
+ )
862
+ dtype = self.config.torch_dtype
863
+
864
+ for _ in range(self.config.num_hidden_layers):
865
+ key_cache = torch.zeros(
866
+ cache_key_shape, dtype=dtype, device=input_ids.device
867
+ )
868
+ past_key_values += ((key_cache,),)
869
+
870
+ return past_key_values
871
+
872
+ def prepare_inputs_for_generation(
873
+ self,
874
+ input_ids,
875
+ past_key_values=None,
876
+ attention_mask=None,
877
+ inputs_embeds=None,
878
+ is_prefill=None,
879
+ kv_len=None,
880
+ share_mask_tril=None,
881
+ **kwargs,
882
+ ):
883
+ batch_size, seq_len = input_ids.size()
884
+ if past_key_values is None:
885
+ past_key_values = self.init_cache(input_ids)
886
+ if is_prefill:
887
+ position_ids = attention_mask.long().cumsum(-1) - 1
888
+ position_ids.masked_fill_(attention_mask == 0, 1)
889
+ attention_mask = share_mask_tril
890
+ kv_len = torch.zeros(
891
+ (position_ids.size()[0]), dtype=torch.int32, device=input_ids.device
892
+ )
893
+ actual_seq_lengths_kv = None
894
+ past_key_values_length = 0
895
+ input_mask = None
896
+ else:
897
+ attention_mask = None
898
+ position_ids = kv_len.unsqueeze(1)
899
+ actual_seq_lengths_kv = (kv_len + 1).cpu().detach().numpy().tolist()
900
+ past_key_values_length = self.config.max_position_embeddings - seq_len
901
+ input_mask = share_mask_tril
902
+
903
+ attention_mask = _prepare_4d_causal_attention_mask(
904
+ input_mask, (batch_size, seq_len), input_ids.float(), past_key_values_length
905
+ )
906
+
907
+ model_inputs = {}
908
+ model_inputs.update(
909
+ {
910
+ "input_ids": input_ids,
911
+ "position_ids": position_ids,
912
+ "past_key_values": past_key_values,
913
+ "attention_mask": attention_mask,
914
+ "kv_len": kv_len,
915
+ "actual_seq_lengths_kv": actual_seq_lengths_kv,
916
+ }
917
+ )
918
+ return model_inputs
inference/runner.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3
+
4
+ import copy
5
+ import logging
6
+ import os
7
+ import time
8
+
9
+ import torch
10
+ from torch.distributed.distributed_c10d import _world
11
+ from transformers import AutoTokenizer
12
+
13
+ root_logger = logging.getLogger()
14
+ root_logger.handlers.clear()
15
+ logging.basicConfig(
16
+ format="%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s",
17
+ level=logging.INFO,
18
+ )
19
+
20
+ torch.manual_seed(42)
21
+ torch.npu.manual_seed_all(42)
22
+
23
+
24
+ def get_init_attn_mask(mask_length, device, valid_len=None):
25
+ share_mask_tril = ~torch.tril(
26
+ torch.ones((mask_length, mask_length), dtype=torch.bool, device=device)
27
+ )
28
+ if valid_len is not None:
29
+ share_mask_tril[-valid_len:, :] = torch.zeros(valid_len, mask_length)
30
+ return share_mask_tril
31
+
32
+
33
+ def get_decode_mask(mask_length, device, position):
34
+ decode_mask = torch.zeros((1, mask_length), device=device)
35
+ decode_mask[0, :position] = 1
36
+ return decode_mask
37
+
38
+
39
+ def sample(input_logits: torch.Tensor, temperature=1.0, top_p=0.0, top_k=0, top_n_sigma=-1.0, **kwargs):
40
+ # shape of input_logits: [batch_size, 1, vocab_size]
41
+ # greedy
42
+ if temperature <= 0.0 or top_k == 1 or top_p == 0.0 or top_n_sigma == 0.0:
43
+ return torch.argmax(input_logits, dim=-1)
44
+
45
+ logits = input_logits / temperature
46
+
47
+ filter_value = -3.4028e+38
48
+
49
+ # top_n_sigma truncation
50
+ if top_n_sigma > 0.0:
51
+ max_vals, _ = logits.max(dim=-1, keepdim=True)
52
+ std_vals = logits.std(dim=-1, keepdim=True)
53
+ threshold = max_vals - top_n_sigma * std_vals
54
+ mask = logits < threshold
55
+ logits = torch.where(mask, filter_value, logits)
56
+
57
+ # top_k truncation
58
+ if top_k > 0:
59
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
60
+ logits[indices_to_remove] = filter_value
61
+
62
+ # top_p truncation
63
+ if 0.0 < top_p < 1.0:
64
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
65
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
66
+
67
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
68
+ # keep at least 1 token
69
+ sorted_indices_to_remove[..., -1:] = 0
70
+
71
+ indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
72
+ logits = logits.masked_fill(indices_to_remove, filter_value)
73
+
74
+ probs = logits.softmax(dim=-1)
75
+ outputs = torch.multinomial(probs.squeeze(1), num_samples=1)
76
+ return outputs
77
+
78
+
79
+ class ModelRunner:
80
+ def __init__(self, runner_config):
81
+ self.runner_config = runner_config
82
+ self.model_name = runner_config.get("model_name", "default_model_name")
83
+ model_path = self.runner_config.get("model_path")
84
+ self.dtype = runner_config.get("model_config").get("dtype", torch.bfloat16)
85
+ self.max_position_embeddings = runner_config.get("data_config").get(
86
+ "max_position_embeddings", 131072
87
+ )
88
+ self.input_max_len = runner_config.get("data_config").get("input_max_len", 1024)
89
+ self.max_new_tokens = runner_config.get("data_config").get("max_new_tokens", 32)
90
+ self.batch_size = runner_config.get("data_config").get("batch_size", 16)
91
+ self.sampling_params = runner_config.get("sampling_config", {})
92
+ self.tokenizer = None
93
+ self.model = None
94
+ self.device = None
95
+ self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
96
+ self.rank_offset = int(os.getenv("RANK_OFFSET", "0"))
97
+ self.global_rank = self.local_rank + self.rank_offset
98
+ self.world_size = int(os.getenv("WORLD_SIZE", "1"))
99
+ if self.world_size == 1:
100
+ self.model_path = model_path
101
+ else:
102
+ self.model_path = os.path.join(model_path, f"rank_{self.global_rank}")
103
+
104
+ self.res_path = os.getenv("RES_PATH", "./")
105
+ self.enable_profiler = runner_config.get("model_config").get(
106
+ "enable_profiler", 0
107
+ )
108
+ self.use_pretrained_model = True
109
+ self.execute_mode = runner_config.get("exe_mode", "dynamo")
110
+ self.tokenizer_mode = runner_config.get("model_config").get(
111
+ "tokenizer_mode", "default"
112
+ )
113
+ self.init_device()
114
+ self.start_time = None
115
+ self.end_time = None
116
+ self.with_ckpt = runner_config.get("model_config").get("with_ckpt", 1)
117
+
118
+ @staticmethod
119
+ def repeat_batch(tensor, repeat_num):
120
+ if repeat_num == 1:
121
+ return tensor
122
+ return tensor.repeat(repeat_num, *[1] * (tensor.dim() - 1))
123
+
124
+ def init_device(self):
125
+ logging.info(
126
+ "Set execution using npu index: %s, global: %s",
127
+ self.local_rank,
128
+ self.global_rank,
129
+ )
130
+ self.device = torch.device("%s:%s" % ("npu", self.local_rank))
131
+ torch.npu.set_device(self.device)
132
+ if torch.npu.is_available() and self.world_size > 1:
133
+ if _world._default_pg is None:
134
+ torch.distributed.init_process_group(
135
+ backend="hccl", world_size=self.world_size, rank=self.global_rank
136
+ )
137
+
138
+ def init_model(self, model, config=None):
139
+ if self.with_ckpt:
140
+ self.use_pretrained_model = True
141
+ config = None
142
+ else:
143
+ self.use_pretrained_model = False
144
+ from configuration_openpangu_moe import PanguUltraMoEConfig as config
145
+ logging.info(f"use_pretrained_model: {self.use_pretrained_model}")
146
+
147
+ if self.use_pretrained_model:
148
+ self.load_model(model)
149
+ else:
150
+ self.init_model_from_config(model, config=config)
151
+ self.to_device()
152
+ self.compile_model()
153
+ self.init_tokenizer()
154
+
155
+ def init_model_from_config(self, model, config):
156
+ if config is None:
157
+ raise Exception("config cannot be None")
158
+ config_file = f"{self.model_path}/config.json"
159
+ model_config = config.from_pretrained(
160
+ config_file,
161
+ torch_dtype=self.dtype,
162
+ max_position_embeddings=self.max_position_embeddings,
163
+ )
164
+ self.model = model(model_config, runner_config=self.runner_config).to(
165
+ self.dtype
166
+ )
167
+
168
+ def load_model(self, model):
169
+ logging.info("Try to load pretrained model in path: %s", self.model_path)
170
+ self.model = model.from_pretrained(
171
+ self.model_path,
172
+ low_cpu_mem_usage=True,
173
+ ignore_mismatched_sizes=True,
174
+ torch_dtype=self.dtype,
175
+ max_position_embeddings=self.max_position_embeddings,
176
+ runner_config=self.runner_config,
177
+ )
178
+ for name, params in self.model.named_parameters():
179
+ logging.info(
180
+ "Param of %s: %s, %s, %s",
181
+ self.model_name,
182
+ name,
183
+ params.size(),
184
+ params.dtype,
185
+ )
186
+
187
+ def to_device(self):
188
+ self.model.to(self.device)
189
+ logging.info("Model weights H2D finished.")
190
+
191
+ def init_tokenizer(self):
192
+ self.tokenizer = AutoTokenizer.from_pretrained(
193
+ self.model_path,
194
+ trust_remote_code=True,
195
+ local_files_only=True,
196
+ padding_side="right",
197
+ truncation_side="right",
198
+ )
199
+ if self.tokenizer.pad_token is None:
200
+ self.tokenizer.pad_token = self.tokenizer.eos_token
201
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
202
+
203
+ def compile_model(self):
204
+ logging.info("The final model structure is: \n %s", self.model)
205
+ if self.execute_mode == "dynamo":
206
+ logging.info("Try to compile model")
207
+ self.graph_compile()
208
+
209
+ def graph_compile(self):
210
+ import torchair as tng
211
+ import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce
212
+ from torchair.configs.compiler_config import CompilerConfig
213
+
214
+ compiler_config = CompilerConfig()
215
+ compiler_config.experimental_config.frozen_parameter = True
216
+ compiler_config.experimental_config.tiling_schedule_optimize = True
217
+ npu_backend = tng.get_npu_backend(compiler_config=compiler_config)
218
+ self.model = torch.compile(
219
+ self.model, dynamic=True, fullgraph=True, backend=npu_backend
220
+ )
221
+
222
+ def mark_inputs(self, model_inputs):
223
+ if self.execute_mode == "dynamo":
224
+ input_ids = model_inputs.get("input_ids")
225
+ kv_len = model_inputs.get("kv_len")
226
+ attention_mask = model_inputs.get("attention_mask")
227
+ position_ids = model_inputs.get("position_ids")
228
+ past_key_values = model_inputs.get("past_key_values")
229
+
230
+ # prefill with dynamic sequence length, decode with static sequence length
231
+ torch._dynamo.mark_static(kv_len)
232
+ for item in past_key_values:
233
+ for sub_item in item:
234
+ torch._dynamo.mark_static(sub_item)
235
+
236
+ torch._dynamo.mark_static(input_ids)
237
+ if attention_mask is not None:
238
+ torch._dynamo.mark_static(attention_mask)
239
+ torch._dynamo.mark_static(position_ids)
240
+
241
+ def model_input_prepare(self, input_dict):
242
+ input_ids = input_dict.get("input_ids")
243
+ attention_mask = input_dict.get("attention_mask")
244
+ past_key_values = input_dict.get("past_key_values")
245
+ is_prefill = input_dict.get("is_prefill")
246
+ kv_len = input_dict.get("kv_len")
247
+ share_mask_tril = input_dict.get("share_mask_tril")
248
+ model_inputs = self.model.prepare_inputs_for_generation(
249
+ input_ids=input_ids,
250
+ attention_mask=attention_mask,
251
+ past_key_values=past_key_values,
252
+ is_prefill=is_prefill,
253
+ kv_len=kv_len,
254
+ input_lens=input_dict.get("input_lens"),
255
+ share_mask_tril=share_mask_tril,
256
+ )
257
+ return model_inputs
258
+
259
+ def model_inference(self, model_inputs, warm_up=False):
260
+ torch.npu.synchronize()
261
+ if warm_up:
262
+ self.mark_inputs(model_inputs)
263
+ if self.start_time is None:
264
+ self.start_time = time.time()
265
+ with torch.no_grad():
266
+ logits = self.model(**model_inputs)
267
+ torch.npu.synchronize()
268
+ self.end_time = time.time()
269
+ if torch.distributed.get_rank() != 0:
270
+ logging.info(
271
+ f"{self.model_name} inference time cost {(self.end_time - self.start_time)*1000:.2f} ms"
272
+ )
273
+ self.start_time = time.time()
274
+ return logits
275
+
276
+ def model_generate(self, prompts, warm_up=False, **kwargs):
277
+ calling_func = {
278
+ "default": self.tokenizer,
279
+ "chat": self.tokenizer.apply_chat_template,
280
+ }
281
+ kwargs = {
282
+ "return_tensors": "pt",
283
+ "truncation": True,
284
+ "padding": "max_length",
285
+ "max_length": self.input_max_len,
286
+ }
287
+ if self.tokenizer_mode == "chat":
288
+ chat_kwargs = {"add_generation_prompt": True, "return_dict": True}
289
+ kwargs.update(chat_kwargs)
290
+ tokenizer = calling_func.get(self.tokenizer_mode, self.tokenizer)
291
+ inputs = tokenizer(prompts, **kwargs).to(self.device)
292
+
293
+ # get init input_dict
294
+ share_mask_tril = get_init_attn_mask(
295
+ self.max_position_embeddings, self.device, valid_len=self.input_max_len
296
+ )
297
+ share_mask_tril = share_mask_tril[None, None, ...]
298
+
299
+ input_lens = copy.deepcopy(inputs.input_ids.size()[1])
300
+ logging.info("Padding max prompts lens is : %d", input_lens)
301
+ input_dict = {
302
+ "input_ids": inputs.input_ids,
303
+ "generate_ids": inputs.input_ids,
304
+ "input_lens": input_lens,
305
+ "kv_len": None,
306
+ "past_key_values": None,
307
+ "attention_mask": inputs.attention_mask,
308
+ "share_mask_tril": share_mask_tril,
309
+ "is_prefill": True,
310
+ }
311
+
312
+ if torch.distributed.get_rank() == 0:
313
+ logging.info(
314
+ f"inputs.input_ids {inputs.input_ids[:,:30]} eod id {self.tokenizer.eos_token_id}"
315
+ )
316
+
317
+ generate_tokens = 0
318
+ cnt = 0
319
+ all_done = [False for _ in range(input_dict["input_ids"].size(0))]
320
+ done_len = [-1 for _ in range(input_dict["input_ids"].size(0))]
321
+ while True:
322
+ jump_flag = self.get_jump_flag(cnt, warm_up, generate_tokens)
323
+ if jump_flag:
324
+ break
325
+
326
+ # exit until all reach eod
327
+ if input_dict["input_ids"].size(1) == 1:
328
+ for bi in range(input_dict["input_ids"].size(0)):
329
+ if (
330
+ input_dict["input_ids"][bi, 0].item()
331
+ == self.tokenizer.eos_token_id
332
+ ):
333
+ all_done[bi] = True
334
+ done_len[bi] = generate_tokens
335
+ if all(all_done):
336
+ break
337
+ model_inputs = self.model_input_prepare(input_dict)
338
+ # fix decode mask
339
+ if model_inputs["position_ids"].shape[1] == 1:
340
+ model_inputs["attention_mask"].fill_(-3.4028e38)
341
+ for bi in range(model_inputs["position_ids"].size(0)):
342
+ max_l = model_inputs["position_ids"][bi].max().item()
343
+ model_inputs["attention_mask"][bi, :, :, : max_l + 1] = 0
344
+ outputs = self.model_inference(model_inputs, warm_up=warm_up)
345
+ self.model_output_process(model_inputs, outputs, input_dict)
346
+ # prof.step()
347
+ generate_tokens += 1
348
+ cnt += 1
349
+
350
+ generate_ids = input_dict["generate_ids"][:, input_lens:].clip(
351
+ 0, self.model.config.vocab_size - 1
352
+ )
353
+ for bi in range(generate_ids.size(0)):
354
+ if done_len[bi] != -1:
355
+ generate_ids[bi, done_len[bi] :] = self.tokenizer.eos_token_id
356
+ res = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True)
357
+
358
+ if isinstance(res, list):
359
+ for answer in res:
360
+ logging.info("Inference decode result: \n%s", answer)
361
+ else:
362
+ logging.info("Inference decode result: \n%s", res)
363
+ return res
364
+
365
+ def get_jump_flag(self, cnt, warm_up, generate_tokens):
366
+ default_decode_dump = 2
367
+ # warm up only perform for 5 times(decode)
368
+ jump_flag_warm = warm_up and cnt >= default_decode_dump
369
+ # do not generate after max_token
370
+ jump_flag_oversize = generate_tokens >= self.max_new_tokens
371
+ jump_flag = jump_flag_oversize or jump_flag_warm
372
+ return jump_flag
373
+
374
+ def model_output_process(self, model_inputs, outputs, input_dict):
375
+ next_batch = self.batch_size
376
+ attn_tp_size = self.runner_config.get("parallel_config").get("attn_tp_size", 1)
377
+ if self.world_size % attn_tp_size != 0:
378
+ raise Exception(
379
+ f"world_size ({self.world_siz}) not divisible by attn_tp_size ({attn_tp_size})!"
380
+ )
381
+ attn_dp_size = self.world_size // attn_tp_size
382
+ input_dict["is_prefill"] = False
383
+ input_dict["input_lens"] = input_dict["input_lens"] + 1
384
+
385
+ kv_len = torch.max(model_inputs.get("position_ids"), axis=1)[0] + 1
386
+ input_dict["kv_len"] = kv_len
387
+
388
+ logits = outputs
389
+ past_key_values = model_inputs.get("past_key_values")
390
+ input_dict["past_key_values"] = past_key_values
391
+
392
+ attention_mask = None
393
+
394
+ share_mask_tril = get_decode_mask(
395
+ mask_length=self.max_position_embeddings,
396
+ device=self.device,
397
+ position=input_dict["input_lens"],
398
+ )
399
+ share_mask_tril = share_mask_tril[None, None, ...]
400
+
401
+ input_dict["attention_mask"] = attention_mask
402
+ input_dict["share_mask_tril"] = ModelRunner.repeat_batch(
403
+ share_mask_tril, self.batch_size
404
+ )
405
+
406
+ next_tokens = sample(logits, **self.sampling_params)
407
+ torch.distributed.broadcast(next_tokens, src=0)
408
+ input_dict["input_ids"] = next_tokens
409
+ input_dict["generate_ids"] = torch.cat(
410
+ [input_dict["generate_ids"], next_tokens], dim=-1
411
+ )
inference/runner_config/tp1.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3
+
4
+ model_name: "pangu_ultra_moe"
5
+ model_path: "./model"
6
+ exe_mode: "eager" # ["dynamo", "eager"]
7
+
8
+ model_config:
9
+ tokenizer_mode: default # ["default", "chat"]
10
+ mm_quant_mode: None
11
+ mla_backend: absorb # [native, absorb]
12
+ with_ckpt: 1 # [0, 1]
13
+ enable_profiler: 0 # [0, 1]
14
+
15
+ data_config:
16
+ input_max_len: 4096
17
+ max_new_tokens: 28000
18
+ batch_size: 1
19
+ max_position_embeddings: 32768
20
+
21
+ parallel_config:
22
+ attn_tp_size: 1
23
+ moe_tp_size: 1
24
+ embed_tp_size: 1
25
+
26
+ sampling_config:
27
+ top_n_sigma: 0.05
28
+ top_p: 1.0
29
+ temperature: 0.7
30
+ top_k: -1
inference/runner_config/tp32.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3
+
4
+ model_name: "pangu_ultra_moe"
5
+ model_path: "./model"
6
+ exe_mode: "eager" # ["dynamo", "eager"]
7
+
8
+ model_config:
9
+ tokenizer_mode: default # ["default", "chat"]
10
+ mm_quant_mode: None
11
+ mla_backend: absorb # [native, absorb]
12
+ with_ckpt: 1 # [0, 1]
13
+ enable_profiler: 0 # [0, 1]
14
+
15
+ data_config:
16
+ input_max_len: 4096
17
+ max_new_tokens: 28000
18
+ batch_size: 1
19
+ max_position_embeddings: 32768
20
+
21
+ parallel_config:
22
+ attn_tp_size: 32
23
+ moe_tp_size: 32
24
+ embed_tp_size: 32
25
+
26
+ sampling_config:
27
+ top_n_sigma: 0.05
28
+ top_p: 1.0
29
+ temperature: 0.7
30
+ top_k: -1
inference/split_weight.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3
+
4
+ import argparse
5
+ import logging
6
+ import os
7
+ import shutil
8
+ from threading import Thread
9
+
10
+ import numpy as np
11
+ import torch
12
+ import yaml
13
+ from torch import nn
14
+ from transformers import AutoModelForCausalLM
15
+
16
+ from model import PanguUltraMoEForCausalLM
17
+
18
+ root_logger = logging.getLogger()
19
+ root_logger.handlers.clear()
20
+ logging.basicConfig(
21
+ format="%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s",
22
+ level=logging.INFO,
23
+ )
24
+
25
+
26
+ def _to_parameter(data):
27
+ return nn.Parameter(data, requires_grad=False)
28
+
29
+
30
+ def split_w_dense(block, dst_model, i, local_rank):
31
+ up_weight_list = []
32
+ ffn_dim = dst_model.model.layers[i].mlp.intermediate_size_per_rank
33
+ gate_weight = block.mlp.gate_proj.weight[
34
+ local_rank * ffn_dim : (local_rank + 1) * ffn_dim, :
35
+ ].contiguous()
36
+ up_weight = block.mlp.up_proj.weight[
37
+ local_rank * ffn_dim : (local_rank + 1) * ffn_dim, :
38
+ ].contiguous()
39
+ up_weight_list.append(_to_parameter(torch.cat([gate_weight, up_weight], axis=0)))
40
+
41
+ if len(up_weight_list) == 1:
42
+ dst_model.model.layers[i].mlp.merge_up_gate_proj.weight = up_weight_list[0]
43
+ else:
44
+ dst_model.model.layers[i].mlp.merge_up_gate_proj.weight = _to_parameter(
45
+ torch.cat(up_weight_list, axis=0)
46
+ )
47
+ dst_model.model.layers[i].mlp.down_proj.weight.data = (
48
+ block.mlp.down_proj.weight.data[
49
+ :, local_rank * ffn_dim : (local_rank + 1) * ffn_dim
50
+ ].contiguous()
51
+ )
52
+
53
+
54
+ def split_w_moe(block, dst_model, i, local_rank):
55
+ shared_up_weight_list = []
56
+ ffn_dim = dst_model.model.layers[i].mlp.shared_experts.intermediate_size_per_rank
57
+ gate_weight = block.mlp.shared_experts.gate_proj.weight[
58
+ local_rank * ffn_dim : (local_rank + 1) * ffn_dim, :
59
+ ].contiguous()
60
+ up_weight = block.mlp.shared_experts.up_proj.weight[
61
+ local_rank * ffn_dim : (local_rank + 1) * ffn_dim, :
62
+ ].contiguous()
63
+ shared_up_weight_list.append(
64
+ _to_parameter(torch.cat([gate_weight, up_weight], axis=0))
65
+ )
66
+ if len(shared_up_weight_list) == 1:
67
+ dst_model.model.layers[i].mlp.shared_experts.merge_up_gate_proj.weight = (
68
+ shared_up_weight_list[0]
69
+ )
70
+ else:
71
+ dst_model.model.layers[i].mlp.shared_experts.merge_up_gate_proj.weight = (
72
+ _to_parameter(torch.cat(shared_up_weight_list, axis=0))
73
+ )
74
+ dst_model.model.layers[i].mlp.shared_experts.down_proj.weight.data = (
75
+ block.mlp.shared_experts.down_proj.weight.data[
76
+ :, local_rank * ffn_dim : (local_rank + 1) * ffn_dim
77
+ ].contiguous()
78
+ )
79
+ dst_model.model.layers[i].mlp.gate.weight.data = block.mlp.gate.weight.data
80
+
81
+ expert_num = block.mlp.num_routed_experts
82
+ gate_proj_list, down_proj_list, up_proj_list = [], [], []
83
+ for _, src_expert in enumerate(block.mlp.experts):
84
+ ffn_dim = dst_model.model.layers[i].mlp.experts.intermediate_size_per_rank
85
+ gate_proj_list.append(
86
+ src_expert.gate_proj.weight.data[
87
+ local_rank * ffn_dim : (local_rank + 1) * ffn_dim, :
88
+ ].contiguous()
89
+ )
90
+ up_proj_list.append(
91
+ src_expert.up_proj.weight.data[
92
+ local_rank * ffn_dim : (local_rank + 1) * ffn_dim, :
93
+ ].contiguous()
94
+ )
95
+ down_proj_list.append(
96
+ src_expert.down_proj.weight.data[
97
+ :, local_rank * ffn_dim : (local_rank + 1) * ffn_dim
98
+ ].contiguous()
99
+ )
100
+
101
+ dst_model.model.layers[i].mlp.experts.group_w2.data = (
102
+ torch.cat(down_proj_list, dim=0).view(expert_num, -1, ffn_dim).contiguous()
103
+ )
104
+ group_gate_proj = (
105
+ torch.cat(gate_proj_list, dim=0).view(expert_num, ffn_dim, -1).contiguous()
106
+ )
107
+ group_up_proj = (
108
+ torch.cat(up_proj_list, dim=0).view(expert_num, ffn_dim, -1).contiguous()
109
+ )
110
+ dst_model.model.layers[i].mlp.experts.group_w1_w3.data = torch.cat(
111
+ [group_gate_proj, group_up_proj], dim=1
112
+ )
113
+
114
+
115
+ def split_w_attn(block, dst_model, i, local_rank):
116
+ q_dim = (
117
+ dst_model.model.layers[0].self_attn.num_heads_per_rank
118
+ * dst_model.model.layers[0].self_attn.q_head_dim
119
+ )
120
+ o_dim = (
121
+ dst_model.model.layers[0].self_attn.num_heads_per_rank
122
+ * dst_model.model.layers[0].self_attn.attention_v_dim
123
+ )
124
+
125
+ if dst_model.model.layers[i].self_attn.attention_q_lora_dim is None:
126
+ dst_model.model.layers[i].self_attn.q_proj.weight.data = (
127
+ block.self_attn.q_proj.weight.data[
128
+ local_rank * q_dim : (local_rank + 1) * q_dim, :
129
+ ].contiguous()
130
+ )
131
+ else:
132
+ dst_model.model.layers[i].self_attn.q_a_proj.weight.data = (
133
+ block.self_attn.q_a_proj.weight.data
134
+ )
135
+ dst_model.model.layers[i].self_attn.q_a_layernorm.weight.data = (
136
+ block.self_attn.q_a_layernorm.weight.data
137
+ )
138
+ dst_model.model.layers[i].self_attn.q_b_proj.weight.data = (
139
+ block.self_attn.q_b_proj.weight.data[
140
+ local_rank * q_dim : (local_rank + 1) * q_dim, :
141
+ ].contiguous()
142
+ )
143
+
144
+ dst_model.model.layers[i].self_attn.kv_a_proj_with_mqa.weight.data = (
145
+ block.self_attn.kv_a_proj_with_mqa.weight.data
146
+ )
147
+
148
+ dst_model.model.layers[i].self_attn.kv_a_layernorm.weight.data = (
149
+ block.self_attn.kv_a_layernorm.weight.data
150
+ )
151
+ dst_model.model.layers[i].self_attn.o_proj.weight.data = (
152
+ block.self_attn.o_proj.weight.data[
153
+ :, local_rank * o_dim : (local_rank + 1) * o_dim
154
+ ].contiguous()
155
+ )
156
+ dst_model.model.layers[i].input_layernorm.weight.data = (
157
+ block.input_layernorm.weight.data
158
+ )
159
+ dst_model.model.layers[i].post_attention_layernorm.weight.data = (
160
+ block.post_attention_layernorm.weight.data
161
+ )
162
+ dst_model.model.layers[i].pre_mlp_layernorm.weight.data = (
163
+ block.pre_mlp_layernorm.weight.data
164
+ )
165
+ dst_model.model.layers[i].post_mlp_layernorm.weight.data = (
166
+ block.post_mlp_layernorm.weight.data
167
+ )
168
+
169
+
170
+ def kv_low_rank_split(block, dst_model, i, local_rank):
171
+ k_dim = dst_model.model.layers[0].self_attn.num_heads_per_rank * (
172
+ dst_model.model.layers[0].self_attn.attention_qk_dim
173
+ + dst_model.model.layers[0].self_attn.attention_v_dim
174
+ )
175
+ kv_b_proj_weight_data = block.self_attn.kv_b_proj.weight.data[
176
+ local_rank * k_dim : (local_rank + 1) * k_dim, :
177
+ ].contiguous()
178
+ attention_qk_dim = dst_model.model.layers[i].self_attn.attention_qk_dim
179
+ num_heads_per_rank = dst_model.model.layers[i].self_attn.num_heads_per_rank
180
+ attention_kv_lora_dim = dst_model.model.layers[i].self_attn.attention_kv_lora_dim
181
+ attention_v_dim = dst_model.model.layers[i].self_attn.attention_v_dim
182
+
183
+ index_tensor = torch.arange(attention_qk_dim).repeat(
184
+ num_heads_per_rank
185
+ ) + torch.arange(num_heads_per_rank).repeat_interleave(attention_qk_dim) * (
186
+ attention_qk_dim + attention_v_dim
187
+ )
188
+ kv_b_proj_w_k = torch.index_select(kv_b_proj_weight_data, dim=0, index=index_tensor)
189
+ dst_model.model.layers[i].self_attn.kv_b_proj_w_k.data = kv_b_proj_w_k.view(
190
+ num_heads_per_rank, attention_qk_dim, attention_kv_lora_dim
191
+ ).contiguous()
192
+ index_tensor = torch.arange(
193
+ attention_qk_dim, attention_qk_dim + attention_v_dim
194
+ ).repeat(num_heads_per_rank) + torch.arange(num_heads_per_rank).repeat_interleave(
195
+ attention_v_dim
196
+ ) * (
197
+ attention_qk_dim + attention_v_dim
198
+ )
199
+ kv_b_proj_w_v = torch.index_select(kv_b_proj_weight_data, dim=0, index=index_tensor)
200
+ dst_model.model.layers[i].self_attn.kv_b_proj_w_v.data = (
201
+ kv_b_proj_w_v.view(num_heads_per_rank, attention_v_dim, attention_kv_lora_dim)
202
+ .transpose(1, 2)
203
+ .contiguous()
204
+ )
205
+
206
+
207
+ def split_layer(block, dst_model, i, local_rank, attn_tp_size, moe_tp_size):
208
+ # attn weights
209
+ local_rank_tp_attn = local_rank % attn_tp_size
210
+ split_w_attn(block, dst_model, i, local_rank_tp_attn)
211
+ kv_low_rank_split(block, dst_model, i, local_rank_tp_attn)
212
+
213
+ # moe experts weights
214
+ local_rank_tp_moe = local_rank % moe_tp_size
215
+ if i >= dst_model.config.num_dense_layers:
216
+ split_w_moe(block, dst_model, i, local_rank_tp_moe)
217
+ else:
218
+ split_w_dense(block, dst_model, i, local_rank_tp_moe)
219
+
220
+
221
+ def split_w(src_model, dst_model, local_rank, runner_config):
222
+ attn_tp_size = runner_config.get("parallel_config").get("attn_tp_size")
223
+ moe_tp_size = runner_config.get("parallel_config").get("moe_tp_size")
224
+ embed_tp_size = runner_config.get("parallel_config").get("embed_tp_size")
225
+
226
+ vocab_size = src_model.model.vocab_size // embed_tp_size
227
+ embed_tp_rank = local_rank % embed_tp_size
228
+
229
+ dst_model.lm_head.weight.data = src_model.lm_head.weight.data[
230
+ embed_tp_rank * vocab_size : (embed_tp_rank + 1) * vocab_size, :
231
+ ]
232
+ dst_model.model.embed_tokens.weight.data = src_model.model.embed_tokens.weight.data[
233
+ embed_tp_rank * vocab_size : (embed_tp_rank + 1) * vocab_size, :
234
+ ]
235
+
236
+ dst_model.model.norm.weight.data = src_model.model.norm.weight.data
237
+
238
+ layer_num = len(src_model.model.layers)
239
+
240
+ all_threads = []
241
+ for i in range(0, layer_num):
242
+ block = src_model.model.layers[i]
243
+ thread = Thread(
244
+ target=split_layer,
245
+ args=(block, dst_model, i, local_rank, attn_tp_size, moe_tp_size),
246
+ )
247
+ all_threads.append(thread)
248
+ thread.start()
249
+ for thread in all_threads:
250
+ thread.join()
251
+
252
+
253
+ def copy_files_with_prefix(src_dir, dst_dir, prefix):
254
+ for file in os.listdir(src_dir):
255
+ if file.startswith(prefix):
256
+ src_file = os.path.join(src_dir, file)
257
+ dst_file = os.path.join(dst_dir, file)
258
+ shutil.copy2(src_file, dst_file)
259
+
260
+
261
+ def parse_args():
262
+ parser = argparse.ArgumentParser(
263
+ description="Split weight parameters with tensor parallel"
264
+ )
265
+ parser.add_argument("--model_path", type=str, help="Path of model weights")
266
+ parser.add_argument(
267
+ "--output_path",
268
+ type=str,
269
+ help="The output directory where the results are saved",
270
+ )
271
+ parser.add_argument(
272
+ "--origin_yaml_file_path", type=str, help="inference configurations"
273
+ )
274
+ parser.add_argument(
275
+ "--new_yaml_file_path", type=str, help="inference configurations"
276
+ )
277
+ parser.add_argument(
278
+ "--world_size", type=int, default=8, help="The parallel rank size of model"
279
+ )
280
+ parser.add_argument("--node_num", type=int, default=1, help="The parallel node num")
281
+ parser.add_argument(
282
+ "--node_rank", type=int, default=0, help="The parallel node rank"
283
+ )
284
+ parser_args = parser.parse_args()
285
+ return parser_args
286
+
287
+
288
+ def show_model_states(origin_model, model_name="src_model"):
289
+ src_param_size = 0
290
+ for name, params in origin_model.named_parameters():
291
+ size_per_param = np.prod(params.size())
292
+ src_param_size += size_per_param
293
+ logging.info(
294
+ "Param of %s tensor parallel: %s, %s, %s",
295
+ model_name,
296
+ name,
297
+ params.size(),
298
+ params.dtype,
299
+ )
300
+ logging.info(
301
+ "Total param size of %s tensor parallel: %s", model_name, src_param_size
302
+ )
303
+
304
+
305
+ def read_yaml(yaml_file_path):
306
+ try:
307
+ with open(yaml_file_path, "r", encoding="utf-8") as file:
308
+ data = yaml.safe_load(file)
309
+ except FileNotFoundError:
310
+ logging.error(f"No such yaml file: {yaml_file_path}")
311
+ except yaml.YAMLERROR as e:
312
+ logging.error(f"Load yaml file failed: {e}")
313
+ return data
314
+
315
+
316
+ def check_vars(world_size, runner_config):
317
+ attn_tp_size = runner_config.get("parallel_config").get("attn_tp_size")
318
+ moe_tp_size = runner_config.get("parallel_config").get("moe_tp_size")
319
+ embed_tp_size = runner_config.get("parallel_config").get("embed_tp_size")
320
+ if world_size % attn_tp_size != 0:
321
+ logging.error(
322
+ "world_size %s mod attn_tp_size %s must be 0", world_size, attn_tp_size
323
+ )
324
+ exit(1)
325
+ if world_size % moe_tp_size != 0:
326
+ logging.error(
327
+ "world_size %s mod moe_tp_size %s must be 0", world_size, moe_tp_size
328
+ )
329
+ exit(1)
330
+ if world_size % embed_tp_size != 0:
331
+ logging.error(
332
+ "world_size %s mod embed_tp_size %s must be 0", world_size, embed_tp_size
333
+ )
334
+ exit(1)
335
+
336
+
337
+ if __name__ == "__main__":
338
+ logging.info("Start to split weight...")
339
+ args = parse_args()
340
+ output_path = args.output_path
341
+
342
+ old_runner_config = read_yaml(args.origin_yaml_file_path)
343
+ new_runner_config = read_yaml(args.new_yaml_file_path)
344
+ world_size = args.world_size
345
+
346
+ if not os.path.exists(output_path):
347
+ os.makedirs(output_path)
348
+ origin_model = AutoModelForCausalLM.from_pretrained(
349
+ args.model_path,
350
+ trust_remote_code=True,
351
+ local_files_only=True,
352
+ ignore_mismatched_sizes=True,
353
+ low_cpu_mem_usage=True,
354
+ torch_dtype=torch.bfloat16,
355
+ attn_implementation="eager",
356
+ )
357
+ show_model_states(origin_model, "origin_model")
358
+
359
+ node_rank_id = args.node_rank
360
+ rank_num_per_node = world_size // args.node_num
361
+ start_rank = rank_num_per_node * node_rank_id
362
+ end_rank = rank_num_per_node * (node_rank_id + 1)
363
+
364
+ for rank_id in range(start_rank, end_rank):
365
+ logging.info("rank_id={} / rank_size={}".format(rank_id, world_size))
366
+ os.environ["LOCAL_RANK"] = str(rank_id)
367
+
368
+ save_path = os.path.join(output_path, f"rank_{rank_id}")
369
+ logging.info(
370
+ "Split weight for rank %s start, save path is: %s", rank_id, save_path
371
+ )
372
+
373
+ config = origin_model.config
374
+ part_model = PanguUltraMoEForCausalLM(config, new_runner_config)
375
+
376
+ split_w(origin_model, part_model, rank_id, new_runner_config)
377
+
378
+ show_model_states(part_model, "dst_model")
379
+
380
+ part_model.save_pretrained(save_path)
381
+ copy_files_with_prefix(args.model_path, save_path, "tokenizer")
382
+ copy_files_with_prefix(args.model_path, save_path, "tokenization")
383
+ logging.info(
384
+ "Split weight for rank %s finished, save path is: %s", rank_id, save_path
385
+ )
386
+
387
+ del part_model
inference/split_weight.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3
+
4
+ rm -rf ./model
5
+ mkdir ./model
6
+ python split_weight.py \
7
+ --model_path=../ \
8
+ --output_path=./model \
9
+ --origin_yaml_file_path=./runner_config/tp1.yaml \
10
+ --new_yaml_file_path=./runner_config/tp32.yaml \
11
+ --world_size=32 \
12
+ --node_num=1 \
13
+ --node_rank=0
inference/vllm_ascend/_build_info.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Auto-generated file
2
+ __soc_version__ = 'ASCEND910B1'
3
+ __sleep_mode_enabled__ = True
inference/vllm_ascend/attention/attention.py ADDED
@@ -0,0 +1,1220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # This file is a part of the vllm-ascend project.
16
+ #
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, List, Optional, Tuple, Type
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch_npu
24
+ import torchair._contrib.custom_torch_ops # type: ignore # noqa: F401
25
+ from torch.nn.functional import scaled_dot_product_attention
26
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
27
+ AttentionLayer,
28
+ AttentionMetadata, AttentionType,
29
+ MLAAttentionImpl)
30
+ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
31
+ CommonMetadataBuilder,
32
+ compute_slot_mapping,
33
+ compute_slot_mapping_start_idx,
34
+ is_block_tables_empty)
35
+ from vllm.utils import async_tensor_h2d, make_tensor_with_pad
36
+
37
+ from vllm_ascend.ascend_config import get_ascend_config
38
+ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
39
+ from vllm_ascend.ops.cache import concat_and_cache_mla
40
+ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
41
+ enable_custom_op, is_310p, nd_to_nz_2d)
42
+ from vllm_ascend.worker.model_runner import (
43
+ ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
44
+
45
+ _ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
46
+
47
+
48
+ class AscendAttentionBackend(AttentionBackend):
49
+
50
+ @staticmethod
51
+ def get_name() -> str:
52
+ return "ASCEND"
53
+
54
+ @staticmethod
55
+ def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
56
+ return AscendAttentionBackendImpl
57
+
58
+ @staticmethod
59
+ def get_metadata_cls() -> Type["AscendMetadata"]:
60
+ return AscendMetadata
61
+
62
+ @staticmethod
63
+ def get_state_cls() -> Type["CommonAttentionState"]:
64
+ return CommonAttentionState
65
+
66
+ @staticmethod
67
+ def get_kv_cache_shape(
68
+ num_blocks: int,
69
+ block_size: int,
70
+ num_kv_heads: int,
71
+ head_size: int,
72
+ ) -> Tuple[int, ...]:
73
+ if is_310p():
74
+ return (2, num_blocks, num_kv_heads * head_size // 16, block_size,
75
+ 16)
76
+ else:
77
+ return (2, num_blocks, block_size, num_kv_heads, head_size)
78
+
79
+ @staticmethod
80
+ def swap_blocks(
81
+ src_kv_cache: List[torch.Tensor],
82
+ dst_kv_cache: List[torch.Tensor],
83
+ src_to_dst: torch.Tensor,
84
+ ) -> None:
85
+ src_key_cache, src_value_cache = src_kv_cache[0], src_kv_cache[1]
86
+ dst_key_cache, dst_value_cache = dst_kv_cache[0], dst_kv_cache[1]
87
+ src_indices = src_to_dst[:, 0]
88
+ dst_indices = src_to_dst[:, 1]
89
+
90
+ dst_key_cache[dst_indices] = src_key_cache[src_indices].to(
91
+ dst_key_cache.device)
92
+ dst_value_cache[dst_indices] = src_value_cache[src_indices].to(
93
+ dst_key_cache.device)
94
+
95
+ @staticmethod
96
+ def copy_blocks(
97
+ kv_caches: List[torch.Tensor],
98
+ src_to_dists: torch.Tensor,
99
+ ) -> None:
100
+ src_indices = src_to_dists[:, 0]
101
+ dst_indices = src_to_dists[:, 1]
102
+
103
+ for kv_cache in kv_caches:
104
+ key_caches = kv_cache[0]
105
+ value_caches = kv_cache[1]
106
+ key_caches[dst_indices] = key_caches[src_indices]
107
+ value_caches[dst_indices] = value_caches[src_indices]
108
+
109
+ @staticmethod
110
+ def get_builder_cls() -> Type["AscendMetadataBuilder"]:
111
+ return AscendMetadataBuilder
112
+
113
+ @classmethod
114
+ def make_metadata_builder(cls, *args, **kwargs) -> "AscendMetadataBuilder":
115
+ return cls.get_builder_cls()(*args, **kwargs)
116
+
117
+
118
+ class AscendMLAAttentionBackend(AscendAttentionBackend):
119
+
120
+ @staticmethod
121
+ def get_impl_cls() -> Type["AscendMLAAttentionBackendImpl"]:
122
+ return AscendMLAAttentionBackendImpl
123
+
124
+ @staticmethod
125
+ def get_kv_cache_shape(
126
+ num_blocks: int,
127
+ block_size: int,
128
+ num_kv_heads: int,
129
+ head_size: int,
130
+ ) -> Tuple[int, ...]:
131
+ return (num_blocks, block_size, num_kv_heads, head_size)
132
+
133
+
134
+ @dataclass
135
+ class AscendMetadata(AttentionMetadata):
136
+ """Metadata for Ascendbackend.
137
+ * modified from XFormersbackend
138
+ NOTE: Any python object stored here is not updated when it is
139
+ cuda-graph replayed. If you have values that need to be changed
140
+ dynamically, it should be stored in tensor. The tensor has to be
141
+ updated from `CUDAGraphRunner.forward` API.
142
+ """
143
+
144
+ # |---------- N-1 iteration --------|
145
+ # |---------------- N iteration ---------------------|
146
+ # |- tokenA -|......................|-- newTokens ---|
147
+ # |---------- context_len ----------|
148
+ # |-------------------- seq_len ----------------------|
149
+ # |-- query_len ---|
150
+
151
+ # FIXME: It is for flash attn.
152
+ # Maximum sequence length among prefill batch. 0 if there are decoding
153
+ # Avoid mypy error
154
+ # Total number of prefill requests.
155
+ num_prefills: int
156
+ # Number of prefill tokens.
157
+ num_prefill_tokens: int
158
+ # (num_tokens,). The indices of the token slots that input tokens will be
159
+ # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
160
+ # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
161
+ # in block 0, and 1st slot in block 1, respectively.
162
+ slot_mapping: torch.Tensor
163
+
164
+ # requests only.
165
+ max_prefill_seq_len: int
166
+ # Maximum sequence length among decode batch. 0 if there are prefill
167
+ # requests only.
168
+ max_decode_seq_len: int
169
+
170
+ chunked_prefill_enabled: bool
171
+
172
+ # (batch_size, max_blocks_per_seq).
173
+ # Block addresses per sequence. (Seq id -> list of physical block)
174
+ block_tables: Optional[torch.Tensor]
175
+
176
+ # seq_lens stored as a tensor.
177
+ seq_lens_tensor: Optional[torch.Tensor]
178
+
179
+ # (batch_size,). The sequence length per sequence. Sequence length means
180
+ # the computed tokens + new tokens None if it is a decoding.
181
+ seq_lens: Optional[List[int]] = None
182
+
183
+ # The query lengths of the input sequences
184
+ query_lens: Optional[List[int]] = None
185
+
186
+ # Maximum query length in the batch. None for decoding.
187
+ max_query_len: Optional[int] = None
188
+
189
+ # Self-attention prefill/decode metadata cache
190
+ _cached_prefill_metadata: Optional["AscendMetadata"] = None
191
+ _cached_decode_metadata: Optional["AscendMetadata"] = None
192
+
193
+ # Begin encoder attn & enc/dec cross-attn fields...
194
+
195
+ # Encoder sequence lengths representation
196
+ encoder_seq_lens: Optional[List[int]] = None
197
+ encoder_seq_lens_tensor: Optional[torch.Tensor] = None
198
+
199
+ # Maximum sequence length among encoder sequences
200
+ max_encoder_seq_len: Optional[int] = None
201
+
202
+ # Number of tokens input to encoder
203
+ num_encoder_tokens: Optional[int] = None
204
+
205
+ # Mask for normal situation
206
+ attn_mask: Optional[torch.Tensor] = None
207
+
208
+ # Mask for prefix caching
209
+ compress_mask: Optional[torch.Tensor] = None
210
+
211
+ # Mask for chunked prefill
212
+ chunk_mask: Optional[torch.Tensor] = None
213
+
214
+ # Cross-attention memory-mapping data structures: slot mapping
215
+ # and block tables
216
+ cross_slot_mapping: Optional[torch.Tensor] = None
217
+ cross_block_tables: Optional[torch.Tensor] = None
218
+
219
+ @property
220
+ def prefill_metadata(self) -> Optional["AscendMetadata"]:
221
+ if self.num_prefills == 0:
222
+ return None
223
+
224
+ if self._cached_prefill_metadata is not None:
225
+ # Recover cached prefill-phase attention
226
+ # metadata structure.
227
+ return self._cached_prefill_metadata
228
+
229
+ assert ((self.seq_lens is not None)
230
+ or (self.encoder_seq_lens is not None))
231
+
232
+ # Compute some attn_metadata fields which default to None.
233
+ slot_mapping = (None if self.slot_mapping is None else
234
+ self.slot_mapping[:self.num_prefill_tokens])
235
+ seq_lens = (None if self.seq_lens is None else
236
+ self.seq_lens[:self.num_prefills])
237
+ query_lens = (None if self.query_lens is None else
238
+ self.query_lens[:self.num_prefills])
239
+ block_tables = (None if self.block_tables is None else
240
+ self.block_tables[:self.num_prefills])
241
+
242
+ seq_lens_tensor = (None if self.seq_lens_tensor is None else
243
+ self.seq_lens_tensor[:self.num_prefills])
244
+
245
+ # Construct & cache prefill-phase attention metadata structure.
246
+ self._cached_prefill_metadata = AscendMetadata(
247
+ num_prefills=self.num_prefills,
248
+ num_prefill_tokens=self.num_prefill_tokens,
249
+ num_decode_tokens=0,
250
+ slot_mapping=slot_mapping,
251
+ seq_lens=seq_lens,
252
+ seq_lens_tensor=seq_lens_tensor,
253
+ query_lens=query_lens,
254
+ max_query_len=self.max_query_len,
255
+ max_prefill_seq_len=self.max_prefill_seq_len,
256
+ max_decode_seq_len=0,
257
+ chunked_prefill_enabled=self.chunked_prefill_enabled,
258
+ block_tables=block_tables,
259
+ # Begin encoder & cross attn fields below...
260
+ encoder_seq_lens=self.encoder_seq_lens,
261
+ encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
262
+ max_encoder_seq_len=self.max_encoder_seq_len,
263
+ multi_modal_placeholder_index_maps=self.
264
+ multi_modal_placeholder_index_maps,
265
+ cross_slot_mapping=self.cross_slot_mapping,
266
+ cross_block_tables=self.cross_block_tables,
267
+ enable_kv_scales_calculation=False)
268
+ return self._cached_prefill_metadata
269
+
270
+ @property
271
+ def decode_metadata(self) -> Optional["AscendMetadata"]:
272
+ if self.num_decode_tokens == 0:
273
+ return None
274
+
275
+ if self._cached_decode_metadata is not None:
276
+ # Recover cached decode-phase attention
277
+ # metadata structure.
278
+ return self._cached_decode_metadata
279
+
280
+ # Compute some attn_metadata fields which default to None.
281
+ slot_mapping = (None if self.slot_mapping is None else
282
+ self.slot_mapping[self.num_prefill_tokens:])
283
+ seq_lens = (None if self.seq_lens is None else
284
+ self.seq_lens[self.num_prefills:])
285
+ query_lens = (None if self.query_lens is None else
286
+ self.query_lens[self.num_prefills:])
287
+ block_tables = (None if self.block_tables is None else
288
+ self.block_tables[self.num_prefills:])
289
+ seq_lens_tensor = (None if self.seq_lens_tensor is None else
290
+ self.seq_lens_tensor[self.num_prefills:])
291
+ # Construct & cache decode-phase attention metadata structure.
292
+ self._cached_decode_metadata = AscendMetadata(
293
+ num_prefills=0,
294
+ num_prefill_tokens=0,
295
+ num_decode_tokens=self.num_decode_tokens,
296
+ slot_mapping=slot_mapping,
297
+ seq_lens=seq_lens,
298
+ seq_lens_tensor=seq_lens_tensor,
299
+ query_lens=query_lens,
300
+ max_query_len=self.max_query_len,
301
+ max_prefill_seq_len=0,
302
+ max_decode_seq_len=self.max_decode_seq_len,
303
+ chunked_prefill_enabled=self.chunked_prefill_enabled,
304
+ block_tables=block_tables,
305
+ # Begin encoder & cross attn fields below...
306
+ encoder_seq_lens=self.encoder_seq_lens,
307
+ encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
308
+ max_encoder_seq_len=self.max_encoder_seq_len,
309
+ multi_modal_placeholder_index_maps=self.
310
+ multi_modal_placeholder_index_maps,
311
+ cross_slot_mapping=self.cross_slot_mapping,
312
+ cross_block_tables=self.cross_block_tables,
313
+ enable_kv_scales_calculation=False)
314
+ return self._cached_decode_metadata
315
+
316
+ def advance_step(self,
317
+ model_input: "ModelInputForNPUWithSamplingMetadata",
318
+ sampled_token_ids: Optional[torch.Tensor],
319
+ block_size: int,
320
+ num_seqs: int,
321
+ num_queries: int,
322
+ turn_prefills_into_decodes: bool = False):
323
+ """
324
+ Update metadata in-place to advance one decode step.
325
+ """
326
+ # When using cudagraph, the num_seqs is padded to the next captured
327
+ # batch sized, but num_queries tracks the actual number of requests in
328
+ # the batch. For --enforce-eager mode, num_seqs == num_queries
329
+ if num_seqs != num_queries:
330
+ assert num_seqs > num_queries
331
+
332
+ if turn_prefills_into_decodes:
333
+ # When Mutli-Step is enabled with Chunked-Prefill, prefills and
334
+ # decodes are scheduled together. In the first step, all the
335
+ # prefills turn into decodes. This update reflects that
336
+ # conversion.
337
+ assert self.num_decode_tokens + self.num_prefills == num_seqs
338
+ self.num_decode_tokens += self.num_prefills
339
+ self.num_prefills = 0
340
+ self.num_prefill_tokens = 0
341
+ self.max_prefill_seq_len = 0
342
+ self.max_query_len = 1
343
+
344
+ self.slot_mapping = self.slot_mapping[:num_seqs]
345
+ else:
346
+ assert self.seq_lens is not None
347
+ assert self.max_decode_seq_len == max(self.seq_lens)
348
+
349
+ assert self.num_prefills == 0
350
+ assert self.num_prefill_tokens == 0
351
+ assert self.num_decode_tokens == num_seqs
352
+ assert self.slot_mapping.shape == (num_seqs, )
353
+
354
+ assert self.seq_lens is not None
355
+ assert len(self.seq_lens) == num_seqs
356
+ assert self.seq_lens_tensor is not None
357
+ assert self.seq_lens_tensor.shape == (num_seqs, )
358
+ assert self.max_query_len == 1
359
+ assert self.max_prefill_seq_len == 0
360
+
361
+ assert self.block_tables is not None
362
+ assert self.block_tables.shape[0] == num_seqs
363
+
364
+ # Update query lengths. Note that we update only queries and not seqs,
365
+ # since tensors may be padded due to captured cuda graph batch size
366
+ for i in range(num_queries):
367
+ self.seq_lens[i] += 1
368
+ self.max_decode_seq_len = max(self.seq_lens)
369
+ if enable_custom_op():
370
+ #advance a step on NPU for existing inputs for a multi-step runner if custom ops is enabled
371
+ torch.ops._C.advance_step_flashattn_ascendc(
372
+ num_seqs=num_seqs,
373
+ num_queries=num_queries,
374
+ block_size=block_size,
375
+ input_tokens=model_input.input_tokens,
376
+ sampled_token_ids=sampled_token_ids,
377
+ input_positions=model_input.input_positions,
378
+ seq_lens=self.seq_lens_tensor,
379
+ slot_mapping=self.slot_mapping,
380
+ block_tables=self.block_tables)
381
+ else:
382
+ # use traditional Pytorch method for updating these tensors.
383
+ # update input_tokens
384
+ sampled_token_ids_list = sampled_token_ids[:
385
+ num_queries].squeeze( # type: ignore
386
+ -1)
387
+ model_input.input_tokens[:
388
+ num_queries] = sampled_token_ids_list # type: ignore
389
+
390
+ # get seq_lens and input_positions
391
+ seq_lens = self.seq_lens_tensor[:num_queries]
392
+ next_seq_lens = seq_lens + 1
393
+ next_input_pos = next_seq_lens - 1
394
+
395
+ # update seq_lens and input_positions
396
+ self.seq_lens_tensor[:num_queries] = next_seq_lens
397
+ model_input.input_positions[:
398
+ num_queries] = next_input_pos # type: ignore
399
+
400
+ # 计算 block index 和 offset
401
+ block_idx = next_input_pos // block_size
402
+ block_offset = next_input_pos % block_size
403
+
404
+ current_block_table = self.block_tables.gather(
405
+ 1, block_idx.unsqueeze(-1)).squeeze(-1)
406
+ slot_num = current_block_table * block_size + block_offset
407
+
408
+ # update slot_mapping
409
+ self.slot_mapping[:num_queries] = slot_num
410
+
411
+
412
+ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
413
+
414
+ _attn_mask_builder = None # noqa
415
+
416
+ def __init__(self, input_builder: "ModelInputForNPUBuilder"):
417
+ self.input_builder = input_builder
418
+ self.runner = input_builder.runner
419
+ self.sliding_window = input_builder.sliding_window
420
+ self.block_size = input_builder.block_size
421
+
422
+ self.attn_mask = None
423
+ self.compress_mask = None
424
+ self.chunk_mask = None
425
+ if AscendMetadataBuilder._attn_mask_builder is None:
426
+ AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder(
427
+ 128, self.input_builder.runner.model_config.dtype)
428
+
429
+ def _add_seq_group(
430
+ self, inter_data: ModelInputForNPUBuilder.InterDataForSeqGroup,
431
+ chunked_prefill_enabled: bool):
432
+ """Add a sequence group to the metadata. Specifically update/append
433
+ 1. context length.
434
+ 2. block table.
435
+ 3. slot mapping.
436
+ """
437
+ is_prompt = inter_data.is_prompt
438
+ block_tables = inter_data.block_tables
439
+
440
+ for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
441
+ curr_sliding_window_block) in zip(
442
+ inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
443
+ inter_data.orig_seq_lens, inter_data.seq_lens,
444
+ inter_data.query_lens, inter_data.context_lens,
445
+ inter_data.curr_sliding_window_blocks):
446
+ self.context_lens.append(context_len)
447
+ if is_prompt:
448
+ self.num_prefills += 1
449
+ self.num_prefill_tokens += token_len
450
+ self.prefill_seq_lens.append(seq_len)
451
+ else:
452
+ self.num_decode_tokens += query_len
453
+ self.curr_seq_lens.append(curr_seq_len)
454
+
455
+ # Compute block table.
456
+ # TODO(sang): Combine chunked prefill and prefix caching by
457
+ # only allowing multiple of block_size chunk size.
458
+ # NOTE: This only works for oooooooxxx style attention.
459
+ block_table: List[int] = []
460
+ prefix_cache_hit = any([
461
+ inter_data.prefix_cache_hit
462
+ for inter_data in self.input_builder.inter_data_list
463
+ ])
464
+ if prefix_cache_hit:
465
+ # NOTE(woosuk): For flash-attn, the block table should
466
+ # include the entries for the incoming prefill tokens.
467
+ if block_tables is not None:
468
+ block_table = block_tables[seq_id]
469
+ elif ((chunked_prefill_enabled or not is_prompt)
470
+ and block_tables is not None):
471
+ if curr_sliding_window_block == 0:
472
+ block_table = block_tables[seq_id]
473
+ else:
474
+ block_table = block_tables[seq_id][
475
+ -curr_sliding_window_block:]
476
+ self.block_tables.append(block_table)
477
+
478
+ # Compute slot mapping.
479
+ is_profile_run = is_block_tables_empty(block_tables)
480
+ start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
481
+ context_len,
482
+ self.sliding_window)
483
+ compute_slot_mapping(
484
+ is_profile_run,
485
+ self.slot_mapping,
486
+ seq_id,
487
+ seq_len,
488
+ context_len,
489
+ start_idx,
490
+ self.block_size,
491
+ inter_data.block_tables,
492
+ )
493
+
494
+ def _get_graph_runner_block_tables(
495
+ self, num_seqs: int,
496
+ block_tables: List[List[int]]) -> torch.Tensor:
497
+ # The shape of graph_block_tables is
498
+ # [max batch size, max context len // block size].
499
+
500
+ max_batch_size, max_blocks = self.runner.graph_block_tables.shape
501
+ assert max_batch_size >= num_seqs
502
+
503
+ graph_block_tables = self.runner.graph_block_tables # [:num_seqs]
504
+ for i, block_table in enumerate(block_tables):
505
+ if block_table:
506
+ num_blocks = len(block_table)
507
+ if num_blocks <= max_blocks:
508
+ graph_block_tables[i, :num_blocks] = block_table
509
+ else:
510
+ graph_block_tables[
511
+ i, :max_blocks] = block_table[:max_blocks]
512
+
513
+ return torch.from_numpy(graph_block_tables).to(
514
+ device=self.runner.device, non_blocking=True)
515
+
516
+ def build(
517
+ self,
518
+ seq_lens: List[int],
519
+ query_lens: List[int],
520
+ graph_pad_size: int,
521
+ ):
522
+ """Build attention metadata with on-device tensors.
523
+
524
+ Args:
525
+ seq_lens: The maybe padded sequence lengths of the input sequences.
526
+ query_lens: The query lengths of the input sequences.
527
+ """
528
+ for inter_data in self.input_builder.inter_data_list:
529
+ self._add_seq_group(inter_data,
530
+ self.input_builder.chunked_prefill_enabled)
531
+
532
+ device = self.runner.device
533
+ dtype = self.runner.model_config.dtype
534
+ use_npu_graph = graph_pad_size != -1
535
+
536
+ max_query_len = max(query_lens)
537
+ max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
538
+ max_decode_seq_len = max(self.curr_seq_lens, default=0)
539
+ max_seq_len = max(max_prefill_seq_len, max_decode_seq_len)
540
+ num_decode_tokens = self.num_decode_tokens
541
+
542
+ if self.num_prefills == 0 and use_npu_graph:
543
+ num_seqs = len(seq_lens)
544
+ self.slot_mapping.extend([PAD_SLOT_ID] * graph_pad_size)
545
+ self.block_tables.extend([[]] * graph_pad_size)
546
+ block_tables = self._get_graph_runner_block_tables(
547
+ num_seqs, self.block_tables)
548
+ else:
549
+ block_tables = make_tensor_with_pad(
550
+ self.block_tables,
551
+ pad=0,
552
+ dtype=torch.int32,
553
+ device=device,
554
+ )
555
+
556
+ if self.num_prefills > 0:
557
+ if block_tables is None or block_tables.numel() == 0:
558
+ # normal mask
559
+ self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
560
+ max_prefill_seq_len, dtype, device)
561
+ if is_310p():
562
+ mask_nz = nd_to_nz_2d(self.attn_mask)
563
+ mask_nz = torch_npu.npu_format_cast(
564
+ mask_nz.contiguous(), ACL_FORMAT_FRACTAL_NZ)
565
+ self.attn_mask = mask_nz
566
+ elif self.num_decode_tokens == 0 and not self.input_builder.chunked_prefill_enabled:
567
+ # compress mask for prefix cache
568
+ self.compress_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
569
+ 128, dtype, device)
570
+ else:
571
+ # chunk_mask for chunk prefill
572
+ attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
573
+ max_seq_len, dtype, device)
574
+ if attn_mask.numel() > 1 and attn_mask[0][1] > 0:
575
+ # Do not use in-place multiplication to avoid modifying `attn_mask_cache`!
576
+ attn_mask = attn_mask * -10000
577
+ chunk_mask_list = []
578
+ for i, seq_len in enumerate(seq_lens):
579
+ context_len = self.context_lens[i]
580
+ chunk_mask_list.append(attn_mask[context_len:seq_len])
581
+ self.chunk_mask = torch.cat(chunk_mask_list, 0)
582
+ else:
583
+ self.attn_mask = None
584
+ self.compress_mask = None
585
+ self.chunk_mask = None
586
+
587
+ assert max_query_len > 0, "query_lens: {}".format(query_lens)
588
+
589
+ assert device is not None
590
+ slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32,
591
+ device, self.runner.pin_memory)
592
+ seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
593
+ self.runner.pin_memory)
594
+ placeholder_index_maps = {
595
+ modality: placeholder_map.index_map()
596
+ for modality, placeholder_map in
597
+ self.multimodal_placeholder_maps.items()
598
+ }
599
+
600
+ return AscendMetadata(
601
+ num_prefills=self.num_prefills,
602
+ slot_mapping=slot_mapping_tensor,
603
+ num_prefill_tokens=self.num_prefill_tokens,
604
+ num_decode_tokens=num_decode_tokens,
605
+ seq_lens=seq_lens,
606
+ multi_modal_placeholder_index_maps=placeholder_index_maps,
607
+ enable_kv_scales_calculation=True,
608
+ seq_lens_tensor=seq_lens_tensor,
609
+ query_lens=query_lens,
610
+ max_query_len=max_query_len,
611
+ max_prefill_seq_len=max_prefill_seq_len,
612
+ max_decode_seq_len=max_decode_seq_len,
613
+ block_tables=block_tables,
614
+ attn_mask=self.attn_mask,
615
+ compress_mask=self.compress_mask,
616
+ chunk_mask=self.chunk_mask,
617
+ chunked_prefill_enabled=self.input_builder.chunked_prefill_enabled,
618
+ )
619
+
620
+
621
+ class AscendAttentionBackendImpl(AttentionImpl):
622
+
623
+ def __init__(
624
+ self,
625
+ num_heads: int,
626
+ head_size: int,
627
+ scale: float,
628
+ num_kv_heads: int,
629
+ alibi_slopes: Optional[List[float]],
630
+ sliding_window: Optional[int],
631
+ kv_cache_dtype: str,
632
+ blocksparse_params: Optional[Dict[str, Any]] = None,
633
+ logits_soft_cap: Optional[float] = None,
634
+ attn_type: str = AttentionType.DECODER,
635
+ kv_sharing_target_layer_name: Optional[str] = None,
636
+ use_irope: bool = False,
637
+ ) -> None:
638
+ self.num_heads = num_heads
639
+ self.head_size = head_size
640
+ self.scale = float(scale)
641
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
642
+ self.hidden_size = self.num_heads * self.head_size
643
+ self.kv_cache_dtype = kv_cache_dtype
644
+ self.sliding_window = sliding_window
645
+ if alibi_slopes is not None:
646
+ alibi_slopes = torch.tensor(alibi_slopes,
647
+ dtype=torch.float32,
648
+ device="npu")
649
+ self.alibi_slopes = alibi_slopes
650
+ self.attn_type = attn_type
651
+
652
+ assert self.num_heads % self.num_kv_heads == 0
653
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
654
+ self.seq_len_cpu_tensor = None
655
+ self.query_len_cpu_tensor = None
656
+ self.key_cache = None
657
+ self.value_cache = None
658
+
659
+ def forward(
660
+ self,
661
+ layer: AttentionLayer,
662
+ query: torch.Tensor,
663
+ key: torch.Tensor,
664
+ value: torch.Tensor,
665
+ kv_cache: torch.Tensor,
666
+ attn_metadata: AscendMetadata,
667
+ attn_type: str = AttentionType.DECODER,
668
+ output: Optional[torch.Tensor] = None,
669
+ ) -> torch.Tensor:
670
+ """Forward pass with Ascend attention.
671
+ Args:
672
+ query: shape = [num_tokens, num_heads * head_size]
673
+ num_tokens = batch_size * seq_len
674
+ key: shape = [num_tokens, num_kv_heads * head_size]
675
+ value: shape = [num_tokens, num_kv_heads * head_size]
676
+ kv_cache: shape = [2, num_blocks, block_size,
677
+ num_kv_heads, head_size]
678
+ key_cache = [num_blocks, block_size,
679
+ num_kv_heads, head_size]
680
+ value_cache = [num_blocks, block_size,
681
+ num_kv_heads, head_size]
682
+ attn_metadata: Metadata for attention.
683
+ Returns:
684
+ shape = [batch_size, seq_len * num_heads * head_size]
685
+ """
686
+ assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
687
+ # View q k v to BSH.
688
+ num_tokens = query.shape[0]
689
+ query = query.view(-1, self.num_heads, self.head_size)
690
+ key = key.view(-1, self.num_kv_heads, self.head_size)
691
+ value = value.view(-1, self.num_kv_heads, self.head_size)
692
+ # TODO: Remove this contiguous in the future.
693
+ value = value.contiguous()
694
+ attn_type = self.attn_type
695
+
696
+ output = torch.empty(num_tokens,
697
+ self.num_heads,
698
+ self.head_size,
699
+ dtype=query.dtype,
700
+ device=query.device)
701
+
702
+ if kv_cache.numel() > 0:
703
+ if self.key_cache is None:
704
+ self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
705
+ slots = attn_metadata.slot_mapping
706
+
707
+ if hasattr(layer, 'quant_method'):
708
+ isPrefill = True if attn_metadata.num_prefills > 0 else False
709
+ if isPrefill:
710
+ assert attn_metadata.prefill_metadata is not None
711
+ self.seq_lens_tensor_cpu = torch.from_numpy(
712
+ np.array(attn_metadata.prefill_metadata.seq_lens).astype(
713
+ np.int32))
714
+ else:
715
+ assert attn_metadata.decode_metadata is not None
716
+ self.seq_lens_tensor_cpu = torch.from_numpy(
717
+ np.array(attn_metadata.decode_metadata.seq_lens).astype(
718
+ np.int32))
719
+ block_tables = attn_metadata.decode_metadata.block_tables if attn_metadata.decode_metadata else None
720
+ # Details of kv_cache arrangement in attention quantization
721
+ # are implemented by quant_method.
722
+ layer.quant_method.apply(
723
+ layer,
724
+ query,
725
+ key,
726
+ value,
727
+ self.key_cache,
728
+ self.value_cache,
729
+ self.scale,
730
+ block_tables,
731
+ isPrefill,
732
+ attn_metadata,
733
+ output,
734
+ seq_lens_tensor_cpu=self.seq_lens_tensor_cpu)
735
+ else:
736
+ if self.key_cache is not None:
737
+ torch_npu._npu_reshape_and_cache(key=key,
738
+ value=value,
739
+ key_cache=self.key_cache,
740
+ value_cache=self.value_cache,
741
+ slot_indices=slots)
742
+
743
+ if attn_metadata.num_prefills > 0:
744
+ # Prefix cache disabled and chunk prefill disabled or no prefix cache hit
745
+ if (attn_metadata.block_tables is None
746
+ or attn_metadata.block_tables.numel() == 0):
747
+ if attn_type == AttentionType.ENCODER_ONLY:
748
+ # TODO: change to use torch_npu encoder attention op, instead
749
+ # of torch sdpa
750
+ query = query.movedim(0, query.dim() - 2)
751
+ key = key.movedim(0, key.dim() - 2)
752
+ value = value.movedim(0, value.dim() - 2)
753
+
754
+ causal_attn = (attn_type == AttentionType.DECODER)
755
+ if attn_metadata.seq_lens is not None:
756
+ seq_lens_q = seq_lens_kv = attn_metadata.seq_lens
757
+ attn_masks = [None] * len(seq_lens_q)
758
+ start_q, start_kv = 0, 0
759
+ for seq_len_q, seq_len_kv, mask in zip(
760
+ seq_lens_q, seq_lens_kv, attn_masks):
761
+ end_q = start_q + seq_len_q
762
+ end_kv = start_kv + seq_len_kv
763
+ sub_out = scaled_dot_product_attention(
764
+ query[None, :, start_q:end_q, :],
765
+ key[None, :, start_kv:end_kv, :],
766
+ value[None, :, start_kv:end_kv, :],
767
+ attn_mask=mask,
768
+ dropout_p=0.0,
769
+ is_causal=causal_attn and mask is None,
770
+ scale=self.scale).squeeze(0).movedim(
771
+ query.dim() - 2, 0)
772
+ output[start_q:end_q, :, :] = sub_out
773
+ start_q, start_kv = end_q, end_kv
774
+ else:
775
+ assert attn_metadata.attn_mask is not None
776
+ mask = attn_metadata.attn_mask
777
+ assert attn_metadata.prefill_metadata is not None
778
+ self.seq_lens_tensor_cpu = torch.from_numpy(
779
+ np.array(attn_metadata.prefill_metadata.seq_lens).
780
+ astype(np.int32))
781
+ if is_310p():
782
+ # align q k v output tensors
783
+ query = aligned_16(query)
784
+ key = aligned_16(key)
785
+ value = aligned_16(value)
786
+ output = aligned_16(output)
787
+
788
+ # do reformat in case of broadcasted tensors
789
+ mask = mask.repeat(
790
+ self.seq_lens_tensor_cpu.size(0), 1, 1, 1)
791
+ mask = torch_npu.npu_format_cast(
792
+ mask.contiguous(), ACL_FORMAT_FRACTAL_NZ)
793
+ torch_npu._npu_flash_attention(
794
+ query=query,
795
+ key=key,
796
+ value=value,
797
+ mask=mask,
798
+ seq_len=self.seq_lens_tensor_cpu,
799
+ scale_value=self.scale,
800
+ num_heads=self.num_heads,
801
+ num_kv_heads=self.num_kv_heads,
802
+ out=output)
803
+ output = output[:num_tokens, :, :]
804
+ # Prefix cache only and cache hit
805
+ elif attn_metadata.num_decode_tokens == 0 and not attn_metadata.chunked_prefill_enabled:
806
+ assert kv_cache is not None
807
+ assert attn_metadata.prefill_metadata is not None
808
+ self.seq_lens_tensor_cpu = torch.from_numpy(
809
+ np.array(
810
+ attn_metadata.prefill_metadata.seq_lens).astype(
811
+ np.int32))
812
+ self.query_lens_tensor_cpu = torch.from_numpy(
813
+ np.array(
814
+ attn_metadata.prefill_metadata.query_lens).astype(
815
+ np.int32))
816
+ block_tables = attn_metadata.prefill_metadata.block_tables
817
+ assert attn_metadata.compress_mask is not None
818
+ compress_mask = attn_metadata.compress_mask
819
+ torch_npu._npu_flash_attention_qlens(
820
+ query=query,
821
+ key_cache=self.key_cache,
822
+ value_cache=self.value_cache,
823
+ block_table=block_tables,
824
+ mask=compress_mask,
825
+ seq_len=self.query_lens_tensor_cpu,
826
+ context_lens=self.seq_lens_tensor_cpu,
827
+ num_kv_heads=self.num_kv_heads,
828
+ num_heads=self.num_heads,
829
+ scale_value=self.scale,
830
+ out=output)
831
+ # Splitfuse
832
+ else:
833
+ assert kv_cache is not None
834
+ self.seq_lens_tensor_cpu = torch.from_numpy(
835
+ np.array(attn_metadata.seq_lens).astype(np.int32))
836
+ self.query_lens_tensor_cpu = torch.from_numpy(
837
+ np.array(attn_metadata.query_lens).astype(np.int32))
838
+ block_tables = attn_metadata.block_tables
839
+ assert attn_metadata.chunk_mask is not None
840
+ chunk_mask = attn_metadata.chunk_mask
841
+ torch_npu._npu_paged_attention_splitfuse(
842
+ query=query,
843
+ key_cache=self.key_cache,
844
+ value_cache=self.value_cache,
845
+ block_table=block_tables,
846
+ context_lens=self.seq_lens_tensor_cpu,
847
+ mask=chunk_mask,
848
+ seq_len=self.query_lens_tensor_cpu,
849
+ num_kv_heads=self.num_kv_heads,
850
+ num_heads=self.num_heads,
851
+ scale_value=self.scale,
852
+ out=output)
853
+ # Decode only
854
+ else:
855
+ assert self.key_cache is not None
856
+ assert self.value_cache is not None
857
+ assert attn_metadata.decode_metadata is not None
858
+ self.seq_lens_tensor_cpu = torch.from_numpy(
859
+ np.array(attn_metadata.decode_metadata.seq_lens).astype(
860
+ np.int32))
861
+ if is_310p():
862
+ # # seq_lens_tensor needs to be transferred to the device for 310P
863
+ self.seq_lens_tensor_cpu = self.seq_lens_tensor_cpu.to(
864
+ device=self.key_cache.device)
865
+ block_tables = attn_metadata.decode_metadata.block_tables
866
+ torch_npu._npu_paged_attention(
867
+ query=query,
868
+ key_cache=self.key_cache,
869
+ value_cache=self.value_cache,
870
+ num_kv_heads=self.num_kv_heads,
871
+ num_heads=self.num_heads,
872
+ scale_value=self.scale,
873
+ block_table=block_tables,
874
+ context_lens=self.seq_lens_tensor_cpu,
875
+ out=output)
876
+
877
+ return output.view(num_tokens, self.hidden_size)
878
+
879
+
880
+ class AscendMLAAttentionBackendImpl(MLAAttentionImpl):
881
+
882
+ def __init__(
883
+ self,
884
+ num_heads: int,
885
+ head_size: int,
886
+ scale: float,
887
+ num_kv_heads: int,
888
+ alibi_slopes: Optional[List[float]],
889
+ sliding_window: Optional[int],
890
+ kv_cache_dtype: str,
891
+ blocksparse_params: Optional[Dict[str, Any]] = None,
892
+ logits_soft_cap: Optional[float] = None,
893
+ attn_type: str = AttentionType.DECODER,
894
+ kv_sharing_target_layer_name: Optional[str] = None,
895
+ **extra_impl_args,
896
+ ) -> None:
897
+ self.num_heads = num_heads
898
+ self.head_size = head_size
899
+ self.scale = float(scale)
900
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
901
+ self.hidden_size = self.num_heads * self.head_size
902
+ self.kv_cache_dtype = kv_cache_dtype
903
+ self.sliding_window = sliding_window
904
+ if alibi_slopes is not None:
905
+ alibi_slopes = torch.tensor(alibi_slopes,
906
+ dtype=torch.float32,
907
+ device="npu")
908
+ self.alibi_slopes = alibi_slopes
909
+ self.attn_type = attn_type
910
+
911
+ assert self.num_heads % self.num_kv_heads == 0
912
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
913
+ self.seq_len_cpu_tensor = None
914
+
915
+ # MLA Args
916
+ self.q_lora_rank = extra_impl_args['q_lora_rank']
917
+ self.kv_lora_rank = extra_impl_args['kv_lora_rank']
918
+ self.qk_nope_head_dim = extra_impl_args['qk_nope_head_dim']
919
+ self.qk_rope_head_dim = extra_impl_args['qk_rope_head_dim']
920
+ self.qk_head_dim = extra_impl_args['qk_head_dim']
921
+ self.v_head_dim = extra_impl_args['v_head_dim']
922
+ self.rotary_emb = extra_impl_args['rotary_emb']
923
+ self.q_proj = extra_impl_args['q_proj']
924
+ self.kv_b_proj = extra_impl_args['kv_b_proj']
925
+ self.o_proj = extra_impl_args['o_proj']
926
+ self.kv_a_proj_with_mqa = extra_impl_args.get('kv_a_proj_with_mqa',
927
+ None)
928
+ self.kv_a_layernorm = extra_impl_args.get('kv_a_layernorm', None)
929
+ self.k_pe_cache = None
930
+ self.k_nope_cache = None
931
+ self.w_kc = None
932
+ self.w_vc = None
933
+
934
+ ascend_config = get_ascend_config()
935
+ self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
936
+
937
+
938
+ def exec_kv(
939
+ self,
940
+ hidden_states: torch.Tensor,
941
+ cos: torch.Tensor,
942
+ sin: torch.Tensor,
943
+ kv_cache: Tuple,
944
+ slots: torch.Tensor,
945
+ ):
946
+ B = hidden_states.shape[0]
947
+ N = self.num_kv_heads
948
+ S = 1
949
+ kv = self.kv_a_proj_with_mqa(hidden_states)[0]
950
+ # npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
951
+ kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
952
+
953
+ k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
954
+ kv,
955
+ self.kv_a_layernorm.weight,
956
+ cos,
957
+ sin,
958
+ slots.to(torch.int64),
959
+ kv_cache[1],
960
+ kv_cache[0],
961
+ epsilon=self.kv_a_layernorm.variance_epsilon,
962
+ cache_mode="PA",
963
+ )
964
+
965
+ return k_pe, k_nope
966
+
967
+ def apply_rotary_emb(
968
+ self,
969
+ x: torch.Tensor,
970
+ cos: torch.Tensor,
971
+ sin: torch.Tensor,
972
+ is_neox_style: bool,
973
+ ) -> torch.Tensor:
974
+ """
975
+ Args:
976
+ x: [num_tokens, num_heads, head_size]
977
+ cos: [num_tokens, head_size // 2]
978
+ sin: [num_tokens, head_size // 2]
979
+ is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
980
+ positional embeddings.
981
+ """
982
+ cos = cos.unsqueeze(-2).to(x.dtype)
983
+ sin = sin.unsqueeze(-2).to(x.dtype)
984
+ if is_neox_style:
985
+ x1, x2 = torch.chunk(x, 2, dim=-1)
986
+ else:
987
+ x1 = x[..., ::2]
988
+ x2 = x[..., 1::2]
989
+ o1 = x1 * cos - x2 * sin
990
+ o2 = x2 * cos + x1 * sin
991
+ if is_neox_style:
992
+ return torch.cat((o1, o2), dim=-1)
993
+ else:
994
+ return torch.stack((o1, o2), dim=-1).flatten(-2)
995
+
996
+ def rope_single(
997
+ self,
998
+ x: torch.Tensor,
999
+ cos: torch.Tensor,
1000
+ sin: torch.Tensor,
1001
+ ) -> torch.Tensor:
1002
+ B, N, D = x.shape
1003
+ S = 1
1004
+ x = x.view(B, N, S, D)
1005
+ x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
1006
+ return x.view(B, N, D)
1007
+
1008
+ def process_weights_after_loading(self, act_dtype: torch.dtype):
1009
+ if self.w_kc is None or self.w_vc is None:
1010
+ kv_b_proj_weight = self.kv_b_proj.weight.reshape(
1011
+ self.num_heads, self.qk_nope_head_dim + self.v_head_dim,
1012
+ self.kv_lora_rank)
1013
+ self.w_kc = kv_b_proj_weight[:, :self.
1014
+ qk_nope_head_dim, :].contiguous()
1015
+ self.w_vc = kv_b_proj_weight[:,
1016
+ self.qk_nope_head_dim:, :].transpose(
1017
+ 1, 2).contiguous()
1018
+
1019
+ def forward(
1020
+ self,
1021
+ layer: AttentionLayer,
1022
+ hidden_states_or_q_c: torch.Tensor,
1023
+ hidden_states_or_kv_c_normed: torch.Tensor,
1024
+ k_pe: torch.Tensor,
1025
+ kv_cache: torch.Tensor,
1026
+ attn_metadata: AscendMetadata,
1027
+ attn_type: str = AttentionType.DECODER,
1028
+ output: Optional[torch.Tensor] = None,
1029
+ ) -> torch.Tensor:
1030
+ """Forward pass with Ascend attention.
1031
+ Args:
1032
+ hidden_states_or_q_c: shape = [num_tokens, num_heads * head_size]
1033
+ num_tokens = batch_size * seq_len
1034
+ hidden_states_or_kv_c_normed: shape = [num_tokens, num_kv_heads * head_size]
1035
+ k_pe: shape = [num_tokens, num_kv_heads * head_size]
1036
+ kv_cache: shape = [1, num_blocks, block_size,
1037
+ num_kv_heads * head_size]
1038
+ attn_metadata: Metadata for attention.
1039
+ Returns:
1040
+ shape = [batch_size, seq_len * num_heads * head_size]
1041
+ """
1042
+ assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
1043
+ attn_type = self.attn_type
1044
+ if attn_type != AttentionType.DECODER:
1045
+ raise NotImplementedError("Encoder self-attention and "
1046
+ "encoder/decoder cross-attention "
1047
+ "are not implemented for "
1048
+ "PallasAttentionBackendImpl")
1049
+
1050
+ if attn_metadata is None:
1051
+ # for profile run
1052
+ return hidden_states_or_q_c
1053
+
1054
+ num_tokens = hidden_states_or_q_c.shape[0]
1055
+ q = self.q_proj(hidden_states_or_q_c)[0].view(-1, self.num_heads,
1056
+ self.qk_head_dim)
1057
+ q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
1058
+ dim=-1)
1059
+ if k_pe is None and attn_metadata.decode_metadata:
1060
+ seq_len = self.rotary_emb.max_position_embeddings
1061
+
1062
+ cos = self.rotary_emb.cos_cached[:seq_len].to(dtype=q_pe.dtype)
1063
+ sin = self.rotary_emb.sin_cached[:seq_len].to(dtype=q_pe.dtype)
1064
+ cos = cos[attn_metadata.input_positions]
1065
+ sin = sin[attn_metadata.input_positions]
1066
+ cos = cos[:, None, None, :]
1067
+ sin = sin[:, None, None, :]
1068
+
1069
+ q_pe = self.rope_single(q_pe, cos, sin)
1070
+ k_pe, k_nope = self.exec_kv(hidden_states_or_kv_c_normed, cos, sin,
1071
+ kv_cache, attn_metadata.slot_mapping)
1072
+ else:
1073
+ if k_pe is None:
1074
+ # NOTE: k_pe is None when graph mode enabled
1075
+ kv_c, k_pe = self.kv_a_proj_with_mqa(
1076
+ hidden_states_or_kv_c_normed)[0].split(
1077
+ [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1078
+ kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
1079
+ else:
1080
+ kv_c_normed = hidden_states_or_kv_c_normed
1081
+ k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)
1082
+ if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
1083
+ # NOTE: When scaling not specified
1084
+ ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
1085
+ q_pe = q_pe.reshape(num_tokens, -1)
1086
+ k_pe = k_pe.reshape(num_tokens, -1)
1087
+ q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
1088
+ q_pe, k_pe)
1089
+ q_pe = q_pe.view(ori_q_pe_shape)
1090
+ k_pe = k_pe.view(ori_k_pe_shape)
1091
+ else:
1092
+ q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions,
1093
+ q_pe, k_pe)
1094
+
1095
+ if attn_metadata.num_prefills > 0:
1096
+ kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens,
1097
+ self.num_heads, -1)
1098
+ k_nope, value = kv.split([self.qk_nope_head_dim, self.v_head_dim],
1099
+ dim=-1)
1100
+ else:
1101
+ q_nope_t = torch.transpose(q_nope, 0, 1)
1102
+ q_nope_out = torch.bmm(q_nope_t, self.w_kc)
1103
+ q_nope = torch.transpose(q_nope_out, 0, 1)
1104
+
1105
+ query = torch.cat([q_nope, q_pe], dim=-1).view(num_tokens,
1106
+ self.num_heads, -1)
1107
+
1108
+ # TODO: Replace the env with more flexible expressions
1109
+ if self.torchair_graph_enabled:
1110
+ if len(kv_cache) > 0 and kv_cache[0].numel(
1111
+ ) > 0 and attn_metadata.num_prefills > 0:
1112
+ slots = attn_metadata.slot_mapping
1113
+ # NOTE: Separate the kv cache in advance to avoid OOM or other issues
1114
+ torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
1115
+ num_tokens, self.num_kv_heads, -1),
1116
+ value=k_pe,
1117
+ key_cache=kv_cache[0],
1118
+ value_cache=kv_cache[1],
1119
+ slot_indices=slots)
1120
+ elif kv_cache.numel() > 0:
1121
+ # TODO replace this naive implement with fusion kernel
1122
+ concat_and_cache_mla(kv_c_normed, k_pe, kv_cache,
1123
+ attn_metadata.slot_mapping)
1124
+
1125
+ if attn_metadata.num_prefills > 0:
1126
+ attn_output = torch.empty(num_tokens,
1127
+ self.num_heads,
1128
+ self.v_head_dim,
1129
+ dtype=query.dtype,
1130
+ device=query.device)
1131
+ if (attn_metadata.block_tables is None
1132
+ or attn_metadata.block_tables.numel() == 0):
1133
+ assert attn_metadata.attn_mask is not None
1134
+ assert attn_metadata.prefill_metadata is not None
1135
+ assert attn_metadata.prefill_metadata.seq_lens is not None
1136
+ mask = attn_metadata.attn_mask
1137
+ self.seq_lens_tensor_cpu = torch.from_numpy(
1138
+ np.array(attn_metadata.prefill_metadata.seq_lens).astype(
1139
+ np.int32))
1140
+ k_pe = k_pe.repeat(1, self.num_heads, 1)
1141
+ key = torch.cat(
1142
+ [k_nope.view(num_tokens, self.num_heads, -1), k_pe], dim=2)
1143
+ torch_npu._npu_flash_attention(
1144
+ query=query,
1145
+ key=key,
1146
+ value=value,
1147
+ mask=mask,
1148
+ seq_len=self.seq_lens_tensor_cpu,
1149
+ scale_value=self.scale,
1150
+ num_heads=self.num_heads,
1151
+ num_kv_heads=self.num_heads,
1152
+ out=attn_output)
1153
+ else:
1154
+ # TODO: Will support prefix cache and chunked prefill soon.
1155
+ raise RuntimeError(
1156
+ "Prefix cache and chunked prefill are currently not supported."
1157
+ )
1158
+ elif attn_metadata.decode_metadata:
1159
+ assert kv_cache is not None
1160
+ if self.torchair_graph_enabled:
1161
+ # shape of query for npu graph mode should be:
1162
+ # [bs, num_heads_per_rank, seq_len, dim]
1163
+ q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
1164
+ q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
1165
+ # shape of knope/k_pe for npu graph mode should be:
1166
+ # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
1167
+ block_size = kv_cache[0].shape[1]
1168
+ k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
1169
+ self.kv_lora_rank)
1170
+ k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
1171
+ self.qk_rope_head_dim)
1172
+ attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
1173
+ q_nope,
1174
+ k_nope,
1175
+ k_nope,
1176
+ query_rope=q_pe,
1177
+ key_rope=k_pe,
1178
+ num_heads=self.num_heads,
1179
+ num_key_value_heads=self.num_kv_heads,
1180
+ input_layout="BNSD",
1181
+ atten_mask=attn_metadata.attn_mask,
1182
+ scale=self.scale,
1183
+ antiquant_mode=0,
1184
+ antiquant_scale=None,
1185
+ block_table=attn_metadata.block_tables,
1186
+ block_size=block_size,
1187
+ actual_seq_lengths_kv=attn_metadata.seq_lens,
1188
+ )
1189
+ attn_output = attn_output.view(num_tokens, -1,
1190
+ self.kv_lora_rank).transpose(
1191
+ 0, 1)
1192
+ attn_output = torch.bmm(attn_output, self.w_vc).transpose(0, 1)
1193
+ else:
1194
+ # if torch.empty is used here, the preemptive scheduling case of
1195
+ # test_mtp_correctness.py will fail to run.
1196
+ attn_output = torch.randn(
1197
+ [num_tokens, self.num_heads, self.kv_lora_rank],
1198
+ dtype=query.dtype,
1199
+ device=query.device)
1200
+ self.seq_lens_tensor_cpu = torch.from_numpy(
1201
+ np.array(attn_metadata.decode_metadata.seq_lens).astype(
1202
+ np.int32))
1203
+ block_tables = attn_metadata.decode_metadata.block_tables
1204
+ torch_npu._npu_paged_attention_mla(
1205
+ query=query,
1206
+ key_cache=kv_cache,
1207
+ num_kv_heads=self.num_kv_heads,
1208
+ num_heads=self.num_heads,
1209
+ scale_value=self.scale,
1210
+ block_table=block_tables,
1211
+ context_lens=self.seq_lens_tensor_cpu,
1212
+ mla_vheadsize=self.kv_lora_rank,
1213
+ out=attn_output)
1214
+ attn_output_t = torch.transpose(attn_output, 0, 1)
1215
+ attn_output_t = torch.bmm(attn_output_t, self.w_vc)
1216
+ attn_output = torch.transpose(attn_output_t, 0, 1)
1217
+
1218
+ output, _ = self.o_proj(attn_output.reshape(num_tokens, -1))
1219
+
1220
+ return output
inference/vllm_ascend/attention/mla_v1.py ADDED
@@ -0,0 +1,1224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import TYPE_CHECKING, Any, Optional, Tuple, Type, TypeVar
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch_npu
7
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
8
+ AttentionMetadata,
9
+ MLAAttentionImpl)
10
+ from vllm.attention.backends.utils import PAD_SLOT_ID
11
+ from vllm.config import get_current_vllm_config
12
+ from vllm.distributed import get_tensor_model_parallel_world_size
13
+ from vllm.model_executor.layers.linear import (LinearBase,
14
+ UnquantizedLinearMethod)
15
+ from vllm.utils import cdiv, round_down
16
+
17
+ from vllm_ascend.ascend_config import get_ascend_config
18
+ from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV
19
+ from vllm_ascend.attention.attention_v1 import AscendAttentionState
20
+ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
21
+ from vllm_ascend.multistream.context import get_multistream_comm_context
22
+ from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
23
+ from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
24
+ from vllm_ascend.utils import npu_prefetch, npu_stream_switch, npu_wait_tensor
25
+ from vllm_ascend.worker.npu_input_batch import InputBatch
26
+
27
+ if TYPE_CHECKING:
28
+ from vllm.v1.core.sched.output import SchedulerOutput
29
+
30
+
31
+ @dataclass
32
+ class CommonAttentionMetadata:
33
+ """
34
+ Attention metadata attributes that can be shared by layers in different KV
35
+ cache groups and thus having different block table.
36
+ """
37
+
38
+ query_start_loc: torch.Tensor
39
+ """(batch_size + 1,), the start location of each request in query Tensor"""
40
+ seq_lens: torch.Tensor
41
+ """(batch_size,), the length of each request including both computed tokens
42
+ and newly scheduled tokens"""
43
+
44
+
45
+ class AscendMLABackend(AttentionBackend):
46
+
47
+ accept_output_buffer: bool = True
48
+
49
+ @staticmethod
50
+ def get_name() -> str:
51
+ return "VLLM_ASCEND_MLA"
52
+
53
+ @staticmethod
54
+ def get_metadata_cls() -> type["AttentionMetadata"]:
55
+ return AscendMLAMetadata
56
+
57
+ @staticmethod
58
+ def get_builder_cls():
59
+ return AscendMLAMetadataBuilder
60
+
61
+ @staticmethod
62
+ def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
63
+ head_size: int) -> tuple[int, ...]:
64
+ return (num_blocks, block_size, num_kv_heads, head_size)
65
+
66
+ @staticmethod
67
+ def get_impl_cls() -> Type["MLAAttentionImpl"]:
68
+ return AscendMLAImpl
69
+
70
+
71
+ @dataclass
72
+ class AscendMLAPrefillMetadata:
73
+ """ Prefill Specific Metadata for Ascend"""
74
+
75
+ @dataclass
76
+ class ChunkedContextMetadata:
77
+ # New for MLA (compared to FlashAttention)
78
+ # For handling chunked prefill
79
+ cu_seq_lens: torch.Tensor
80
+ starts: torch.Tensor
81
+ seq_tot: list[int]
82
+ max_seq_lens: list[int]
83
+ workspace: torch.Tensor
84
+ chunk_seq_lens: torch.Tensor
85
+
86
+ attn_mask: torch.Tensor
87
+ query_lens: list[int]
88
+ seq_lens: list[int]
89
+ context_lens: torch.Tensor
90
+ input_positions: torch.Tensor
91
+ query_start_loc: torch.Tensor
92
+ block_table: torch.Tensor
93
+ max_query_len: int
94
+ max_seq_lens: int
95
+ chunked_context: Optional[ChunkedContextMetadata] = None
96
+
97
+
98
+ @dataclass
99
+ class AscendMLADecodeMetadata:
100
+ # Input positions for rotrary embeddings since for MLA the rotary
101
+ # position embeddings are applied inside the attention backend
102
+ input_positions: torch.Tensor
103
+ block_table: torch.Tensor
104
+ seq_lens: torch.Tensor
105
+ max_seq_lens: int
106
+ seq_lens_list: list[int]
107
+ attn_mask: Optional[torch.Tensor] = None
108
+
109
+
110
+ @dataclass
111
+ class AscendMLAMetadata:
112
+ """Metadata for MLACommon.
113
+
114
+ NOTE: Please read the comment at the top of the file before trying to
115
+ understand this class
116
+ """
117
+ # NOTE(sang): Definition of context_len, query_len, and seq_len.
118
+ # |---------- N-1 iteration --------|
119
+ # |---------------- N iteration ---------------------|
120
+ # |- tokenA -|......................|-- newTokens ---|
121
+ # |---------- context_len ----------|
122
+ # |-------------------- seq_len ---------------------|
123
+ # |-- query_len ---|
124
+
125
+ num_actual_tokens: int # Number of tokens excluding padding.
126
+ slot_mapping: torch.Tensor
127
+ query_start_loc: torch.Tensor
128
+ seq_lens: torch.Tensor
129
+ block_tables: torch.Tensor
130
+
131
+ # New for MLA (compared to FlashAttention)
132
+ # For handling prefill decode split
133
+ num_decodes: int
134
+ num_decode_tokens: int
135
+ num_prefills: int
136
+
137
+ # For logging.
138
+ num_input_tokens: int = 0 # Number of tokens including padding.
139
+
140
+ max_num_tokens_across_dp: int = 0
141
+ with_prefill_across_dp: bool = False
142
+
143
+ query_lens: Optional[list[int]] = None
144
+ # The dimension of the attention heads
145
+ head_dim: Optional[int] = None
146
+ attn_mask: torch.Tensor = None
147
+ # chunked prefill by default if no attn_states passed
148
+ attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
149
+
150
+ decode: Optional[AscendMLADecodeMetadata] = None
151
+ prefill: Optional[AscendMLAPrefillMetadata] = None
152
+
153
+ def __post_init__(self):
154
+ pass
155
+ # supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
156
+ # if self.head_dim is not None and self.head_dim \
157
+ # not in supported_head_sizes:
158
+ # raise ValueError(
159
+ # f"Only {supported_head_sizes} are supported for head_dim,",
160
+ # f"received {self.head_dim}.")
161
+
162
+ def split_metadata_for_multistream(
163
+ self,
164
+ ms_split_config: MSAttentionMetadataSplitConfig,
165
+ ) -> list["AscendMLAMetadata"]:
166
+ """Split metadata for multi-stream with AscendMLAMetadata"""
167
+ return model_input_split_v1_mla_attn(
168
+ ms_split_config=ms_split_config,
169
+ attn_metadata=self,
170
+ _metadata_cls=AscendMLAMetadata,
171
+ )
172
+
173
+
174
+ M = TypeVar("M", bound=AscendMLAMetadata)
175
+
176
+
177
+ class AscendMLAMetadataBuilder:
178
+ """
179
+ NOTE: Please read the comment at the top of the file before trying to
180
+ understand this class
181
+ """
182
+
183
+ # _attn_mask_builder = None
184
+ def __init__(self,
185
+ runner,
186
+ metadata_cls: Optional[AscendMLAMetadata] = None):
187
+ self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \
188
+ if metadata_cls is not None else AscendMLAMetadata # type: ignore
189
+ self.runner = runner
190
+ scheduler_config = runner.scheduler_config
191
+ model_config = runner.model_config
192
+ self.block_size = runner.block_size
193
+ self.chunked_prefill_enabled = runner.chunked_prefill_enabled
194
+ if self.chunked_prefill_enabled:
195
+ self.chunked_prefill_workspace_size = min(
196
+ # Max sure there is enough for 8 full length request or at least
197
+ # 4 pages of cache per request
198
+ max(8 * model_config.max_model_len,
199
+ 4 * scheduler_config.max_num_seqs * self.block_size),
200
+ # For long-context models try not to over-allocate limiting
201
+ # kv-cache space, limiting it to 64k tokens,
202
+ # which would result in the workspace being:
203
+ # 2*(576)*(64*1024) = 144mb
204
+ # (assuming 576 MLA head dim, and fp16)
205
+ # which would result in up-projected context being
206
+ # 2*(192*128)*(64*1024) = 3gb
207
+ # (assuming 192 QK head dim, 128 heads, and fp16)
208
+ 128 * 1024)
209
+ assert self.chunked_prefill_workspace_size >= \
210
+ scheduler_config.max_num_seqs * self.block_size
211
+ self.chunked_prefill_workspace = torch.empty(
212
+ (self.chunked_prefill_workspace_size,
213
+ model_config.get_head_size()),
214
+ dtype=model_config.dtype,
215
+ device=runner.device,
216
+ )
217
+ ascend_config = get_ascend_config()
218
+ self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
219
+
220
+ def reorder_batch(self, input_batch: "InputBatch",
221
+ scheduler_output: "SchedulerOutput") -> bool:
222
+ # We now want to reorder the batch so that the "decode" requests are at
223
+ # the front and the "prefill" requests are at the using the least amount
224
+ # swaps possible. (NOTE for now we loosely use "decode" to mean requests
225
+ # where attention is likely memory-bound and "prefill" to mean requests
226
+ # where attention is likely compute-bound, TODO(lucas): figure out a
227
+ # better naming here)
228
+ decodes = []
229
+ prefills = []
230
+ num_decode_tokens = 0
231
+ num_prefill_tokens = 0
232
+
233
+ for i, req_id in enumerate(input_batch.req_ids):
234
+ num_tokens = scheduler_output.num_scheduled_tokens[req_id]
235
+ num_spec_tokens = len(
236
+ scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
237
+ # For torch air graph mode we treat spec decoding as decode.
238
+ if self.torchair_graph_enabled:
239
+ if num_tokens - num_spec_tokens == 1:
240
+ decodes.append(i)
241
+ num_decode_tokens += num_tokens
242
+ else:
243
+ prefills.append(i)
244
+ num_prefill_tokens += num_tokens
245
+ # For eager mode we treat spec decoding as chunked prefill.
246
+ else:
247
+ if num_tokens == 1:
248
+ decodes.append(i)
249
+ num_decode_tokens += num_tokens
250
+ else:
251
+ prefills.append(i)
252
+ num_prefill_tokens += num_tokens
253
+
254
+ # We hope that this is fairly minimal since decodes
255
+ # should be around for a number of iterations so hopefully they are
256
+ # relatively stationary (and new request are generally appended to the
257
+ # persistent batch so already should be at the back)
258
+ # To achieve this we loop over the decodes in descending order and
259
+ # the prefills in ascending order. We swap decodes from the "back"
260
+ # i.e. past where the last decode should be in the reodorered with
261
+ # prefills from the front of the batch.
262
+ # `decodes` and `prefills` are already in ascending order just based on
263
+ # the above loop
264
+ num_decodes = len(decodes)
265
+ num_prefills = len(prefills)
266
+ first_prefill = 0
267
+ modified_batch = False
268
+
269
+ for i in range(1, min(num_decodes, num_prefills) + 1):
270
+ # If the decode is at the "back" of the batch, i, we can swap it
271
+ # with the prefill closest to the front of the batch
272
+ if decodes[num_decodes - i] >= num_decodes:
273
+ input_batch.swap_states(prefills[first_prefill],
274
+ decodes[num_decodes - i])
275
+ first_prefill += 1
276
+ modified_batch = True
277
+ else:
278
+ break
279
+
280
+ # Save for next `build` call
281
+ # TODO(lucas): this is a bit of a hack, we should probably have a
282
+ # better way of doing this
283
+ self._num_decodes = num_decodes
284
+ self._num_prefills = num_prefills
285
+ self._num_decode_tokens = num_decode_tokens
286
+ self._num_prefill_tokens = num_prefill_tokens
287
+
288
+ return modified_batch
289
+
290
+ def _get_graph_runner_block_tables(
291
+ self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
292
+
293
+ max_batch_size, max_blocks = self.runner.graph_block_tables.shape
294
+ assert max_batch_size >= num_seqs
295
+
296
+ if isinstance(self.runner.graph_block_tables, np.ndarray):
297
+ graph_block_tables = torch.zeros((max_batch_size, max_blocks),
298
+ dtype=block_tables.dtype,
299
+ device=block_tables.device)
300
+ else:
301
+ graph_block_tables = self.runner.graph_block_tables.to(
302
+ device=block_tables.device, dtype=block_tables.dtype)
303
+
304
+ num_blocks = block_tables.size(1)
305
+ if num_blocks <= max_blocks:
306
+ graph_block_tables[:num_seqs, :
307
+ num_blocks] = block_tables[:num_seqs, :
308
+ num_blocks]
309
+ else:
310
+ graph_block_tables[:num_seqs, :
311
+ max_blocks] = block_tables[:num_seqs, :
312
+ max_blocks]
313
+
314
+ return graph_block_tables[:num_seqs, :max_blocks]
315
+
316
+ def build_dummy(self, num_reqs: int,
317
+ num_actual_tokens: int) -> AscendMLAMetadata:
318
+ device = self.runner.device
319
+ _, max_blocks = self.runner.graph_block_tables.shape
320
+ block_table = torch.zeros((num_reqs, max_blocks),
321
+ dtype=torch.int32,
322
+ device=device)
323
+ block_table = self._get_graph_runner_block_tables(
324
+ num_reqs, block_table)
325
+ seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
326
+ input_positions = torch.zeros(num_reqs,
327
+ dtype=torch.int32,
328
+ device=device).long()
329
+ slot_mapping = torch.full((num_reqs, ),
330
+ PAD_SLOT_ID,
331
+ dtype=torch.int32,
332
+ device=device)
333
+ query_start_loc = torch.full((num_reqs, ),
334
+ -1,
335
+ dtype=torch.int32,
336
+ device=device)
337
+ decode_metadata = AscendMLADecodeMetadata(
338
+ input_positions=input_positions,
339
+ block_table=block_table,
340
+ seq_lens=seq_lens,
341
+ seq_lens_list=seq_lens.tolist(),
342
+ max_seq_lens=1,
343
+ attn_mask=self.runner.spec_attn_mask)
344
+ return self.metadata_cls( # type: ignore
345
+ num_input_tokens=num_actual_tokens,
346
+ num_actual_tokens=num_actual_tokens,
347
+ slot_mapping=slot_mapping,
348
+ head_dim=self.runner.model_config.get_head_size(),
349
+ num_decodes=1,
350
+ num_decode_tokens=1,
351
+ num_prefills=0,
352
+ attn_mask=self.runner.attn_mask,
353
+ attn_state=AscendAttentionState.DecodeOnly,
354
+ prefill=None,
355
+ decode=decode_metadata,
356
+ query_start_loc=query_start_loc,
357
+ seq_lens=seq_lens,
358
+ block_tables=block_table,
359
+ )
360
+
361
+ def build(
362
+ self,
363
+ num_reqs: int,
364
+ num_actual_tokens: int,
365
+ max_query_len: int,
366
+ common_attn_metadata: CommonAttentionMetadata,
367
+ common_prefix_len: Optional[int] = None,
368
+ graph_pad_size: int = -1,
369
+ max_num_tokens_across_dp: int = 0,
370
+ with_prefill_across_dp: bool = False,
371
+ ) -> AscendMLAMetadata:
372
+ assert self._num_decodes + self._num_prefills == num_reqs
373
+
374
+ # Note(simon): be careful about the CPU <> GPU memory movement in this
375
+ # function. We should avoid GPU -> CPU sync as much as possible because
376
+ # it blocks on all previous kernels.
377
+ device = self.runner.device
378
+
379
+ block_table = (self.runner.input_batch.block_table[0].
380
+ get_device_tensor()[:num_reqs])
381
+ slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
382
+ device, non_blocking=True)
383
+ input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
384
+ device, non_blocking=True).long()
385
+
386
+ seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
387
+ query_lens = seq_lens_cpu - self.runner.input_batch.num_computed_tokens_cpu_tensor[:
388
+ num_reqs]
389
+ seq_lens = seq_lens_cpu
390
+ max_query_len = query_lens.max().item()
391
+ max_seq_lens = seq_lens.max().item()
392
+ query_start_loc = common_attn_metadata.query_start_loc
393
+
394
+ prefill_metadata = None
395
+ chunked_context_metadata = None
396
+ if self._num_prefills > 0:
397
+ reqs_start = self._num_decodes # prefill_start
398
+ tokens_start = self._num_decode_tokens
399
+ max_query_len = query_lens[tokens_start:].max().item()
400
+ max_seq_lens = seq_lens[tokens_start:].max().item()
401
+ query_start_loc = common_attn_metadata.query_start_loc
402
+ prefill_query_start_loc = query_start_loc[
403
+ reqs_start:] - query_start_loc[reqs_start]
404
+
405
+ context_lens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[
406
+ reqs_start:num_reqs]
407
+ max_context_len_cpu = context_lens_cpu.max().item()
408
+ num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
409
+ if self.chunked_prefill_enabled and max_context_len_cpu > 0:
410
+ max_context_chunk = (self.chunked_prefill_workspace_size //
411
+ num_prefills_with_context_cpu)
412
+ max_context_chunk = round_down(max_context_chunk,
413
+ self.block_size)
414
+
415
+ assert max_context_chunk > 0
416
+ num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
417
+ chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
418
+ .unsqueeze(1).expand(-1, self._num_prefills) * max_context_chunk
419
+ chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
420
+ chunk_starts + max_context_chunk)
421
+ chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
422
+ cu_seq_lens_cpu = torch.zeros(num_chunks,
423
+ self._num_prefills + 1,
424
+ dtype=torch.int32,
425
+ pin_memory=True)
426
+ torch.cumsum(chunk_seq_lens,
427
+ dim=1,
428
+ out=cu_seq_lens_cpu[:, 1:],
429
+ dtype=torch.int32)
430
+ chunked_context_metadata = \
431
+ AscendMLAPrefillMetadata.ChunkedContextMetadata(
432
+ cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
433
+ starts=chunk_starts.to(device, non_blocking=True),
434
+ seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
435
+ max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
436
+ chunk_seq_lens=chunk_seq_lens,
437
+ workspace=self.chunked_prefill_workspace,
438
+ )
439
+
440
+ prefill_metadata = AscendMLAPrefillMetadata(
441
+ attn_mask=self.runner.attn_mask,
442
+ query_lens=query_lens[tokens_start:],
443
+ seq_lens=seq_lens,
444
+ context_lens=seq_lens[tokens_start:],
445
+ input_positions=input_positions[tokens_start:],
446
+ block_table=block_table[reqs_start:, ...],
447
+ max_query_len=max_query_len,
448
+ max_seq_lens=max_seq_lens,
449
+ query_start_loc=prefill_query_start_loc,
450
+ chunked_context=chunked_context_metadata,
451
+ )
452
+
453
+ decode_metadata = None
454
+ use_torchair_graph = graph_pad_size != -1
455
+ if self._num_decodes > 0:
456
+ max_seq_lens = seq_lens[:self._num_decodes].max().item()
457
+ seq_lens = seq_lens[:self._num_decode_tokens]
458
+ input_positions = input_positions[:self._num_decode_tokens]
459
+ block_table = block_table[:self._num_decode_tokens, ...]
460
+ if use_torchair_graph and self.runner.attn_state in [
461
+ AscendAttentionState.DecodeOnly,
462
+ AscendAttentionState.SpecDecoding
463
+ ]:
464
+ num_seqs = len(seq_lens)
465
+ if graph_pad_size != 0:
466
+ pad_value = 1
467
+ padded_seq_lens = seq_lens.tolist() + [pad_value
468
+ ] * graph_pad_size
469
+ else:
470
+ padded_seq_lens = seq_lens.tolist()
471
+
472
+ seq_lens = torch.from_numpy(
473
+ np.array(padded_seq_lens).astype(np.int32))
474
+ padding = torch.full((graph_pad_size, ),
475
+ PAD_SLOT_ID,
476
+ dtype=slot_mapping.dtype,
477
+ device=slot_mapping.device)
478
+ slot_mapping = torch.cat([slot_mapping, padding])
479
+ block_table_padding = torch.zeros(
480
+ (graph_pad_size, ) + block_table.shape[1:],
481
+ dtype=block_table.dtype,
482
+ device=block_table.device)
483
+ block_table = torch.cat([block_table, block_table_padding],
484
+ dim=0)
485
+ block_table = self._get_graph_runner_block_tables(
486
+ num_seqs + graph_pad_size, block_table)
487
+ padding_0 = torch.zeros(graph_pad_size,
488
+ dtype=input_positions.dtype,
489
+ device=input_positions.device)
490
+ input_positions = torch.cat([input_positions, padding_0])
491
+
492
+ decode_metadata = AscendMLADecodeMetadata(
493
+ input_positions=input_positions,
494
+ block_table=block_table,
495
+ seq_lens=seq_lens,
496
+ seq_lens_list=seq_lens.tolist(),
497
+ max_seq_lens=max_seq_lens,
498
+ attn_mask=self.runner.spec_attn_mask)
499
+
500
+ return self.metadata_cls( # type: ignore
501
+ num_actual_tokens=num_actual_tokens,
502
+ query_lens=query_lens.tolist(),
503
+ slot_mapping=slot_mapping,
504
+ head_dim=self.runner.model_config.get_head_size(),
505
+ num_decodes=self._num_decodes,
506
+ num_decode_tokens=self._num_decode_tokens,
507
+ num_prefills=self._num_prefills,
508
+ attn_mask=self.runner.attn_mask,
509
+ attn_state=self.runner.attn_state,
510
+ prefill=prefill_metadata,
511
+ decode=decode_metadata,
512
+ query_start_loc=query_start_loc,
513
+ block_tables=block_table,
514
+ seq_lens=seq_lens,
515
+ max_num_tokens_across_dp=max_num_tokens_across_dp,
516
+ with_prefill_across_dp=with_prefill_across_dp,
517
+ )
518
+
519
+
520
+ class AscendMLAImpl(MLAAttentionImpl):
521
+ """
522
+ NOTE: Please read the comment at the top of the file before trying to
523
+ understand this class
524
+ """
525
+
526
+ def __init__(
527
+ self,
528
+ num_heads: int,
529
+ head_size: int,
530
+ scale: float,
531
+ num_kv_heads: int,
532
+ alibi_slopes: Optional[list[float]],
533
+ sliding_window: Optional[int],
534
+ kv_cache_dtype: str,
535
+ blocksparse_params: Optional[dict[str, Any]],
536
+ logits_soft_cap: Optional[float],
537
+ attn_type: str,
538
+ kv_sharing_target_layer_name: Optional[str] = None,
539
+ **kwargs,
540
+ ) -> None:
541
+ self.num_heads = num_heads
542
+ self.head_size = head_size
543
+ self.scale = float(scale)
544
+ self.num_kv_heads = num_kv_heads
545
+ self.kv_cache_dtype = kv_cache_dtype
546
+
547
+ # MLA Args
548
+ self.q_lora_rank = kwargs['q_lora_rank']
549
+ self.kv_lora_rank = kwargs['kv_lora_rank']
550
+ self.qk_nope_head_dim = kwargs['qk_nope_head_dim']
551
+ self.qk_rope_head_dim = kwargs['qk_rope_head_dim']
552
+ self.qk_head_dim = kwargs['qk_head_dim']
553
+ self.v_head_dim = kwargs['v_head_dim']
554
+ self.rotary_emb = kwargs['rotary_emb']
555
+ self.q_proj = kwargs['q_proj']
556
+ self.kv_b_proj = kwargs['kv_b_proj']
557
+ self.o_proj = kwargs['o_proj']
558
+ self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
559
+ self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
560
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
561
+ self.tp_size = get_tensor_model_parallel_world_size()
562
+
563
+ ascend_config = get_ascend_config()
564
+ self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
565
+ self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
566
+
567
+ # Adapt torch air graph mode with spec decoding.
568
+ speculative_config = get_current_vllm_config().speculative_config
569
+ if speculative_config is not None:
570
+ self.spec_token_num = speculative_config.num_speculative_tokens
571
+ assert self.spec_token_num > 0
572
+ self.SHARE_MASK_TRIL_SPARSE = ~torch.tril(torch.ones((2048, 2048), dtype=torch.bool)).npu()
573
+
574
+ def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False):
575
+ # Convert from (B, N, L) to (N, B, L)
576
+ x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
577
+ # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
578
+ x = torch.bmm(x, self.W_UV)
579
+ # Convert from (N, B, V) to (B, N * V)
580
+ x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
581
+ MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
582
+ npu_prefetch(self.o_proj.weight,
583
+ x,
584
+ max_size=MAX_O_PROJ_PREFETCH_SIZE,
585
+ enabled=enable_multistream_mla)
586
+ return self.o_proj(x, is_prefill=False)[0]
587
+
588
+ # Return `ql_nope`, `q_pe`
589
+ def _q_proj_and_k_up_proj(self, x):
590
+ q_nope, q_pe = self.q_proj(x)[0]\
591
+ .view(-1, self.num_heads, self.qk_head_dim)\
592
+ .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
593
+
594
+ # Convert from (B, N, P) to (N, B, P)
595
+ q_nope = q_nope.transpose(0, 1)
596
+ # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
597
+ ql_nope = torch.bmm(q_nope, self.W_UK_T)
598
+ # Convert from (N, B, L) to (B, N, L)
599
+ return ql_nope.transpose(0, 1), q_pe
600
+
601
+ def process_weights_after_loading(self, act_dtype: torch.dtype):
602
+
603
+ def get_layer_weight(layer):
604
+ WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
605
+ for attr in WEIGHT_NAMES:
606
+ if hasattr(layer, attr):
607
+ return getattr(layer, attr)
608
+ raise AttributeError(
609
+ f"Layer '{layer}' has no recognized weight attribute:"
610
+ f" {WEIGHT_NAMES}.")
611
+
612
+ def get_and_maybe_dequant_weights(layer: LinearBase):
613
+ if not isinstance(layer.quant_method, UnquantizedLinearMethod):
614
+ # NOTE: This should only be used offline, since it's O(N^3)
615
+ eye = torch.eye(layer.input_size_per_partition,
616
+ dtype=act_dtype,
617
+ device=get_layer_weight(layer).device)
618
+ dequant_weights = layer.quant_method.apply(layer,
619
+ eye,
620
+ bias=None)
621
+ del eye
622
+ # standardize to (output, input)
623
+ return dequant_weights.T
624
+ return layer.weight
625
+
626
+ # we currently do not have quantized bmm's which are needed for
627
+ # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
628
+ # the bmm's in 16-bit, the extra memory overhead of this is fairly low
629
+ kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
630
+ assert kv_b_proj_weight.shape == (
631
+ self.kv_lora_rank,
632
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
633
+ f"{kv_b_proj_weight.shape=}, "
634
+ f"{self.kv_lora_rank=}, "
635
+ f"{self.num_heads=}, "
636
+ f"{self.qk_nope_head_dim=}, "
637
+ f"{self.v_head_dim=}")
638
+ kv_b_proj_weight = kv_b_proj_weight.view(
639
+ self.kv_lora_rank,
640
+ self.num_heads,
641
+ self.qk_nope_head_dim + self.v_head_dim,
642
+ )
643
+
644
+ W_UK, W_UV = kv_b_proj_weight.split(
645
+ [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
646
+
647
+ # Convert from (L, N, V) to (N, L, V)
648
+ self.W_UV = W_UV.transpose(0, 1).contiguous()
649
+ # Convert from (L, N, P) to (N, P, L)
650
+ self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
651
+
652
+ # Waiting for BMM NZ support
653
+ # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
654
+ # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
655
+
656
+ def _compute_prefill_context(
657
+ self,
658
+ query: torch.Tensor,
659
+ kv_c_and_k_pe_cache: torch.Tensor,
660
+ rope_dim: int,
661
+ attn_metadata: AscendMLAMetadata,
662
+ prefix_output: torch.Tensor,
663
+ prefix_lse: torch.Tensor,
664
+ ):
665
+ prefill_metadata = attn_metadata.prefill
666
+ if prefill_metadata is None or prefill_metadata.chunked_context is None:
667
+ return prefix_output, prefix_lse
668
+
669
+ iters = len(prefill_metadata.chunked_context.seq_tot)
670
+ q_pe = query[..., self.qk_nope_head_dim:]
671
+ q_nope = query[..., :self.qk_nope_head_dim]
672
+
673
+ seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
674
+ latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim
675
+ cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim]
676
+ cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:]
677
+ for i in range(iters):
678
+ toks = prefill_metadata.chunked_context.seq_tot[i]
679
+
680
+ seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
681
+ seq_len = torch.stack([seq_len1, seq_len2])
682
+ kv_c_normed = torch.empty(toks,
683
+ kv_c_and_k_pe_cache.size(2),
684
+ latent_kv_dim,
685
+ dtype=query.dtype,
686
+ device=query.device)
687
+ k_pe = torch.empty(toks,
688
+ kv_c_and_k_pe_cache.size(2),
689
+ rope_dim,
690
+ dtype=query.dtype,
691
+ device=query.device)
692
+
693
+ torch_npu.atb.npu_paged_cache_load(
694
+ cache_kv_c,
695
+ cache_k_pe,
696
+ prefill_metadata.block_table,
697
+ seq_len2.to(query.device),
698
+ seq_starts=prefill_metadata.chunked_context.starts[i],
699
+ key=kv_c_normed,
700
+ value=k_pe,
701
+ )
702
+
703
+ kv_c_normed = kv_c_normed.squeeze()
704
+ kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
705
+ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
706
+ k_nope, v = kv_nope\
707
+ .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
708
+ k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
709
+ mask = torch.triu(
710
+ torch.ones(512, 512, device=query.device, dtype=query.dtype),
711
+ 1)
712
+ torch_npu.atb.npu_ring_mla(
713
+ q_nope=q_nope,
714
+ q_rope=q_pe,
715
+ k_nope=k_nope,
716
+ k_rope=k_pe,
717
+ value=v,
718
+ mask=mask,
719
+ seqlen=seq_len,
720
+ head_num=self.num_heads,
721
+ kv_head_num=self.num_heads,
722
+ pre_out=prefix_output,
723
+ prev_lse=prefix_lse,
724
+ qk_scale=self.scale,
725
+ kernel_type="kernel_type_high_precision",
726
+ mask_type="no_mask",
727
+ input_layout="type_bsnd",
728
+ calc_type="calc_type_default",
729
+ output=prefix_output,
730
+ softmax_lse=prefix_lse)
731
+ return prefix_output, prefix_lse
732
+
733
+ def _forward_prefill(
734
+ self,
735
+ query: torch.Tensor,
736
+ kv_c_normed: torch.Tensor,
737
+ k_pe: torch.Tensor,
738
+ kv_c_and_k_pe_cache: torch.Tensor,
739
+ attn_metadata: AscendMLAMetadata,
740
+ ) -> torch.Tensor:
741
+ assert attn_metadata.prefill is not None
742
+
743
+ num_tokens = query.size(0)
744
+ attn_output = torch.empty(num_tokens,
745
+ self.num_heads,
746
+ self.v_head_dim,
747
+ dtype=query.dtype,
748
+ device=query.device)
749
+ k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
750
+ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
751
+ [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
752
+ k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
753
+ # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
754
+ ascend_config = get_ascend_config()
755
+
756
+ if attn_metadata.attn_state in [
757
+ AscendAttentionState.ChunkedPrefill,
758
+ AscendAttentionState.SpecDecoding,
759
+ AscendAttentionState.PrefillCacheHit
760
+ ] and not ascend_config.chunked_prefill_for_mla:
761
+ attn_output_torch = torch.empty(num_tokens,
762
+ self.num_heads * self.v_head_dim,
763
+ dtype=query.dtype,
764
+ device=query.device)
765
+ # current requests is chunked in prefill, disable flash attention with chunked prefill
766
+ vanilla_chunked_prefill_mla(
767
+ output=attn_output_torch,
768
+ query=query,
769
+ kv_cache=kv_c_and_k_pe_cache,
770
+ block_tables=attn_metadata.prefill.block_table,
771
+ query_lens=attn_metadata.prefill.query_lens,
772
+ context_lens=attn_metadata.prefill.context_lens,
773
+ kv_b_proj=self.kv_b_proj,
774
+ max_query_len=attn_metadata.prefill.max_query_len,
775
+ max_context_len=attn_metadata.prefill.max_seq_lens,
776
+ nope_dim=self.qk_nope_head_dim,
777
+ rope_dim=self.qk_rope_head_dim,
778
+ v_head_dim=self.v_head_dim,
779
+ scale=self.scale,
780
+ alibi_slopes=None,
781
+ causal=True)
782
+ elif attn_metadata.attn_state in [
783
+ AscendAttentionState.ChunkedPrefill,
784
+ AscendAttentionState.SpecDecoding,
785
+ AscendAttentionState.PrefillCacheHit
786
+ ]:
787
+ attn_lse = torch.empty(self.num_heads,
788
+ num_tokens,
789
+ dtype=torch.float32,
790
+ device=query.device)
791
+ q_pe = query[..., self.qk_nope_head_dim:]
792
+ q_nope = query[..., :self.qk_nope_head_dim]
793
+ mask = torch.triu(
794
+ torch.ones(512, 512, device=query.device, dtype=query.dtype),
795
+ 1) # 512: mask only support 512
796
+ if attn_metadata.num_prefills > 1:
797
+ mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
798
+ 1)
799
+ torch_npu.atb.npu_ring_mla(
800
+ q_nope=q_nope,
801
+ q_rope=q_pe,
802
+ k_nope=k_nope,
803
+ k_rope=k_pe,
804
+ value=value,
805
+ mask=mask,
806
+ seqlen=torch.tensor(attn_metadata.prefill.query_lens,
807
+ dtype=torch.int32),
808
+ head_num=self.num_heads,
809
+ kv_head_num=self.num_heads,
810
+ pre_out=None,
811
+ prev_lse=None,
812
+ qk_scale=self.scale,
813
+ kernel_type="kernel_type_high_precision",
814
+ mask_type="mask_type_triu",
815
+ input_layout="type_bsnd",
816
+ calc_type="calc_type_first_ring",
817
+ output=attn_output,
818
+ softmax_lse=attn_lse)
819
+ attn_output, attn_lse = self._compute_prefill_context( \
820
+ query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
821
+
822
+ elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
823
+ key = torch.cat((k_nope, k_pe), dim=-1)
824
+ context_lens_list = torch.cumsum(attn_metadata.prefill.context_lens, dim=0).tolist()
825
+ attn_output = torch_npu.npu_fused_infer_attention_score(
826
+ query,
827
+ key,
828
+ value,
829
+ num_heads=self.num_heads,
830
+ input_layout="TND",
831
+ scale=self.scale,
832
+ sparse_mode=3,
833
+ atten_mask=self.SHARE_MASK_TRIL_SPARSE,
834
+ actual_seq_lengths=context_lens_list,
835
+ actual_seq_lengths_kv=context_lens_list,
836
+ inner_precise=0)[0]
837
+ attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
838
+ else:
839
+ raise RuntimeError(
840
+ "Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
841
+ )
842
+ attn_output = attn_output.reshape(
843
+ [num_tokens, self.num_heads * self.v_head_dim])
844
+ if attn_metadata.attn_state in [
845
+ AscendAttentionState.ChunkedPrefill,
846
+ AscendAttentionState.SpecDecoding,
847
+ AscendAttentionState.PrefillCacheHit
848
+ ] and not ascend_config.chunked_prefill_for_mla:
849
+ attn_output = attn_output_torch
850
+
851
+ current_ms_metadata = get_multistream_comm_context()
852
+ if current_ms_metadata is None:
853
+ return self.o_proj(attn_output, is_prefill=True)[0]
854
+ else:
855
+ current_ms_metadata.before_comm_event.record()
856
+ with torch.npu.stream(current_ms_metadata.comm_stream):
857
+ current_ms_metadata.before_comm_event.wait()
858
+ return self.o_proj(attn_output, is_prefill=True)[0]
859
+
860
+ def exec_kv(
861
+ self,
862
+ hidden_states: torch.Tensor,
863
+ cos: torch.Tensor,
864
+ sin: torch.Tensor,
865
+ kv_cache: Tuple,
866
+ slots: torch.Tensor,
867
+ ):
868
+
869
+ B = hidden_states.shape[0]
870
+ N = self.num_kv_heads
871
+ S = 1
872
+ kv = self.kv_a_proj_with_mqa(hidden_states)[0]
873
+ # npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
874
+ kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
875
+ cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
876
+ k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
877
+ kv,
878
+ self.kv_a_layernorm.weight,
879
+ cos,
880
+ sin,
881
+ slots.to(torch.int64),
882
+ kv_cache[1],
883
+ kv_cache[0],
884
+ epsilon=self.kv_a_layernorm.variance_epsilon,
885
+ cache_mode=cache_mode,
886
+ )
887
+ return k_pe, k_nope, kv
888
+
889
+ def exec_kv_prefill(
890
+ self,
891
+ hidden_states: torch.Tensor,
892
+ cos: torch.Tensor,
893
+ sin: torch.Tensor,
894
+ kv_cache: Tuple,
895
+ slots: torch.Tensor,
896
+ ):
897
+
898
+ B = hidden_states.shape[0]
899
+ N = self.num_kv_heads
900
+ S = 1
901
+ kv = self.kv_a_proj_with_mqa(hidden_states)[0]
902
+ # npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
903
+ kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
904
+ cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA"
905
+ _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
906
+ kv,
907
+ self.kv_a_layernorm.weight,
908
+ cos,
909
+ sin,
910
+ slots.to(torch.int64),
911
+ kv_cache[1],
912
+ kv_cache[0],
913
+ epsilon=self.kv_a_layernorm.variance_epsilon,
914
+ cache_mode=cache_mode,
915
+ is_output_kv=True,
916
+ )
917
+ return k_pe, k_nope
918
+
919
+ def rope_single(
920
+ self,
921
+ x: torch.Tensor,
922
+ cos: torch.Tensor,
923
+ sin: torch.Tensor,
924
+ ) -> torch.Tensor:
925
+ B, N, D = x.shape
926
+ S = 1
927
+ x = x.view(B, N, S, D)
928
+ x = torch_npu.npu_interleave_rope(x, cos, sin)
929
+ return x.view(B, N, D)
930
+
931
+ def _forward_decode(
932
+ self,
933
+ q_nope: torch.Tensor,
934
+ q_pe: torch.Tensor,
935
+ k_nope: torch.Tensor,
936
+ k_pe: torch.Tensor,
937
+ kv_c_and_k_pe_cache: torch.Tensor,
938
+ attn_metadata: AscendMLAMetadata,
939
+ enable_multistream_mla: bool = False,
940
+ ) -> torch.Tensor:
941
+ decode_meta = attn_metadata.decode
942
+ assert decode_meta is not None
943
+
944
+ q = torch.cat([q_nope, q_pe], dim=-1)
945
+ num_tokens = q.size(0)
946
+ attn_output = torch.empty(
947
+ [num_tokens, self.num_heads, self.kv_lora_rank],
948
+ dtype=q.dtype,
949
+ device=q.device)
950
+ if self.running_in_graph:
951
+ # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
952
+ if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
953
+ assert num_tokens % self.spec_token_num == 0
954
+ q_nope = q_nope.view(num_tokens // (self.spec_token_num + 1),
955
+ self.spec_token_num + 1, self.num_heads,
956
+ -1)
957
+ q_pe = q_pe.view(num_tokens // (self.spec_token_num + 1),
958
+ self.spec_token_num + 1, self.num_heads, -1)
959
+ if not self.enable_kv_nz:
960
+ q_nope = q_nope.transpose(1, 2).contiguous()
961
+ q_pe = q_pe.transpose(1, 2).contiguous()
962
+ sparse_mode = 3
963
+ spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
964
+ else:
965
+ if self.enable_kv_nz:
966
+ q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
967
+ q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
968
+ else:
969
+ q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
970
+ q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
971
+ sparse_mode = 0
972
+ spec_attn_mask = None
973
+ # shape of knope/k_pe for npu graph mode should be:
974
+ # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
975
+ block_size = kv_c_and_k_pe_cache[0].shape[1]
976
+ if self.enable_kv_nz:
977
+ k_nope = k_nope.view(-1, self.num_kv_heads,
978
+ self.kv_lora_rank // 16, block_size, 16)
979
+ k_pe = k_pe.view(-1, self.num_kv_heads,
980
+ self.qk_rope_head_dim // 16, block_size, 16)
981
+ input_layout = "BSND"
982
+ else:
983
+ k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
984
+ self.kv_lora_rank)
985
+ k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
986
+ self.qk_rope_head_dim)
987
+ input_layout = "BNSD"
988
+
989
+ attn_output, _ = torch_npu.npu_fused_infer_attention_score(
990
+ q_nope,
991
+ k_nope,
992
+ k_nope,
993
+ query_rope=q_pe,
994
+ key_rope=k_pe,
995
+ num_heads=self.num_heads,
996
+ num_key_value_heads=self.num_kv_heads,
997
+ input_layout=input_layout,
998
+ atten_mask=spec_attn_mask,
999
+ sparse_mode=sparse_mode,
1000
+ scale=self.scale,
1001
+ antiquant_mode=0,
1002
+ antiquant_scale=None,
1003
+ block_table=decode_meta.block_table,
1004
+ block_size=block_size,
1005
+ actual_seq_lengths_kv=decode_meta.seq_lens_list,
1006
+ )
1007
+ else:
1008
+ torch_npu._npu_paged_attention_mla(
1009
+ query=q,
1010
+ key_cache=kv_c_and_k_pe_cache,
1011
+ num_kv_heads=self.num_kv_heads,
1012
+ num_heads=self.num_heads,
1013
+ scale_value=self.scale,
1014
+ block_table=attn_metadata.decode.block_table, # type:ignore
1015
+ context_lens=attn_metadata.decode.seq_lens, # type:ignore
1016
+ mla_vheadsize=self.kv_lora_rank,
1017
+ out=attn_output)
1018
+ current_ms_metadata = get_multistream_comm_context()
1019
+ if current_ms_metadata is None:
1020
+ return self._v_up_proj_and_o_proj(attn_output,
1021
+ enable_multistream_mla)
1022
+ else:
1023
+ current_ms_metadata.before_comm_event.record()
1024
+ with torch.npu.stream(current_ms_metadata.comm_stream):
1025
+ current_ms_metadata.before_comm_event.wait()
1026
+ return self._v_up_proj_and_o_proj(attn_output)
1027
+
1028
+ def forward(
1029
+ self,
1030
+ layer: AttentionLayer,
1031
+ hidden_states_or_q_c: torch.Tensor, # query in unified attn
1032
+ hidden_states_or_kv_c_normed: torch.Tensor, # key in unified attn
1033
+ k_pe: torch.Tensor, # value in unified attn
1034
+ kv_cache: torch.Tensor,
1035
+ attn_metadata: M,
1036
+ output: Optional[torch.Tensor] = None,
1037
+ enable_multistream_mla: bool = False,
1038
+ ckq: Optional[torch.Tensor] = None,
1039
+ ) -> torch.Tensor:
1040
+ assert output is not None, "Output tensor must be provided."
1041
+ if attn_metadata is None:
1042
+ # Profiling run.
1043
+ return output
1044
+ self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
1045
+ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
1046
+ ]
1047
+ num_actual_toks = attn_metadata.num_actual_tokens
1048
+ if k_pe is None and not self.running_in_graph:
1049
+ if not self.torchair_graph_enabled:
1050
+ kv_c, k_pe = self.kv_a_proj_with_mqa(
1051
+ hidden_states_or_kv_c_normed)[0].split(
1052
+ [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1053
+ kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
1054
+ else:
1055
+ kv_c_normed = hidden_states_or_kv_c_normed
1056
+ assert attn_metadata.num_decodes is not None and \
1057
+ attn_metadata.num_prefills is not None and \
1058
+ attn_metadata.num_decode_tokens is not None
1059
+ has_decode = attn_metadata.num_decodes > 0
1060
+ has_prefill = attn_metadata.num_prefills > 0
1061
+ num_decode_tokens = attn_metadata.num_decode_tokens
1062
+ if not self.running_in_graph:
1063
+ # Inputs and outputs may be padded for CUDA graphs
1064
+ output_padded = output
1065
+ output = output[:num_actual_toks, ...]
1066
+ if not self.torchair_graph_enabled:
1067
+ kv_c_normed = kv_c_normed[:num_actual_toks, ...]
1068
+ prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
1069
+ if not self.running_in_graph:
1070
+ hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
1071
+ prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
1072
+ if not self.torchair_graph_enabled:
1073
+ decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
1074
+ k_pe = k_pe[:num_actual_toks, ...]
1075
+ k_pe = k_pe.unsqueeze(1)
1076
+ decode_k_pe = k_pe[:num_decode_tokens]
1077
+ prefill_k_pe = k_pe[num_decode_tokens:]
1078
+ else:
1079
+ decode_hs_or_q_c = hidden_states_or_q_c
1080
+ if has_decode:
1081
+ decode_k_nope = None
1082
+ assert attn_metadata.decode is not None
1083
+ if self.running_in_graph:
1084
+ seq_len = self.rotary_emb.max_position_embeddings * \
1085
+ getattr(self.rotary_emb, "scaling_factor", 1)
1086
+ cos = self.rotary_emb.cos_cached[:seq_len].to(
1087
+ dtype=decode_hs_or_q_c.dtype)
1088
+ sin = self.rotary_emb.sin_cached[:seq_len].to(
1089
+ dtype=decode_hs_or_q_c.dtype)
1090
+ cos = cos[attn_metadata.decode.input_positions]
1091
+ sin = sin[attn_metadata.decode.input_positions]
1092
+ cos = cos[:, None, None, :]
1093
+ sin = sin[:, None, None, :]
1094
+ with npu_stream_switch("mla_secondary",
1095
+ 0,
1096
+ enabled=enable_multistream_mla):
1097
+ npu_wait_tensor(hidden_states_or_kv_c_normed,
1098
+ ckq,
1099
+ enabled=enable_multistream_mla)
1100
+ decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
1101
+ hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1102
+ attn_metadata.slot_mapping)
1103
+ # Without explicitly controlling the order, IndexByTensor operations
1104
+ # would be placed after `matmul W_KV_T` hindering the overlapping of
1105
+ # KvRmsNormRopeCache and SingleRope.
1106
+ npu_wait_tensor(decode_hs_or_q_c,
1107
+ cos,
1108
+ enabled=enable_multistream_mla)
1109
+ npu_wait_tensor(decode_hs_or_q_c,
1110
+ sin,
1111
+ enabled=enable_multistream_mla)
1112
+ npu_wait_tensor(decode_hs_or_q_c,
1113
+ decode_kv,
1114
+ enabled=enable_multistream_mla)
1115
+
1116
+ decode_ql_nope, decode_q_pe = \
1117
+ self._q_proj_and_k_up_proj(decode_hs_or_q_c)
1118
+ if self.running_in_graph:
1119
+ with npu_stream_switch("mla_secondary",
1120
+ 0,
1121
+ enabled=enable_multistream_mla):
1122
+ npu_wait_tensor(decode_q_pe,
1123
+ decode_k_pe,
1124
+ enabled=enable_multistream_mla)
1125
+ decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
1126
+ else:
1127
+ decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
1128
+ attn_metadata.decode.input_positions,
1129
+ decode_q_pe.contiguous(),
1130
+ decode_k_pe,
1131
+ max_seq_len=attn_metadata.decode.max_seq_lens)
1132
+ if has_prefill:
1133
+ assert attn_metadata.prefill is not None
1134
+ prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
1135
+ .view(-1, self.num_heads, self.qk_head_dim)
1136
+ prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
1137
+ prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
1138
+ if self.torchair_graph_enabled:
1139
+ num_tokens = prefill_hs_or_q_c.shape[0]
1140
+ seq_len = self.rotary_emb.max_position_embeddings * \
1141
+ getattr(self.rotary_emb, "scaling_factor", 1)
1142
+ cos = self.rotary_emb.cos_cached[:seq_len].to(
1143
+ dtype=prefill_q_pe.dtype)
1144
+ sin = self.rotary_emb.sin_cached[:seq_len].to(
1145
+ dtype=prefill_q_pe.dtype)
1146
+ cos = cos[attn_metadata.prefill.input_positions]
1147
+ sin = sin[attn_metadata.prefill.input_positions]
1148
+ cos = cos[:, None, None, :]
1149
+ sin = sin[:, None, None, :]
1150
+
1151
+ prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
1152
+ prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
1153
+ hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1154
+ attn_metadata.slot_mapping)
1155
+
1156
+ kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
1157
+ prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
1158
+ prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
1159
+ -1)
1160
+ prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
1161
+ else:
1162
+ prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
1163
+ attn_metadata.prefill.input_positions,
1164
+ prefill_q_pe.contiguous(),
1165
+ prefill_k_pe,
1166
+ max_seq_len=attn_metadata.prefill.max_seq_lens)
1167
+ if self.torchair_graph_enabled:
1168
+ if len(kv_cache) > 0 and kv_cache[0].numel(
1169
+ ) > 0 and attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
1170
+ slots = attn_metadata.slot_mapping
1171
+ # NOTE: Separate the kv cache in advance to avoid OOM or other issues
1172
+ torch_npu._npu_reshape_and_cache(key=kv_c_normed.view(
1173
+ num_tokens, self.num_kv_heads, -1),
1174
+ value=prefill_k_pe,
1175
+ key_cache=kv_cache[0],
1176
+ value_cache=kv_cache[1],
1177
+ slot_indices=slots)
1178
+ elif kv_cache.numel() > 0:
1179
+ key = torch.cat([
1180
+ kv_c_normed.view([num_actual_toks, self.num_kv_heads, -1]),
1181
+ k_pe
1182
+ ],
1183
+ dim=2)
1184
+ torch_npu._npu_reshape_and_cache_siso(
1185
+ key=key,
1186
+ key_cache=kv_cache,
1187
+ slot_indices=attn_metadata.slot_mapping.flatten())
1188
+ if has_prefill:
1189
+ # FIX: aicore move should be also placed on the comm stream in dbo,
1190
+ # otherwise it may affect the accuracy
1191
+ # TODO: use an elegant way to overlap
1192
+ output_prefill = self._forward_prefill(prefill_q,
1193
+ prefill_k_c_normed,
1194
+ prefill_k_pe, kv_cache,
1195
+ attn_metadata)
1196
+ current_ms_metadata = get_multistream_comm_context()
1197
+ if current_ms_metadata is not None:
1198
+ with torch.npu.stream(current_ms_metadata.comm_stream):
1199
+ output[num_decode_tokens:] = output_prefill
1200
+ current_ms_metadata.after_comm_event.record()
1201
+ else:
1202
+ output[num_decode_tokens:] = output_prefill
1203
+
1204
+ if has_decode:
1205
+ if self.running_in_graph:
1206
+ return self._forward_decode(decode_ql_nope, decode_q_pe,
1207
+ decode_k_nope, decode_k_pe,
1208
+ kv_cache, attn_metadata,
1209
+ enable_multistream_mla)
1210
+ else:
1211
+ output_decode = self._forward_decode(decode_ql_nope,
1212
+ decode_q_pe,
1213
+ decode_k_nope,
1214
+ decode_k_pe, kv_cache,
1215
+ attn_metadata)
1216
+ current_ms_metadata = get_multistream_comm_context()
1217
+ if current_ms_metadata is not None:
1218
+ with torch.npu.stream(current_ms_metadata.comm_stream):
1219
+ output[:num_decode_tokens] = output_decode
1220
+ current_ms_metadata.after_comm_event.record()
1221
+ else:
1222
+ output[:num_decode_tokens] = output_decode
1223
+
1224
+ return output_padded
inference/vllm_ascend/entrypoints/openai/reasoning_parsers/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2
+ from .pangu_reasoning_parser import PanguReasoningParser
3
+
4
+ __all__ = [
5
+ "PanguReasoningParser"
6
+ ]
inference/vllm_ascend/entrypoints/openai/reasoning_parsers/pangu_reasoning_parser.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
4
+
5
+ from collections.abc import Sequence
6
+ from typing import Optional, Union
7
+
8
+ from transformers import PreTrainedTokenizerBase
9
+
10
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
11
+ DeltaMessage)
12
+ from vllm.logger import init_logger
13
+ from vllm.reasoning import ReasoningParser, ReasoningParserManager
14
+
15
+ logger = init_logger(__name__)
16
+
17
+
18
+ @ReasoningParserManager.register_module("pangu")
19
+ class PanguReasoningParser(ReasoningParser):
20
+ """
21
+ Reasoning parser for Pangu model.
22
+
23
+ The Pangu model uses [unused16]...[unused17] tokens to denote reasoning
24
+ text. This parser extracts the reasoning content from the model output.
25
+ """
26
+
27
+ start_token_id: int
28
+ end_token_id: int
29
+
30
+ start_token: str = "[unused16]"
31
+ end_token: str = "[unused17]"
32
+
33
+ def __init__(self, tokenizer: PreTrainedTokenizerBase):
34
+ super().__init__(tokenizer)
35
+
36
+ if not self.model_tokenizer:
37
+ raise ValueError(
38
+ "The model tokenizer must be passed to the ReasoningParser "
39
+ "constructor during construction.")
40
+
41
+ self.start_token_id = self.vocab.get(self.start_token)
42
+ self.end_token_id = self.vocab.get(self.end_token)
43
+ if self.start_token_id is None or self.end_token_id is None:
44
+ raise RuntimeError(
45
+ "Pangu reasoning parser could not locate think start/end "
46
+ "tokens in the tokenizer!")
47
+
48
+ def is_reasoning_end(self, input_ids: list[int]) -> bool:
49
+ return self.end_token_id in input_ids
50
+
51
+ def extract_content_ids(self, input_ids: list[int]) -> list[int]:
52
+ """
53
+ Extract the content after the end tokens
54
+ """
55
+ if self.end_token_id not in input_ids[:-1]:
56
+ return []
57
+ else:
58
+ return input_ids[input_ids.index(self.end_token_id) + 1:]
59
+
60
+ def extract_reasoning_content_streaming(
61
+ self,
62
+ previous_text: str,
63
+ current_text: str,
64
+ delta_text: str,
65
+ previous_token_ids: Sequence[int],
66
+ current_token_ids: Sequence[int],
67
+ delta_token_ids: Sequence[int],
68
+ ) -> Union[DeltaMessage, None]:
69
+ """
70
+ Extract reasoning content from a delta message.
71
+ Handles streaming output where previous + delta = current.
72
+ Uses token IDs for faster processing.
73
+ For text [unused16]abc[unused17]xyz:
74
+ - 'abc' goes to reasoning_content
75
+ - 'xyz' goes to content
76
+ """
77
+ # Skip single special tokens
78
+ if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
79
+ self.start_token_id, self.end_token_id
80
+ ]):
81
+ return None
82
+
83
+ # Check if [unused16] is present in previous or delta.
84
+ # Keep compatibility with models that don't generate [unused16] tokens.
85
+ if self.start_token_id in previous_token_ids:
86
+ if self.end_token_id in delta_token_ids:
87
+ # [unused16] in previous, [unused17] in delta,
88
+ # extract reasoning content
89
+ end_index = delta_text.find(self.end_token)
90
+ reasoning_content = delta_text[:end_index]
91
+ content = delta_text[end_index + len(self.end_token):]
92
+ return DeltaMessage(
93
+ reasoning_content=reasoning_content,
94
+ content=content if content else None,
95
+ )
96
+ elif self.end_token_id in previous_token_ids:
97
+ # [unused16] in previous, [unused17] in previous,
98
+ # reasoning content continues
99
+ return DeltaMessage(content=delta_text)
100
+ else:
101
+ # [unused16] in previous, no [unused17] in previous or delta,
102
+ # reasoning content continues
103
+ return DeltaMessage(reasoning_content=delta_text)
104
+ elif self.start_token_id in delta_token_ids:
105
+ if self.end_token_id in delta_token_ids:
106
+ # [unused16] in delta, [unused17] in delta, extract reasoning content
107
+ start_index = delta_text.find(self.start_token)
108
+ end_index = delta_text.find(self.end_token)
109
+ reasoning_content = delta_text[start_index +
110
+ len(self.start_token):end_index]
111
+ content = delta_text[end_index + len(self.end_token):]
112
+ return DeltaMessage(
113
+ reasoning_content=reasoning_content,
114
+ content=content if content else None,
115
+ )
116
+ else:
117
+ # [unused16] in delta, no [unused17] in delta,
118
+ # reasoning content continues
119
+ return DeltaMessage(reasoning_content=delta_text)
120
+ else:
121
+ # No [unused16] in previous or delta, also need to check for [unused17].
122
+ # Because the model may have generated [unused17] without [unused16]
123
+ if self.end_token_id in delta_token_ids:
124
+ # [unused17] in delta with more tokens,
125
+ # extract reasoning content and content
126
+ end_index = delta_text.find(self.end_token)
127
+ reasoning_content = delta_text[:end_index]
128
+ content = delta_text[end_index + len(self.end_token):]
129
+ return DeltaMessage(
130
+ reasoning_content=reasoning_content,
131
+ content=content if content else None,
132
+ )
133
+ elif self.end_token_id in previous_token_ids:
134
+ # [unused17] in previous, thinking content ends
135
+ return DeltaMessage(content=delta_text)
136
+ else:
137
+ # no [unused17] in previous or delta, reasoning content continues
138
+ return DeltaMessage(reasoning_content=delta_text)
139
+
140
+ def extract_reasoning_content(
141
+ self, model_output: str, request: ChatCompletionRequest
142
+ ) -> tuple[Optional[str], Optional[str]]:
143
+ """
144
+ Extract reasoning content from the model output.
145
+
146
+ For text [unused16]abc[unused17]xyz:
147
+ - 'abc' goes to reasoning_content
148
+ - 'xyz' goes to content
149
+
150
+ Returns:
151
+ tuple[Optional[str], Optional[str]]: reasoning content and content
152
+ """
153
+
154
+ # Check if the start token is present in the model output, remove it
155
+ # if it is present.
156
+ model_output_parts = model_output.partition(self.start_token)
157
+ model_output = model_output_parts[2] if model_output_parts[
158
+ 1] else model_output_parts[0]
159
+
160
+ # Thus we assume the reasoning content is always at the start.
161
+ if self.end_token not in model_output:
162
+ return model_output, None
163
+ else:
164
+ reasoning_content, _, content = model_output.partition(
165
+ self.end_token)
166
+ # If the end token is not found, return the model output as is.
167
+ # It should not happen since we already checked for the presence
168
+ # of the end token.
169
+ # If generation stops right after end-of-think, return null content
170
+ final_content = content or None
171
+ return reasoning_content, final_content
inference/vllm_ascend/entrypoints/openai/tool_parsers/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2
+ from .pangu_tool_parser import PanguToolParser
3
+
4
+ __all__ = [
5
+ "PanguToolParser"
6
+ ]
inference/vllm_ascend/entrypoints/openai/tool_parsers/pangu_tool_parser.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2
+ # Copyright 2023 The vLLM team.
3
+
4
+ import json
5
+ import re
6
+ from json import JSONDecodeError, JSONDecoder
7
+ from typing import Dict, List, Sequence, Union, Optional
8
+ from pydantic import Field
9
+ import partial_json_parser
10
+ from partial_json_parser.core.options import Allow
11
+ from transformers import PreTrainedTokenizerBase
12
+
13
+ from vllm.entrypoints.chat_utils import random_tool_call_id
14
+ from vllm.entrypoints.openai.tool_parsers.utils import (
15
+ extract_intermediate_diff)
16
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
17
+ DeltaFunctionCall, DeltaMessage,
18
+ DeltaToolCall,
19
+ ExtractedToolCallInformation,
20
+ FunctionCall, ToolCall,
21
+ )
22
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
23
+ ToolParser, ToolParserManager)
24
+ from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix,
25
+ is_complete_json)
26
+ from vllm.logger import init_logger
27
+ import os
28
+
29
+ logger = init_logger(__name__)
30
+
31
+
32
+ @ToolParserManager.register_module("pangu")
33
+ class PanguToolParser(ToolParser):
34
+
35
+ def __init__(self, tokenizer: PreTrainedTokenizerBase, enable_reasoning=False):
36
+ super().__init__(tokenizer)
37
+
38
+ # initialize properties used for state when parsing tool calls in
39
+ # streaming mode
40
+ self.prev_tool_call_arr: List[Dict] = []
41
+ self.current_tool_id: int = -1
42
+ self.current_tool_name_sent: bool = False
43
+ self.streamed_args_for_tool: List[str] = [
44
+ ] # map what has been streamed for each tool so far to a list
45
+
46
+ self.tool_call_start_token = "[unused11]"
47
+ self.tool_call_end_token = "[unused12]"
48
+ self.pattern = re.escape(self.tool_call_start_token) \
49
+ + "(.*?)" + re.escape(self.tool_call_end_token)
50
+ self.tool_call_regex = re.compile(self.pattern, re.DOTALL)
51
+
52
+
53
+ self.tool_call_start_token_id = self.vocab.get(
54
+ self.tool_call_start_token)
55
+ self.tool_call_end_token_id = self.vocab.get(
56
+ self.tool_call_end_token)
57
+
58
+
59
+ if (self.tool_call_start_token_id is None
60
+ or self.tool_call_end_token_id is None):
61
+ raise RuntimeError(
62
+ "Pangu Tool parser could not locate tool calls start/end "
63
+ "tokens in the tokenizer!")
64
+ self.is_complete = []
65
+ self.text_after_start_token = ""
66
+
67
+
68
+ def extract_tool_calls(
69
+ self, model_output: str,
70
+ request: ChatCompletionRequest
71
+ ) -> ExtractedToolCallInformation:
72
+ """
73
+ Extract the tool calls from a complete model response.
74
+ """
75
+ # case -- if a tool call token is not present, return a text response
76
+ if not (self.tool_call_start_token in model_output and \
77
+ model_output.find(self.tool_call_end_token) != -1):
78
+ return ExtractedToolCallInformation(tools_called=False,
79
+ tool_calls=[],
80
+ content=model_output)
81
+
82
+ try:
83
+ raw_function_calls = []
84
+ # use a regex to find the tool call between the tags
85
+ function_call_tuples = self.tool_call_regex.findall(model_output)
86
+
87
+
88
+ # load the JSON, and then use it to build the Function and
89
+ # Tool Call
90
+ for function_call_str in function_call_tuples:
91
+ function_call = json.loads(function_call_str)
92
+ raw_function_calls.extend(function_call)
93
+
94
+
95
+ tool_calls: List[ToolCall] = [
96
+ ToolCall(
97
+ type="function",
98
+ function=FunctionCall(
99
+ name=function_call["name"],
100
+ # function call args are JSON but as a string
101
+ arguments=json.dumps(function_call["arguments"] \
102
+ if "arguments" in function_call \
103
+ else function_call["parameters"], ensure_ascii=False)))
104
+ for function_call in raw_function_calls
105
+ ]
106
+ content = model_output[:model_output.
107
+ find(self.tool_call_start_token)]
108
+
109
+ # get any content before the tool call
110
+ ret = ExtractedToolCallInformation(tools_called=True,
111
+ tool_calls=tool_calls,
112
+ content=content if content else None)
113
+
114
+ return ret
115
+
116
+ except Exception:
117
+ logger.exception("Error in extracting tool call from response.")
118
+ # return information to just treat the tool call as regular JSON
119
+ return ExtractedToolCallInformation(tools_called=False,
120
+ tool_calls=[],
121
+ content=model_output)
122
+
123
+ def extract_tool_calls_streaming(
124
+ self,
125
+ previous_text: str,
126
+ current_text: str,
127
+ delta_text: str,
128
+ previous_token_ids: Sequence[int],
129
+ current_token_ids: Sequence[int],
130
+ delta_token_ids: Sequence[int],
131
+ request: ChatCompletionRequest,
132
+ ) -> Union[DeltaMessage, None]:
133
+
134
+ if (self.tool_call_end_token_id in delta_token_ids
135
+ and len(delta_token_ids) == 1):
136
+ # if it's the only token, return None, so we don't send a chat
137
+ # completion and don't send a control token
138
+ return None
139
+
140
+ if (self.tool_call_end_token in current_text
141
+ and self.tool_call_end_token not in delta_text):
142
+ return DeltaMessage(content=delta_text)
143
+
144
+ if self.tool_call_start_token not in current_text:
145
+ return DeltaMessage(content=delta_text)
146
+
147
+ if self.tool_call_start_token in delta_text:
148
+ texts = delta_text.split(self.tool_call_start_token)
149
+ text_before_start_token = texts[0]
150
+ if text_before_start_token:
151
+ return DeltaMessage(content=text_before_start_token)
152
+
153
+ if (self.tool_call_start_token_id in delta_token_ids
154
+ and len(delta_token_ids) == 1):
155
+ # if it's the only token, return None, so we don't send a chat
156
+ # completion and don't send a control token
157
+ return None
158
+
159
+ # bit mask flags for partial JSON parsing. If the name hasn't been
160
+ # sent yet, don't allow sending
161
+ # an incomplete string since OpenAI only ever (as far as I have
162
+ # seen) allows sending the entire tool/ function name at once.
163
+ flags = Allow.ALL if self.current_tool_name_sent \
164
+ else Allow.ALL & ~Allow.STR
165
+ try:
166
+
167
+ tool_call_portion = current_text.split(
168
+ self.tool_call_start_token)[-1].split(self.tool_call_end_token)[0]
169
+ try:
170
+ tool_call_arr: list[dict] = partial_json_parser.loads(
171
+ tool_call_portion, flags)
172
+
173
+ self.is_complete.append(
174
+ is_complete_json(tool_call_portion))
175
+ except partial_json_parser.core.exceptions.MalformedJSON:
176
+ logger.debug('not enough tokens to parse into JSON yet')
177
+ return None
178
+
179
+ # select as the current tool call the one we're on the state at
180
+ current_tool_call: dict = tool_call_arr[self.current_tool_id] \
181
+ if len(tool_call_arr) > 0 else {}
182
+
183
+ # case -- if no tokens have been streamed for the tool, e.g.
184
+ # only the array brackets, stream nothing
185
+ if len(tool_call_arr) == 0:
186
+ return None
187
+
188
+ # case: we are starting a new tool in the array
189
+ # -> array has > 0 length AND length has moved past cursor
190
+ elif (len(tool_call_arr) > 0
191
+ and len(tool_call_arr) > self.current_tool_id + 1):
192
+
193
+ # if we're moving on to a new call, first make sure we
194
+ # haven't missed anything in the previous one that was
195
+ # auto-generated due to JSON completions, but wasn't
196
+ # streamed to the client yet.
197
+ if self.current_tool_id >= 0:
198
+ cur_arguments = current_tool_call.get("arguments")
199
+ if cur_arguments:
200
+ cur_args_json = json.dumps(cur_arguments,
201
+ ensure_ascii=False)
202
+ sent = len(
203
+ self.streamed_args_for_tool[self.current_tool_id])
204
+ argument_diff = cur_args_json[sent:]
205
+
206
+ logger.debug("got arguments diff: %s", argument_diff)
207
+ delta = DeltaMessage(tool_calls=[
208
+ DeltaToolCall(index=self.current_tool_id,
209
+ function=DeltaFunctionCall(
210
+ arguments=argument_diff).
211
+ model_dump(exclude_none=True))
212
+ ])
213
+ self.streamed_args_for_tool[
214
+ self.current_tool_id] += argument_diff
215
+ else:
216
+ delta = None
217
+ else:
218
+ delta = None
219
+ # re-set stuff pertaining to progress in the current tool
220
+ self.current_tool_id = len(tool_call_arr) - 1
221
+ self.current_tool_name_sent = False
222
+ self.streamed_args_for_tool.append("")
223
+ self.is_complete = []
224
+ logger.debug("starting on new tool %d", self.current_tool_id)
225
+ return delta
226
+
227
+ # if the current tool name hasn't been sent, send if available
228
+ # - otherwise send nothing
229
+ elif not self.current_tool_name_sent:
230
+ function_name = current_tool_call.get("name")
231
+ if function_name:
232
+ delta = DeltaMessage(tool_calls=[
233
+ DeltaToolCall(index=self.current_tool_id,
234
+ type="function",
235
+ id=random_tool_call_id(),
236
+ function=DeltaFunctionCall(
237
+ name=function_name).model_dump(
238
+ exclude_none=True))
239
+ ])
240
+ self.current_tool_name_sent = True
241
+ else:
242
+ delta = None
243
+
244
+ # now we know we're on the same tool call and we're streaming
245
+ # arguments
246
+ else:
247
+ cur_arguments = current_tool_call.get("arguments")
248
+ delta = None
249
+ if (self.is_complete[-1] and not cur_arguments
250
+ and not self.streamed_args_for_tool[-1]):
251
+ argument_diff = "{}"
252
+ delta = DeltaMessage(tool_calls=[
253
+ DeltaToolCall(index=self.current_tool_id,
254
+ function=DeltaFunctionCall(
255
+ arguments=argument_diff).
256
+ model_dump(exclude_none=True))
257
+ ])
258
+ self.streamed_args_for_tool[
259
+ self.current_tool_id] += argument_diff
260
+
261
+ if cur_arguments:
262
+ sent = len(
263
+ self.streamed_args_for_tool[self.current_tool_id])
264
+ cur_args_json = json.dumps(cur_arguments,
265
+ ensure_ascii=False)
266
+ prev_arguments = self.prev_tool_call_arr[
267
+ self.current_tool_id].get("arguments")
268
+
269
+ argument_diff = None
270
+ if self.is_complete[-1]:
271
+ argument_diff = cur_args_json[sent:]
272
+ elif prev_arguments:
273
+ prev_args_json = json.dumps(prev_arguments,
274
+ ensure_ascii=False)
275
+ if cur_args_json != prev_args_json:
276
+
277
+ prefix = find_common_prefix(
278
+ prev_args_json, cur_args_json)
279
+ argument_diff = prefix[sent:]
280
+
281
+ if argument_diff is not None:
282
+ delta = DeltaMessage(tool_calls=[
283
+ DeltaToolCall(index=self.current_tool_id,
284
+ function=DeltaFunctionCall(
285
+ arguments=argument_diff).
286
+ model_dump(exclude_none=True))
287
+ ])
288
+ self.streamed_args_for_tool[
289
+ self.current_tool_id] += argument_diff
290
+
291
+
292
+ self.prev_tool_call_arr = tool_call_arr
293
+ return delta
294
+
295
+ except Exception:
296
+ logger.exception("Error trying to handle streaming tool call.")
297
+ logger.debug(
298
+ "Skipping chunk as a result of tool streaming extraction "
299
+ "error")
300
+ return None
inference/vllm_ascend/envs.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # This file is a part of the vllm-ascend project.
4
+ #
5
+ # This file is mainly Adapted from vllm-project/vllm/vllm/envs.py
6
+ # Copyright 2023 The vLLM team.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ #
20
+
21
+ import os
22
+ from typing import Any, Callable, Dict
23
+
24
+ # The begin-* and end* here are used by the documentation generator
25
+ # to extract the used env vars.
26
+
27
+ # begin-env-vars-definition
28
+
29
+ env_variables: Dict[str, Callable[[], Any]] = {
30
+ # max compile thread number for package building. Usually, it is set to
31
+ # the number of CPU cores. If not set, the default value is None, which
32
+ # means all number of CPU cores will be used.
33
+ "MAX_JOBS":
34
+ lambda: os.getenv("MAX_JOBS", None),
35
+ # The build type of the package. It can be one of the following values:
36
+ # Release, Debug, RelWithDebugInfo. If not set, the default value is Release.
37
+ "CMAKE_BUILD_TYPE":
38
+ lambda: os.getenv("CMAKE_BUILD_TYPE"),
39
+ # Whether to compile custom kernels. If not set, the default value is True.
40
+ # If set to False, the custom kernels will not be compiled. Please note that
41
+ # the sleep mode feature will be disabled as well if custom kernels are not
42
+ # compiled.
43
+ "COMPILE_CUSTOM_KERNELS":
44
+ lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
45
+ # The CXX compiler used for compiling the package. If not set, the default
46
+ # value is None, which means the system default CXX compiler will be used.
47
+ "CXX_COMPILER":
48
+ lambda: os.getenv("CXX_COMPILER", None),
49
+ # The C compiler used for compiling the package. If not set, the default
50
+ # value is None, which means the system default C compiler will be used.
51
+ "C_COMPILER":
52
+ lambda: os.getenv("C_COMPILER", None),
53
+ # The version of the Ascend chip. If not set, the default value is
54
+ # ASCEND910B1. It's used for package building. Please make sure that the
55
+ # version is correct.
56
+ "SOC_VERSION":
57
+ lambda: os.getenv("SOC_VERSION", "ASCEND910B1"),
58
+ # If set, vllm-ascend will print verbose logs during compilation
59
+ "VERBOSE":
60
+ lambda: bool(int(os.getenv('VERBOSE', '0'))),
61
+ # The home path for CANN toolkit. If not set, the default value is
62
+ # /usr/local/Ascend/ascend-toolkit/latest
63
+ "ASCEND_HOME_PATH":
64
+ lambda: os.getenv("ASCEND_HOME_PATH", None),
65
+ # The path for HCCN Tool, the tool will be called by disaggregated prefilling
66
+ # case.
67
+ "HCCN_PATH":
68
+ lambda: os.getenv("HCCN_PATH", "/usr/local/Ascend/driver/tools/hccn_tool"),
69
+ # The path for HCCL library, it's used by pyhccl communicator backend. If
70
+ # not set, the default value is libhccl.so。
71
+ "HCCL_SO_PATH":
72
+ # The prefill device id for disaggregated prefilling case.
73
+ lambda: os.environ.get("HCCL_SO_PATH", None),
74
+ "PROMPT_DEVICE_ID":
75
+ lambda: os.getenv("PROMPT_DEVICE_ID", None),
76
+ # The decode device id for disaggregated prefilling case.
77
+ "DECODE_DEVICE_ID":
78
+ lambda: os.getenv("DECODE_DEVICE_ID", None),
79
+ # The port number for llmdatadist communication. If not set, the default
80
+ # value is 26000.
81
+ "LLMDATADIST_COMM_PORT":
82
+ lambda: os.getenv("LLMDATADIST_COMM_PORT", "26000"),
83
+ # The wait time for llmdatadist sync cache. If not set, the default value is
84
+ # 5000ms.
85
+ "LLMDATADIST_SYNC_CACHE_WAIT_TIME":
86
+ lambda: os.getenv("LLMDATADIST_SYNC_CACHE_WAIT_TIME", "5000"),
87
+ # The version of vllm is installed. This value is used for developers who
88
+ # installed vllm from source locally. In this case, the version of vllm is
89
+ # usually changed. For example, if the version of vllm is "0.9.0", but when
90
+ # it's installed from source, the version of vllm is usually set to "0.9.1".
91
+ # In this case, developers need to set this value to "0.9.0" to make sure
92
+ # that the correct package is installed.
93
+ "VLLM_VERSION":
94
+ lambda: os.getenv("VLLM_VERSION", None),
95
+ # Whether to enable the trace recompiles from pytorch.
96
+ "VLLM_ASCEND_TRACE_RECOMPILES":
97
+ lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
98
+ # Whether to enable fused_experts_allgather_ep. MoeInitRoutingV3 and
99
+ # GroupedMatmulFinalizeRouting operators are combined to implement EP.
100
+ "VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP":
101
+ lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0'))
102
+ ),
103
+ "VLLM_ASCEND_ENABLE_DBO":
104
+ lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
105
+ # Whether to enable the model execute time observe profile. Disable it when
106
+ # running vllm ascend in production environment.
107
+ "VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
108
+ lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
109
+ ),
110
+ # MOE_ALL2ALL_BUFFER:
111
+ # 0: default, normal init.
112
+ # 1: enable moe_all2all_buffer.
113
+ "MOE_ALL2ALL_BUFFER":
114
+ lambda: bool(int(os.getenv("MOE_ALL2ALL_BUFFER", '0'))),
115
+ # Some models are optimized by vllm ascend. While in some case, e.g. rlhf
116
+ # training, the optimized model may not be suitable. In this case, set this
117
+ # value to False to disable the optimized model.
118
+ "USE_OPTIMIZED_MODEL":
119
+ lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
120
+ # SELECT_GATING_TOPK_SOTFMAX_EXPERTS is the equivalent of select_experts in non-quantized scenarios.
121
+ # In theory, it should have better performance than select_experts.
122
+ # Subsequent versions will remove the SELECT_GATING_TOPK_SOTFMAX_EXPERTS tag and use it as the default mode.
123
+ "SELECT_GATING_TOPK_SOTFMAX_EXPERTS":
124
+ lambda: bool(int(os.getenv("SELECT_GATING_TOPK_SOTFMAX_EXPERTS", '0'))),
125
+ # The tolerance of the kv cache size, if the difference between the
126
+ # actual kv cache size and the cached kv cache size is less than this value,
127
+ # then the cached kv cache size will be used.
128
+ "VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE":
129
+ lambda: int(
130
+ os.getenv("VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE", 64)),
131
+ # Whether to enable the topk optimization. It's disabled by default for experimental support
132
+ # We'll make it enabled by default in the future.
133
+ "VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION":
134
+ lambda: bool(
135
+ int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION", '0'))),
136
+ # Whether to enable top n sigma sampling
137
+ "VLLM_ASCEND_ENABLE_TOP_N_SIGMA":
138
+ lambda: bool(
139
+ int(os.getenv("VLLM_ASCEND_ENABLE_TOP_N_SIGMA", '0'))),
140
+ }
141
+
142
+ # end-env-vars-definition
143
+
144
+
145
+ def __getattr__(name: str):
146
+ # lazy evaluation of environment variables
147
+ if name in env_variables:
148
+ return env_variables[name]()
149
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
150
+
151
+
152
+ def __dir__():
153
+ return list(env_variables.keys())
inference/vllm_ascend/models/__init__.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vllm import ModelRegistry
2
+
3
+ import vllm_ascend.envs as envs
4
+
5
+
6
+ def register_model():
7
+ from .deepseek_dbo import CustomDeepseekDBOForCausalLM # noqa: F401
8
+ from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
9
+ from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
10
+ from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401
11
+ from .open_pangu import PanguUltraMoEForCausalLM # noqa: F401
12
+ from .open_pangu import PanguEmbeddedForCausalLM # noqa: F401
13
+ from .qwen2_5_vl import \
14
+ AscendQwen2_5_VLForConditionalGeneration # noqa: F401
15
+ from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
16
+
17
+ ModelRegistry.register_model(
18
+ "DeepSeekMTPModel",
19
+ "vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
20
+
21
+ ModelRegistry.register_model(
22
+ "Qwen2VLForConditionalGeneration",
23
+ "vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")
24
+
25
+ if envs.USE_OPTIMIZED_MODEL:
26
+ ModelRegistry.register_model(
27
+ "Qwen2_5_VLForConditionalGeneration",
28
+ "vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration"
29
+ )
30
+ else:
31
+ ModelRegistry.register_model(
32
+ "Qwen2_5_VLForConditionalGeneration",
33
+ "vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding"
34
+ )
35
+
36
+ if envs.VLLM_ASCEND_ENABLE_DBO:
37
+ ModelRegistry.register_model(
38
+ "DeepseekV2ForCausalLM",
39
+ "vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
40
+
41
+ ModelRegistry.register_model(
42
+ "DeepseekV3ForCausalLM",
43
+ "vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
44
+
45
+ else:
46
+ ModelRegistry.register_model(
47
+ "DeepseekV2ForCausalLM",
48
+ "vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
49
+
50
+ ModelRegistry.register_model(
51
+ "DeepseekV3ForCausalLM",
52
+ "vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
53
+
54
+ ModelRegistry.register_model(
55
+ "Qwen3MoeForCausalLM",
56
+ "vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
57
+
58
+ ModelRegistry.register_model(
59
+ "PanguProMoEForCausalLM",
60
+ "vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM")
61
+
62
+ ModelRegistry.register_model(
63
+ "PanguUltraMoEForCausalLM",
64
+ "vllm_ascend.models.open_pangu:PanguUltraMoEForCausalLM")
65
+
66
+ ModelRegistry.register_model(
67
+ "PanguEmbeddedForCausalLM",
68
+ "vllm_ascend.models.open_pangu:PanguEmbeddedForCausalLM")
inference/vllm_ascend/models/open_pangu.py ADDED
@@ -0,0 +1,1127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # Copyright 2023 The vLLM team.
4
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
7
+ # and OPT implementations in this library. It has been modified from its
8
+ # original forms to accommodate minor architectural differences compared
9
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
24
+ import torch
25
+ import torch_npu
26
+ import vllm.envs as envs
27
+ from torch import nn
28
+ from transformers import PretrainedConfig
29
+ from vllm.compilation.decorators import support_torch_compile
30
+ from vllm.attention import Attention, AttentionMetadata, AttentionType
31
+ from vllm.config import CacheConfig, ModelConfig, VllmConfig
32
+ from vllm.distributed import (get_tensor_model_parallel_rank,
33
+ get_tensor_model_parallel_world_size,
34
+ get_tp_group, split_tensor_along_last_dim,
35
+ tensor_model_parallel_all_gather,
36
+ tensor_model_parallel_all_reduce,
37
+ tensor_model_parallel_reduce_scatter)
38
+ from vllm.distributed.parallel_state import get_dp_group
39
+ from vllm.forward_context import get_forward_context
40
+ from vllm.model_executor.layers.activation import SiluAndMul
41
+ from vllm.model_executor.layers.layernorm import RMSNorm
42
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
43
+ MergedColumnParallelLinear,
44
+ ReplicatedLinear,
45
+ RowParallelLinear,
46
+ UnquantizedLinearMethod,
47
+ QKVParallelLinear)
48
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
49
+ from vllm.model_executor.layers.quantization import QuantizationConfig
50
+ from vllm.model_executor.layers.rotary_embedding import get_rope, _rotate_gptj
51
+ from vllm.model_executor.layers.sampler import get_sampler
52
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
53
+ ParallelLMHead, VocabParallelEmbedding)
54
+ from vllm.model_executor.model_loader.weight_utils import (
55
+ default_weight_loader, maybe_remap_kv_scale_name)
56
+ from vllm.model_executor.models.utils import (
57
+ make_layers, maybe_prefix, extract_layer_index)
58
+ from vllm_ascend.ascend_config import get_ascend_config
59
+ from vllm_ascend.distributed.parallel_state import get_ep_group
60
+ from vllm_ascend.ops.fused_moe import AscendFusedMoE
61
+ from vllm_ascend.quantization.quant_config import AscendLinearMethod
62
+ from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
63
+ from vllm_ascend.utils import dispose_tensor, npu_prefetch, get_fused_moe_state, FusedMoEState
64
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
65
+
66
+
67
+ class OpenPanguMergedReplicatedLinear(ReplicatedLinear):
68
+
69
+ def __init__(
70
+ self,
71
+ input_size: int,
72
+ output_sizes: list[int],
73
+ bias: bool = True,
74
+ quant_config: Optional[QuantizationConfig] = None,
75
+ prefix: str = "",
76
+ ):
77
+ self.output_sizes = output_sizes
78
+ super().__init__(input_size,
79
+ sum(output_sizes),
80
+ bias=bias,
81
+ quant_config=quant_config,
82
+ prefix=prefix)
83
+
84
+ def weight_loader(self, param: torch.nn.Parameter,
85
+ loaded_weight: torch.Tensor, loaded_shard_id: int):
86
+ # With no support for GGUF format yet.
87
+ if getattr(param, "is_gguf_weight", False) or getattr(param, "is_gguf_weight_type", False):
88
+ raise ValueError('With no support for GGUF format yet.')
89
+ if loaded_shard_id >= len(self.output_sizes):
90
+ raise ValueError(f'loaded_shard_id {loaded_shard_id} >= len(self.output_sizes) {len(self.output_sizes)}.')
91
+ shard_offset = sum(self.output_sizes[:loaded_shard_id])
92
+ shard_size = self.output_sizes[loaded_shard_id]
93
+ shard = param.data.narrow(param.output_dim, shard_offset, shard_size)
94
+ if shard.size() != loaded_weight.size():
95
+ raise ValueError(f"Tried to load weights of size {loaded_weight.size()} "
96
+ f"to a parameter shard of id {loaded_shard_id} size {shard.size()}.")
97
+ shard.copy_(loaded_weight)
98
+
99
+
100
+ class OpenPanguRowParallelLinearReplaceAllreduce(RowParallelLinear):
101
+
102
+ def forward(
103
+ self,
104
+ input_,
105
+ is_prefill=True
106
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
107
+ if self.input_is_parallel:
108
+ input_parallel = input_
109
+ else:
110
+ tp_rank = get_tensor_model_parallel_rank()
111
+ splitted_input = split_tensor_along_last_dim(
112
+ input_, num_partitions=self.tp_size)
113
+ input_parallel = splitted_input[tp_rank].contiguous()
114
+
115
+ # Matrix multiply.
116
+ if self.quant_method is None:
117
+ raise ValueError('self.quant_method is None.')
118
+ # Only fuse bias add into GEMM for rank 0 (this ensures that
119
+ # bias will not get added more than once in TP>1 case)
120
+ bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
121
+ output_parallel = self.quant_method.apply(self,
122
+ input_parallel,
123
+ bias=bias_)
124
+ if self.reduce_results and self.tp_size > 1:
125
+ if not is_prefill and output_parallel.shape[0] % self.tp_size == 0:
126
+ output = tensor_model_parallel_reduce_scatter(output_parallel,
127
+ dim=0)
128
+ else:
129
+ output = tensor_model_parallel_all_reduce(output_parallel)
130
+ else:
131
+ output = output_parallel
132
+
133
+ output_bias = self.bias if self.skip_bias_add else None
134
+
135
+ if not self.return_bias:
136
+ return output
137
+ return output, output_bias
138
+
139
+
140
+ class OpenPanguRowParallelLinear(RowParallelLinear):
141
+
142
+ def forward(
143
+ self,
144
+ input_,
145
+ is_prefill=True
146
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
147
+ return super().forward(input_)
148
+
149
+
150
+ class OpenPanguRotaryEmbedding(nn.Module):
151
+ def __init__(self,
152
+ head_size: int,
153
+ rotary_dim: int,
154
+ max_position_embeddings: int,
155
+ base: float,
156
+ ):
157
+ super().__init__()
158
+ self.dim = rotary_dim
159
+ self.max_position_embeddings = max_position_embeddings
160
+ self.base = base
161
+ self._set_cos_sin_cache(
162
+ seq_len=max_position_embeddings,
163
+ device='npu',
164
+ dtype=torch.get_default_dtype(),
165
+ )
166
+
167
+ def _set_cos_sin_cache(self,
168
+ seq_len: int,
169
+ device: str,
170
+ dtype: torch.dtype
171
+ ):
172
+ self.max_seq_len = seq_len
173
+ inv_freq = 1.0 / (
174
+ self.base
175
+ ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device='npu') / self.dim)
176
+ )
177
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
178
+ t = torch.arange(seq_len, device='npu', dtype=torch.float32)
179
+ freqs = torch.outer(t, inv_freq)
180
+ emb = torch.cat((freqs, freqs), dim=-1)
181
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
182
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
183
+
184
+ def forward(self,
185
+ positions: torch.Tensor,
186
+ query: torch.Tensor,
187
+ key: torch.Tensor,
188
+ offsets: Optional[torch.Tensor] = None,
189
+ max_seq_len: Optional[int] = None,
190
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
191
+ if max_seq_len is not None and max_seq_len > self.max_seq_len:
192
+ self._set_cos_sin_cache(max_seq_len, query.device, query.dtype)
193
+ idx = torch.add(positions, offsets) if offsets is not None else positions
194
+ cos = self.cos_cached[idx]
195
+ sin = self.sin_cached[idx]
196
+ # Adapt: adapt cos and sin shape
197
+ cos = cos.view(-1, 1, cos.shape[-1])
198
+ sin = sin.view(-1, 1, sin.shape[-1])
199
+ # Adapt end.
200
+ query_rot = query * cos + _rotate_gptj(query) * sin
201
+ if key is not None:
202
+ key_rot = key * cos + _rotate_gptj(key) * sin
203
+ return query_rot, key_rot
204
+
205
+
206
+ class OpenPanguSiluAndMul(SiluAndMul):
207
+
208
+ def __init__(self,
209
+ *,
210
+ weight_scale: Optional[Callable[[], torch.Tensor]] = None):
211
+ super().__init__()
212
+ self.weight_scale = weight_scale
213
+
214
+ def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
215
+ torch.Tensor]]):
216
+ if isinstance(x, tuple):
217
+ if self.weight_scale is None:
218
+ raise ValueError('self.weight_scale is None.')
219
+ quantized_x, dynamic_scale = x
220
+ return torch_npu.npu_dequant_swiglu_quant(
221
+ x=quantized_x,
222
+ weight_scale=self.weight_scale(),
223
+ activation_scale=dynamic_scale,
224
+ activate_left=True,
225
+ quant_mode=1)
226
+ else:
227
+ return super().forward_oot(x)
228
+
229
+
230
+ def check_ffn_act_fn(act_fn: str):
231
+ if act_fn != "silu":
232
+ raise ValueError(
233
+ f"Unsupported activation: {act_fn}. Only silu is supported for now.")
234
+
235
+
236
+ class OpenPanguMLP(nn.Module):
237
+
238
+ def __init__(
239
+ self,
240
+ hidden_size: int,
241
+ intermediate_size: int,
242
+ hidden_act: str,
243
+ quant_config: Optional[QuantizationConfig] = None,
244
+ bias: bool = False,
245
+ reduce_results: bool = True,
246
+ force_replicate: bool = False,
247
+ prefix: str = "",
248
+ ) -> None:
249
+ super().__init__()
250
+ if not force_replicate:
251
+ self.gate_up_proj = MergedColumnParallelLinear(
252
+ hidden_size, [intermediate_size] * 2,
253
+ bias=bias,
254
+ quant_config=quant_config,
255
+ prefix=f"{prefix}.gate_up_proj")
256
+ self.down_proj = RowParallelLinear(intermediate_size,
257
+ hidden_size,
258
+ bias=bias,
259
+ quant_config=quant_config,
260
+ reduce_results=reduce_results,
261
+ prefix=f"{prefix}.down_proj")
262
+ else:
263
+ self.gate_up_proj = OpenPanguMergedReplicatedLinear(
264
+ hidden_size, [intermediate_size] * 2,
265
+ bias=bias,
266
+ quant_config=quant_config,
267
+ prefix=f"{prefix}.gate_up_proj")
268
+ self.down_proj = ReplicatedLinear(intermediate_size,
269
+ hidden_size,
270
+ bias=bias,
271
+ quant_config=quant_config,
272
+ prefix=f"{prefix}.down_proj")
273
+
274
+ check_ffn_act_fn(hidden_act)
275
+
276
+ quant_method = self.gate_up_proj.quant_method
277
+ if isinstance(quant_method, UnquantizedLinearMethod):
278
+ self.act_fn = OpenPanguSiluAndMul()
279
+ elif (isinstance(quant_method, AscendLinearMethod) and isinstance(
280
+ quant_method.quant_method, AscendW8A8DynamicLinearMethod)):
281
+ # TODO(sdmyzlp): Currently preserved as before:
282
+ # 1. The only quantization supported for silu is W8A8Dynamic
283
+ # 2. Output dtype of gate_up/down is fixed to be int32/bfloat16
284
+ #
285
+ # Maybe one can implement a better and more general configuration
286
+ # scheme, e.g. by somehow passing around the tweaked `quant_config`
287
+ self.act_fn = OpenPanguSiluAndMul(
288
+ # Use lazy binding, for `weight_scale_fp32` is accessible
289
+ # only after `process_weights_after_loading`.
290
+ weight_scale=lambda: self.gate_up_proj.weight_scale_fp32)
291
+ # To be consumed by AscendW8A8DynamicLinearMethod.apply()
292
+ self.gate_up_proj._ascend_quant_config = {
293
+ "output_dtype": torch.int32,
294
+ "pertoken_scale": False,
295
+ "return_scale": True,
296
+ }
297
+ self.down_proj._ascend_quant_config = {
298
+ "output_dtype": torch.bfloat16,
299
+ "pertoken_scale": True,
300
+ "return_scale": False,
301
+ }
302
+ else:
303
+ raise NotImplementedError(
304
+ f"Quantization with [{type(quant_method)}] is NOT supported")
305
+
306
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
307
+ return self.down_proj(self.act_fn(self.gate_up_proj(x)[0]))[0]
308
+
309
+
310
+ class OpenPanguMoE(nn.Module):
311
+
312
+ top_k: int
313
+
314
+ def __init__(
315
+ self,
316
+ config: PretrainedConfig,
317
+ quant_config: Optional[QuantizationConfig] = None,
318
+ prefix: str = "",
319
+ ):
320
+ super().__init__()
321
+ ascend_config = get_ascend_config()
322
+ self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
323
+ self.enable_multistream_moe = \
324
+ ascend_config.torchair_graph_config.enable_multistream_moe
325
+ self.routed_scaling_factor = config.routed_scaling_factor
326
+ check_ffn_act_fn(config.hidden_act)
327
+
328
+ self.gate = ReplicatedLinear(config.hidden_size,
329
+ config.num_routed_experts,
330
+ bias=False,
331
+ quant_config=None,
332
+ prefix=f"{prefix}.gate")
333
+
334
+ self.experts = AscendFusedMoE(
335
+ num_experts=config.num_routed_experts,
336
+ top_k=config.num_experts_per_tok,
337
+ hidden_size=config.hidden_size,
338
+ intermediate_size=config.moe_intermediate_size,
339
+ reduce_results=False,
340
+ renormalize=config.norm_topk_prob,
341
+ quant_config=quant_config,
342
+ use_grouped_topk=True,
343
+ num_expert_group=1,
344
+ topk_group=1,
345
+ prefix=f"{prefix}.experts",
346
+ scoring_func='sigmoid',
347
+ e_score_correction_bias=None)
348
+
349
+ if config.num_shared_experts is not None:
350
+ self.all_reduce_merge = self.experts.all_reduce_merge
351
+ reduce_results = not self.all_reduce_merge
352
+ intermediate_size = (config.moe_intermediate_size * config.num_shared_experts)
353
+ self.shared_experts = OpenPanguMLP(
354
+ hidden_size=config.hidden_size,
355
+ intermediate_size=intermediate_size,
356
+ hidden_act=config.hidden_act,
357
+ quant_config=quant_config,
358
+ reduce_results=reduce_results,
359
+ force_replicate=self.enable_multistream_moe,
360
+ prefix=f"{prefix}.shared_experts",
361
+ )
362
+ else:
363
+ self.shared_experts = None # type: ignore
364
+
365
+ self.tp_size = get_tensor_model_parallel_world_size()
366
+ self.dp_size = get_dp_group().world_size
367
+ self.tp_group = get_tp_group().device_group
368
+ self.tp_rank = get_tp_group().rank_in_group
369
+ self.ep_group = get_ep_group()
370
+
371
+ self.params_dtype = torch.get_default_dtype()
372
+ self.rm_router_logits = self.experts.rm_router_logits
373
+
374
+ self.__class__.top_k = config.num_experts_per_tok
375
+
376
+ def forward(self,
377
+ hidden_states: torch.Tensor,
378
+ attn_metadata: Optional[AttentionMetadata] = None,
379
+ replace_allreduce: bool = False) -> torch.Tensor:
380
+
381
+ if attn_metadata is None:
382
+ attn_metadata = get_forward_context().attn_metadata
383
+ # when profile runs, force experts to load balanced tokens
384
+ # to avoid high memory consumption on a single rank.
385
+ # TODO: need a better flag to indicate whether in profile run or not.
386
+ if attn_metadata is None:
387
+ # for profile run
388
+ is_prefill = True
389
+ fused_moe_state = get_fused_moe_state(self.ep_group.world_size, is_prefill, True)
390
+ enable_force_load_balance = fused_moe_state != FusedMoEState.AllGatherEP
391
+ else:
392
+ is_prefill = attn_metadata.num_prefills > 0
393
+ enable_force_load_balance = False
394
+ if hasattr(attn_metadata, 'with_prefill_across_dp'):
395
+ is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
396
+ fused_moe_state = get_fused_moe_state(self.ep_group.world_size, is_prefill, True)
397
+
398
+ # router_logits: (num_tokens, n_experts)
399
+ router_logits = None
400
+ if not self.rm_router_logits or fused_moe_state == FusedMoEState.All2All:
401
+ router_logits, _ = self.gate(hidden_states.float())
402
+
403
+ routed_hidden_states, shared_hidden_states = self.experts(
404
+ hidden_states=hidden_states,
405
+ router_logits=router_logits,
406
+ is_prefill=is_prefill,
407
+ top_k=self.__class__.top_k,
408
+ enable_force_load_balance=enable_force_load_balance,
409
+ shared_experts=self.shared_experts,
410
+ gate=self.gate,
411
+ replace_allreduce=replace_allreduce)
412
+
413
+ if self.all_reduce_merge and fused_moe_state == FusedMoEState.All2All:
414
+ shared_hidden_states = tensor_model_parallel_all_reduce(shared_hidden_states)
415
+ hidden_states = routed_hidden_states * self.routed_scaling_factor + shared_hidden_states
416
+ if self.all_reduce_merge and fused_moe_state != FusedMoEState.All2All:
417
+ # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
418
+ hidden_states = tensor_model_parallel_all_reduce(hidden_states)
419
+
420
+ return hidden_states
421
+
422
+
423
+ class OpenPanguMLAAttention(nn.Module):
424
+
425
+ def __init__(
426
+ self,
427
+ config: PretrainedConfig,
428
+ hidden_size: int,
429
+ num_heads: int,
430
+ attention_qk_dim: int,
431
+ attention_qk_rope_dim: int,
432
+ attention_v_dim: int,
433
+ attention_q_lora_dim: Optional[int],
434
+ attention_kv_lora_dim: int,
435
+ rope_theta: float = 10000,
436
+ max_position_embeddings: int = 8192,
437
+ cache_config: Optional[CacheConfig] = None,
438
+ quant_config: Optional[QuantizationConfig] = None,
439
+ prefix: str = "",
440
+ ) -> None:
441
+ super().__init__()
442
+ ascend_config = get_ascend_config()
443
+ self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
444
+ self.enable_multistream_mla = ascend_config.torchair_graph_config.enable_multistream_mla
445
+
446
+ self.hidden_size = hidden_size
447
+ self.num_heads = num_heads
448
+ self.attention_qk_dim = attention_qk_dim
449
+ self.attention_qk_rope_dim = attention_qk_rope_dim
450
+ self.qk_head_dim = attention_qk_dim + attention_qk_rope_dim
451
+ self.attention_v_dim = attention_v_dim
452
+ self.attention_q_lora_dim = attention_q_lora_dim
453
+ self.attention_kv_lora_dim = attention_kv_lora_dim
454
+ self.rope_theta = rope_theta
455
+
456
+ tp_size = get_tensor_model_parallel_world_size()
457
+ if num_heads % tp_size != 0:
458
+ raise ValueError(f'num_heads {num_heads} is not divisible by tp_size {tp_size}.')
459
+ self.num_local_heads = num_heads // tp_size
460
+
461
+ self.scaling = self.qk_head_dim**-0.5
462
+ self.max_position_embeddings = max_position_embeddings
463
+
464
+ self.prefix = prefix
465
+ self.debug_layer_idx = int(self.prefix.split(".")[-2])
466
+
467
+ if self.attention_q_lora_dim is not None:
468
+ self.q_a_proj = ReplicatedLinear(self.hidden_size,
469
+ self.attention_q_lora_dim,
470
+ bias=False,
471
+ quant_config=quant_config,
472
+ prefix=f"{prefix}.q_a_proj")
473
+ self.q_a_layernorm = RMSNorm(self.attention_q_lora_dim, eps=config.rms_norm_eps)
474
+ self.q_b_proj = ColumnParallelLinear(attention_q_lora_dim,
475
+ self.num_heads * self.qk_head_dim,
476
+ bias=False,
477
+ quant_config=quant_config,
478
+ prefix=f"{prefix}.q_b_proj")
479
+ else:
480
+ self.q_proj = ColumnParallelLinear(self.hidden_size,
481
+ self.num_heads * self.qk_head_dim,
482
+ bias=False,
483
+ quant_config=quant_config,
484
+ prefix=f"{prefix}.q_proj")
485
+
486
+ self.kv_a_proj_with_mqa = ReplicatedLinear(
487
+ self.hidden_size,
488
+ self.attention_kv_lora_dim + self.attention_qk_rope_dim,
489
+ bias=False,
490
+ quant_config=quant_config,
491
+ prefix=f"{prefix}.kv_a_proj_with_mqa")
492
+ self.kv_a_layernorm = RMSNorm(self.attention_kv_lora_dim,
493
+ eps=config.rms_norm_eps)
494
+ self.kv_b_proj = ColumnParallelLinear(
495
+ self.attention_kv_lora_dim,
496
+ self.num_heads * (self.attention_qk_dim + self.attention_v_dim),
497
+ bias=False,
498
+ quant_config=quant_config,
499
+ prefix=f"{prefix}.kv_b_proj")
500
+ if (config.num_routed_experts is not None
501
+ and self.debug_layer_idx >= config.num_dense_layers and
502
+ ascend_config.torchair_graph_config.enable_multistream_moe):
503
+ self.o_proj = OpenPanguRowParallelLinearReplaceAllreduce(
504
+ self.num_heads * self.attention_v_dim,
505
+ self.hidden_size,
506
+ bias=False,
507
+ quant_config=quant_config,
508
+ prefix=f"{prefix}.o_proj")
509
+ else:
510
+ self.o_proj = OpenPanguRowParallelLinear(
511
+ self.num_heads * self.attention_v_dim,
512
+ self.hidden_size,
513
+ bias=False,
514
+ quant_config=quant_config,
515
+ prefix=f"{prefix}.o_proj")
516
+
517
+ self.rotary_emb = OpenPanguRotaryEmbedding(attention_qk_rope_dim,
518
+ rotary_dim=attention_qk_rope_dim,
519
+ max_position_embeddings=max_position_embeddings,
520
+ base=rope_theta)
521
+
522
+ self.mla_attn = Attention(
523
+ num_heads=self.num_local_heads,
524
+ head_size=self.attention_kv_lora_dim + self.attention_qk_rope_dim,
525
+ scale=self.scaling,
526
+ num_kv_heads=1,
527
+ cache_config=cache_config,
528
+ quant_config=quant_config,
529
+ prefix=f"{prefix}.attn",
530
+ use_mla=True,
531
+ # MLA Args
532
+ q_lora_rank=self.attention_q_lora_dim,
533
+ kv_lora_rank=self.attention_kv_lora_dim,
534
+ qk_nope_head_dim=self.attention_qk_dim,
535
+ qk_rope_head_dim=self.attention_qk_rope_dim,
536
+ qk_head_dim=self.qk_head_dim,
537
+ v_head_dim=self.attention_v_dim,
538
+ rotary_emb=self.rotary_emb,
539
+ q_proj=self.q_proj if self.attention_q_lora_dim is None else self.q_b_proj,
540
+ kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
541
+ kv_a_layernorm=self.kv_a_layernorm,
542
+ kv_b_proj=self.kv_b_proj,
543
+ o_proj=self.o_proj,
544
+ )
545
+
546
+ def forward(
547
+ self,
548
+ positions: torch.Tensor,
549
+ hidden_states: torch.Tensor,
550
+ kv_cache: Optional[torch.Tensor] = None,
551
+ attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
552
+ enable_multistream_mla = (self.enable_multistream_mla
553
+ and attn_metadata is not None
554
+ and not attn_metadata.with_prefill_across_dp
555
+ and attn_metadata.num_decodes > 0)
556
+ forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
557
+ if self.attention_q_lora_dim is not None:
558
+ npu_prefetch(self.q_a_proj.weight,
559
+ hidden_states,
560
+ enabled=enable_multistream_mla)
561
+ ckq = self.q_a_proj(hidden_states)[0]
562
+ hidden_states_or_q_c = self.q_a_layernorm(ckq)
563
+ forward_kwargs['ckq'] = ckq
564
+ else:
565
+ hidden_states_or_q_c = hidden_states
566
+ if self.torchair_graph_enabled:
567
+ if envs.VLLM_USE_V1:
568
+ output_shape = hidden_states.shape
569
+ output = torch.empty(output_shape,
570
+ dtype=hidden_states_or_q_c.dtype,
571
+ device=hidden_states_or_q_c.device)
572
+ forward_kwargs['output'] = output
573
+
574
+ output = self.mla_attn.impl.forward(self.mla_attn,
575
+ hidden_states_or_q_c,
576
+ hidden_states, None, kv_cache,
577
+ attn_metadata,
578
+ **forward_kwargs)
579
+ if envs.VLLM_USE_V1:
580
+ output = output.view(-1, output_shape[-1])
581
+ return output
582
+ else:
583
+ kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
584
+ [self.attention_kv_lora_dim, self.attention_qk_rope_dim], dim=-1)
585
+ kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
586
+ return self.mla_attn(hidden_states_or_q_c,
587
+ kv_c_normed,
588
+ k_pe,
589
+ output_shape=hidden_states.shape)
590
+
591
+
592
+ class OpenPanguEmbeddedAttention(nn.Module):
593
+
594
+ def __init__(
595
+ self,
596
+ config: PretrainedConfig,
597
+ hidden_size: int,
598
+ num_heads: int,
599
+ num_kv_heads: int,
600
+ rope_theta: float = 10000,
601
+ rope_scaling: Optional[dict[str, Any]] = None,
602
+ max_position_embeddings: int = 8192,
603
+ quant_config: Optional[QuantizationConfig] = None,
604
+ bias: bool = False,
605
+ bias_o_proj: bool = False,
606
+ cache_config: Optional[CacheConfig] = None,
607
+ prefix: str = "",
608
+ attn_type: str = AttentionType.DECODER,
609
+ ) -> None:
610
+ super().__init__()
611
+ layer_idx = extract_layer_index(prefix)
612
+ self.hidden_size = hidden_size
613
+ tp_size = get_tensor_model_parallel_world_size()
614
+ self.total_num_heads = num_heads
615
+ if self.total_num_heads % tp_size != 0:
616
+ raise ValueError(f'total_num_heads {total_num_heads} is not divisible by tp_size {tp_size}.')
617
+ self.num_heads = self.total_num_heads // tp_size
618
+ self.total_num_kv_heads = num_kv_heads
619
+ if self.total_num_kv_heads >= tp_size and self.total_num_kv_heads % tp_size != 0:
620
+ # Number of KV heads is greater than TP size, so we partition
621
+ # the KV heads across multiple tensor parallel NPUs.
622
+ raise ValueError(f'Number of KV heads is less than TP size, but total_num_kv_heads {self.total_num_kv_heads} '
623
+ f'is not divisible by tp_size {tp_size}.')
624
+ elif self.total_num_kv_heads < tp_size and tp_size % self.total_num_kv_heads != 0:
625
+ # Number of KV heads is less than TP size, so we replicate
626
+ # the KV heads across multiple tensor parallel NPUs.
627
+ raise ValueError(f'Number of KV heads is less than TP size, but tp_size {tp_size} '
628
+ f'is not divisible by total_num_kv_heads {self.total_num_kv_heads}.')
629
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
630
+ # MistralConfig has an optional head_dim introduced by Mistral-Nemo
631
+ head_dim = getattr(config, "head_dim", None)
632
+ if head_dim is None:
633
+ head_dim = self.hidden_size // self.total_num_heads
634
+ self.head_dim = head_dim
635
+ # Phi models introduced a partial_rotary_factor parameter in the config
636
+ self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
637
+ self.q_size = self.num_heads * self.head_dim
638
+ self.kv_size = self.num_kv_heads * self.head_dim
639
+ self.scaling = self.head_dim**-0.5
640
+ self.rope_theta = rope_theta
641
+ self.max_position_embeddings = max_position_embeddings
642
+
643
+ self.qkv_proj = QKVParallelLinear(
644
+ hidden_size=hidden_size,
645
+ head_size=self.head_dim,
646
+ total_num_heads=self.total_num_heads,
647
+ total_num_kv_heads=self.total_num_kv_heads,
648
+ bias=bias,
649
+ quant_config=quant_config,
650
+ prefix=f"{prefix}.qkv_proj",
651
+ )
652
+
653
+ self.o_proj = RowParallelLinear(
654
+ input_size=self.total_num_heads * self.head_dim,
655
+ output_size=hidden_size,
656
+ bias=bias_o_proj,
657
+ quant_config=quant_config,
658
+ prefix=f"{prefix}.o_proj",
659
+ )
660
+
661
+ self._init_rotary_emb(config,
662
+ rope_scaling=rope_scaling,
663
+ quant_config=quant_config)
664
+
665
+ if hasattr(config, "interleaved_sliding_window"):
666
+ interleaved_sliding_window = config.interleaved_sliding_window
667
+ if isinstance(interleaved_sliding_window, int):
668
+ sliding_window = interleaved_sliding_window
669
+ elif isinstance(interleaved_sliding_window, list):
670
+ sw_idx = layer_idx % len(interleaved_sliding_window)
671
+ sliding_window = interleaved_sliding_window[sw_idx]
672
+ else:
673
+ raise ValueError(
674
+ f"{type(interleaved_sliding_window)} is not supported.")
675
+ else:
676
+ sliding_window = None
677
+
678
+ self.attn = Attention(
679
+ self.num_heads,
680
+ self.head_dim,
681
+ self.scaling,
682
+ num_kv_heads=self.num_kv_heads,
683
+ cache_config=cache_config,
684
+ quant_config=quant_config,
685
+ per_layer_sliding_window=sliding_window,
686
+ attn_type=attn_type,
687
+ prefix=f"{prefix}.attn",
688
+ )
689
+
690
+ def forward(
691
+ self,
692
+ positions: torch.Tensor,
693
+ hidden_states: torch.Tensor,
694
+ kv_cache: Optional[torch.Tensor] = None,
695
+ attn_metadata: Optional[AttentionMetadata] = None
696
+ ) -> torch.Tensor:
697
+ qkv, _ = self.qkv_proj(hidden_states)
698
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
699
+ q, k = self.rotary_emb(positions, q, k)
700
+ attn_output = self.attn(q, k, v)
701
+ output, _ = self.o_proj(attn_output)
702
+ return output
703
+
704
+ def _init_rotary_emb(self, config: PretrainedConfig,
705
+ rope_scaling: Optional[dict[str, Any]],
706
+ quant_config: Optional[QuantizationConfig]) -> None:
707
+ is_neox_style = True
708
+ is_gguf = quant_config and quant_config.get_name() == "gguf"
709
+ if is_gguf and config.model_type == "Pangu":
710
+ is_neox_style = False
711
+
712
+ self.rotary_emb = get_rope(
713
+ self.head_dim,
714
+ rotary_dim=self.head_dim,
715
+ max_position=self.max_position_embeddings,
716
+ base=self.rope_theta,
717
+ rope_scaling=rope_scaling,
718
+ is_neox_style=is_neox_style,
719
+ #partial_rotary_factor=self.partial_rotary_factor,
720
+ )
721
+
722
+
723
+ class OpenPanguDecoderLayer(nn.Module):
724
+
725
+ def __init__(
726
+ self,
727
+ config: PretrainedConfig,
728
+ prefix: str,
729
+ model_config: ModelConfig,
730
+ cache_config: Optional[CacheConfig] = None,
731
+ quant_config: Optional[QuantizationConfig] = None,
732
+ ) -> None:
733
+ super().__init__()
734
+ self.hidden_size = config.hidden_size
735
+ rope_theta = getattr(config, "rope_theta", 10000)
736
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
737
+
738
+ layer_idx = int(prefix.split(sep='.')[-1])
739
+ self.layer_idx = layer_idx
740
+ self.layers = config.num_hidden_layers
741
+ self.tp_size = get_tensor_model_parallel_world_size()
742
+ self.tp_rank = get_tp_group().rank_in_group
743
+ ascend_config = get_ascend_config()
744
+
745
+ self.use_mla = hasattr(config, 'attention_qk_dim') and hasattr(config, 'attention_qk_rope_dim') \
746
+ and hasattr(config, 'attention_v_dim') and hasattr(config, 'attention_kv_lora_dim')
747
+ if self.use_mla:
748
+ self.self_attn = OpenPanguMLAAttention(
749
+ config=config,
750
+ hidden_size=self.hidden_size,
751
+ num_heads=config.num_attention_heads,
752
+ attention_qk_dim=config.attention_qk_dim,
753
+ attention_qk_rope_dim=config.attention_qk_rope_dim,
754
+ attention_v_dim=config.attention_v_dim,
755
+ attention_q_lora_dim=config.attention_q_lora_dim
756
+ if hasattr(config, "attention_q_lora_dim") else None,
757
+ attention_kv_lora_dim=config.attention_kv_lora_dim,
758
+ rope_theta=rope_theta,
759
+ max_position_embeddings=max_position_embeddings,
760
+ cache_config=cache_config,
761
+ quant_config=quant_config,
762
+ prefix=f"{prefix}.self_attn",
763
+ )
764
+ else:
765
+ attention_bias = getattr(config, "attention_bias", False) or getattr(
766
+ config, "bias", False)
767
+ bias_o_proj = attention_bias
768
+ if hasattr(config, 'qkv_bias'):
769
+ attention_bias = config.qkv_bias
770
+ # By default, PanguEmbedded uses causal attention as it is a decoder-only model.
771
+ # You can override the HF config with `is_causal=False` to enable
772
+ # bidirectional attention, which is used in some embedding models
773
+ if getattr(config, "is_causal", True):
774
+ attn_type = AttentionType.DECODER
775
+ else:
776
+ attn_type = AttentionType.ENCODER_ONLY
777
+ self.self_attn = OpenPanguEmbeddedAttention(
778
+ config=config,
779
+ hidden_size=self.hidden_size,
780
+ num_heads=config.num_attention_heads,
781
+ num_kv_heads=getattr(config, "num_key_value_heads", config.num_attention_heads),
782
+ rope_theta=rope_theta,
783
+ rope_scaling=getattr(config, "rope_scaling", None),
784
+ max_position_embeddings=max_position_embeddings,
785
+ quant_config=quant_config,
786
+ bias=attention_bias,
787
+ bias_o_proj=bias_o_proj,
788
+ cache_config=cache_config,
789
+ prefix=f"{prefix}.self_attn",
790
+ attn_type=attn_type,
791
+ )
792
+
793
+ if getattr(config, 'num_routed_experts', None) is not None and layer_idx >= config.num_dense_layers:
794
+ self.mlp = OpenPanguMoE(
795
+ config=config,
796
+ quant_config=quant_config,
797
+ prefix=f"{prefix}.mlp",
798
+ )
799
+ self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \
800
+ and model_config.use_mla and envs.VLLM_USE_V1 and self.tp_size > 1
801
+ else:
802
+ self.mlp = OpenPanguMLP(
803
+ hidden_size=self.hidden_size,
804
+ intermediate_size=config.intermediate_size,
805
+ hidden_act=config.hidden_act,
806
+ quant_config=quant_config,
807
+ bias=getattr(config, "mlp_bias", False),
808
+ prefix=f"{prefix}.mlp",
809
+ )
810
+ self.mla_moe_communication = False
811
+ self.routed_scaling_factor = getattr(config, 'routed_scaling_factor', None)
812
+ self.num_dense_layers = getattr(config, 'num_dense_layers', None)
813
+
814
+ self.input_layernorm = RMSNorm(config.hidden_size,
815
+ eps=config.rms_norm_eps)
816
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
817
+ eps=config.rms_norm_eps)
818
+ if getattr(config, 'sandwich_norm', False):
819
+ self.sandwich_norm = True
820
+ self.pre_mlp_layernorm = RMSNorm(config.hidden_size,
821
+ eps=config.rms_norm_eps)
822
+ self.post_mlp_layernorm = RMSNorm(config.hidden_size,
823
+ eps=config.rms_norm_eps)
824
+ else:
825
+ self.sandwich_norm = False
826
+
827
+ def forward(
828
+ self,
829
+ positions: torch.Tensor,
830
+ hidden_states: torch.Tensor,
831
+ residual: Optional[torch.Tensor],
832
+ kv_cache: Optional[torch.Tensor] = None,
833
+ attn_metadata: Optional[AttentionMetadata] = None,
834
+ replace_allreduce: bool = False,
835
+ ) -> torch.Tensor:
836
+ # Self Attention
837
+ if self.use_mla and attn_metadata is not None and attn_metadata.num_decodes > 0:
838
+ mla_moe_communication = self.mla_moe_communication and replace_allreduce
839
+ else:
840
+ mla_moe_communication = False
841
+ if residual is None:
842
+ residual = hidden_states
843
+ hidden_states = self.input_layernorm(hidden_states)
844
+ else:
845
+ previous_hidden_states, previous_residual = hidden_states, residual
846
+ hidden_states, residual = self.input_layernorm(
847
+ hidden_states, residual)
848
+ # Dispose hidden_states and residual from the previous layer
849
+ # to save npu memory because they're no longer used.
850
+ dispose_tensor(previous_hidden_states)
851
+ dispose_tensor(previous_residual)
852
+ if mla_moe_communication and self.layer_idx > self.num_dense_layers:
853
+ hidden_states = tensor_model_parallel_all_gather(hidden_states,
854
+ dim=0)
855
+
856
+ hidden_states = self.self_attn(
857
+ positions=positions,
858
+ hidden_states=hidden_states,
859
+ kv_cache=kv_cache,
860
+ attn_metadata=attn_metadata,
861
+ )
862
+
863
+ if mla_moe_communication and residual.shape[0] != hidden_states.shape[0]:
864
+ chunk_hidden_states = torch.tensor_split(residual,
865
+ self.tp_size,
866
+ dim=0)
867
+ residual = chunk_hidden_states[self.tp_rank]
868
+
869
+ if self.routed_scaling_factor is not None and hidden_states.dtype == torch.float16:
870
+ # Fix FP16 overflow
871
+ # We scale both hidden_states and residual before
872
+ # rmsnorm, and rmsnorm result would not affect by scale.
873
+ hidden_states *= 1. / self.routed_scaling_factor
874
+ if self.layer_idx == 0:
875
+ # The residual is shared by all layers, we only scale it on
876
+ # first layer.
877
+ residual *= 1. / self.routed_scaling_factor
878
+
879
+ if self.sandwich_norm:
880
+ hidden_states = self.post_attention_layernorm(
881
+ hidden_states)
882
+ hidden_states, residual = self.pre_mlp_layernorm(
883
+ hidden_states, residual)
884
+ else:
885
+ hidden_states, residual = self.post_attention_layernorm(
886
+ hidden_states, residual)
887
+
888
+ # Fully Connected
889
+ if isinstance(self.mlp, OpenPanguMoE):
890
+ hidden_states = self.mlp(hidden_states,
891
+ attn_metadata,
892
+ replace_allreduce=mla_moe_communication)
893
+ else:
894
+ hidden_states = self.mlp(hidden_states)
895
+
896
+ if self.routed_scaling_factor is not None and isinstance(self.mlp, OpenPanguMLP) \
897
+ and hidden_states.dtype == torch.float16:
898
+ hidden_states *= 1. / self.routed_scaling_factor
899
+
900
+ if self.sandwich_norm:
901
+ hidden_states = self.post_mlp_layernorm(hidden_states)
902
+
903
+ if mla_moe_communication and self.layer_idx == self.layers - 1:
904
+ hidden_states = tensor_model_parallel_all_gather(hidden_states,
905
+ dim=0)
906
+ residual = tensor_model_parallel_all_gather(residual, dim=0)
907
+
908
+ return hidden_states, residual
909
+
910
+
911
+ @support_torch_compile
912
+ class OpenPanguModel(nn.Module):
913
+
914
+ fall_back_to_pt_during_load = False
915
+
916
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
917
+ super().__init__()
918
+
919
+ config = vllm_config.model_config.hf_config
920
+ model_config = vllm_config.model_config
921
+ cache_config = vllm_config.cache_config
922
+ quant_config = vllm_config.quant_config
923
+
924
+ self.padding_idx = config.pad_token_id
925
+ self.vocab_size = config.vocab_size
926
+ self.tp_size = get_tensor_model_parallel_world_size()
927
+
928
+ self.embed_tokens = VocabParallelEmbedding(
929
+ config.vocab_size,
930
+ config.hidden_size,
931
+ quant_config=quant_config,
932
+ prefix=f"{prefix}.embed_tokens")
933
+
934
+ self.start_layer, self.end_layer, self.layers = make_layers(
935
+ config.num_hidden_layers,
936
+ lambda prefix: OpenPanguDecoderLayer(
937
+ config,
938
+ prefix,
939
+ model_config=model_config,
940
+ cache_config=cache_config,
941
+ quant_config=quant_config,
942
+ ),
943
+ prefix=f"{prefix}.layers")
944
+
945
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
946
+
947
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
948
+ return self.embed_tokens(input_ids)
949
+
950
+ def forward(
951
+ self,
952
+ input_ids: torch.Tensor,
953
+ positions: torch.Tensor,
954
+ kv_caches: Optional[List[torch.Tensor]] = None,
955
+ attn_metadata: Optional[AttentionMetadata] = None,
956
+ inputs_embeds: Optional[torch.Tensor] = None,
957
+ **kwargs,
958
+ ) -> torch.Tensor:
959
+ if inputs_embeds is not None:
960
+ hidden_states = inputs_embeds
961
+ else:
962
+ hidden_states = self.get_input_embeddings(input_ids)
963
+ residual = None
964
+
965
+ replace_allreduce = hidden_states.shape[0] % self.tp_size == 0
966
+
967
+ for i in range(self.start_layer, self.end_layer):
968
+ layer = self.layers[i]
969
+ hidden_states, residual = layer(
970
+ positions,
971
+ hidden_states,
972
+ residual,
973
+ kv_caches[i -
974
+ self.start_layer] if kv_caches is not None else None,
975
+ attn_metadata,
976
+ replace_allreduce=replace_allreduce)
977
+
978
+ hidden_states, _ = self.norm(hidden_states, residual)
979
+ return hidden_states
980
+
981
+
982
+ class OpenPanguForCausalLM(nn.Module):
983
+ packed_modules_mapping = {
984
+ "gate_up_proj": ["gate_proj", "up_proj"],
985
+ "experts":
986
+ ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
987
+ }
988
+
989
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
990
+ super().__init__()
991
+ config = vllm_config.model_config.hf_config
992
+ quant_config = vllm_config.quant_config
993
+ self.config = config
994
+ self.quant_config = quant_config
995
+ self.model = OpenPanguModel(vllm_config=vllm_config,
996
+ prefix=maybe_prefix(prefix, "model"))
997
+ self.lm_head = ParallelLMHead(config.vocab_size,
998
+ config.hidden_size,
999
+ quant_config=quant_config,
1000
+ prefix=maybe_prefix(prefix, "lm_head"))
1001
+ self.logits_processor = LogitsProcessor(config.vocab_size)
1002
+ self.sampler = get_sampler()
1003
+
1004
+ def load_attn_mlp_weight(self,
1005
+ attn_mlp_replace_mapping: List[Tuple[str, str, int]],
1006
+ params_dict: Dict[str, Any],
1007
+ weight_name: str,
1008
+ loaded_weight: torch.Tensor,
1009
+ loaded_params: set[str]) -> bool:
1010
+ for (param_name, origin_name, shard_id) in attn_mlp_replace_mapping:
1011
+ if origin_name not in weight_name or \
1012
+ (("mlp.experts." in weight_name) and weight_name not in params_dict):
1013
+ continue
1014
+ weight_name = weight_name.replace(origin_name, param_name)
1015
+ if weight_name.endswith(".bias") and weight_name not in params_dict:
1016
+ continue
1017
+ param = params_dict[weight_name]
1018
+ weight_loader = param.weight_loader
1019
+ weight_loader(param, loaded_weight, shard_id)
1020
+ loaded_params.add(weight_name)
1021
+ return True
1022
+ return False
1023
+
1024
+ def load_expert_weight(self,
1025
+ expert_merge_mapping: List[Tuple[str, str, int, str]],
1026
+ params_dict: Dict[str, Any],
1027
+ weight_name: str,
1028
+ loaded_weight: torch.Tensor,
1029
+ loaded_params: set[str]) -> bool:
1030
+ for mapping in expert_merge_mapping:
1031
+ param_name, origin_name, expert_id, shard_id = mapping
1032
+ if origin_name not in weight_name:
1033
+ continue
1034
+ weight_name = weight_name.replace(origin_name, param_name)
1035
+ param = params_dict[weight_name]
1036
+ weight_loader = param.weight_loader
1037
+ weight_loader(param,
1038
+ loaded_weight,
1039
+ weight_name,
1040
+ shard_id=shard_id,
1041
+ expert_id=expert_id,
1042
+ return_success=False)
1043
+ loaded_params.add(weight_name)
1044
+ return True
1045
+ return False
1046
+
1047
+ def load_weights(self, weights: Iterable[tuple[str,
1048
+ torch.Tensor]]) -> set[str]:
1049
+ # (param_name, shard_name, shard_id)
1050
+ attn_mlp_replace_mapping = [
1051
+ (".qkv_proj", ".q_proj", "q"),
1052
+ (".qkv_proj", ".k_proj", "k"),
1053
+ (".qkv_proj", ".v_proj", "v"),
1054
+ (".gate_up_proj", ".gate_proj", 0),
1055
+ (".gate_up_proj", ".up_proj", 1),
1056
+ ]
1057
+ has_experts = hasattr(self.config, 'num_routed_experts')
1058
+ if has_experts:
1059
+ expert_merge_mapping = AscendFusedMoE.make_expert_params_mapping(
1060
+ ckpt_gate_proj_name="gate_proj",
1061
+ ckpt_down_proj_name="down_proj",
1062
+ ckpt_up_proj_name="up_proj",
1063
+ num_experts=self.config.num_routed_experts)
1064
+
1065
+ params_dict = dict(self.named_parameters())
1066
+ loaded_params: set[str] = set()
1067
+ for name, loaded_weight in weights:
1068
+ if "rotary_emb.inv_freq" in name:
1069
+ continue
1070
+ if 'layers' in name: # skip spec decode layers for main model
1071
+ layer_idx = int(name.split('layers.')[-1].split('.')[0])
1072
+ if layer_idx > self.config.num_hidden_layers:
1073
+ continue
1074
+
1075
+ if 'layers' in name and hasattr(self.config, "num_mtp_layers") \
1076
+ and (self.config.num_mtp_layers > 0):
1077
+ layer_idx = int(name.split('layers.')[-1].split('.')[0])
1078
+ mtp_idx = layer_idx - self.config.num_hidden_layers
1079
+ if mtp_idx >= 0 and mtp_idx < self.config.num_mtp_layers:
1080
+ continue # skip spec decode layers for main model
1081
+ if self.load_attn_mlp_weight(attn_mlp_replace_mapping, params_dict, name, loaded_weight, loaded_params):
1082
+ continue
1083
+ elif has_experts and self.load_expert_weight(expert_merge_mapping, params_dict, name, loaded_weight, loaded_params):
1084
+ continue
1085
+ else:
1086
+ if name.endswith(".bias") and name not in params_dict:
1087
+ continue
1088
+ name = maybe_remap_kv_scale_name(name, params_dict)
1089
+ if name is None:
1090
+ continue
1091
+ param = params_dict[name]
1092
+ weight_loader = getattr(param, "weight_loader",
1093
+ default_weight_loader)
1094
+ weight_loader(param, loaded_weight)
1095
+ loaded_params.add(name)
1096
+ if self.config.tie_word_embeddings:
1097
+ self.lm_head.weight = self.model.embed_tokens.weight
1098
+ return loaded_params
1099
+
1100
+ def forward(
1101
+ self,
1102
+ input_ids: torch.Tensor,
1103
+ positions: torch.Tensor,
1104
+ kv_caches: Optional[List[torch.Tensor]] = None,
1105
+ attn_metadata: Optional[AttentionMetadata] = None,
1106
+ inputs_embeds: Optional[torch.Tensor] = None,
1107
+ **kwargs,
1108
+ ) -> torch.Tensor:
1109
+ hidden_states = self.model(input_ids, positions, kv_caches,
1110
+ attn_metadata, inputs_embeds)
1111
+ return hidden_states
1112
+
1113
+ def compute_logits(
1114
+ self,
1115
+ hidden_states: torch.Tensor,
1116
+ sampling_metadata: SamplingMetadata,
1117
+ ) -> Optional[torch.Tensor]:
1118
+ logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
1119
+ return logits
1120
+
1121
+
1122
+ class PanguUltraMoEForCausalLM(OpenPanguForCausalLM):
1123
+ pass
1124
+
1125
+
1126
+ class PanguEmbeddedForCausalLM(OpenPanguForCausalLM):
1127
+ pass
inference/vllm_ascend/ops/fused_moe.py ADDED
@@ -0,0 +1,1530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2
+ # Copyright 2023 The vLLM team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # This file is a part of the vllm-ascend project.
16
+ # Adapted from vllm/tests/kernels/test_moe.py
17
+
18
+ import os
19
+ from typing import Any, Callable, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ import torch_npu
24
+ from torch import nn
25
+ from vllm.config import get_current_vllm_config
26
+ from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
27
+ get_tensor_model_parallel_world_size,
28
+ tensor_model_parallel_all_reduce)
29
+ from vllm.distributed.parallel_state import get_dp_group, get_tp_group
30
+ from vllm.forward_context import get_forward_context
31
+ from vllm.model_executor.layers.fused_moe.config import \
32
+ FusedMoEConfig # isort: skip
33
+ from vllm.model_executor.layers.fused_moe.config import \
34
+ FusedMoEParallelConfig # isort: skip
35
+ from vllm.model_executor.layers.fused_moe.layer import (
36
+ FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
37
+ from vllm.model_executor.layers.quantization.base_config import \
38
+ QuantizationConfig
39
+
40
+ import vllm_ascend.envs as envs_ascend
41
+ from vllm_ascend.ascend_config import get_ascend_config
42
+ from vllm_ascend.distributed.communication_op import \
43
+ data_parallel_reduce_scatter
44
+ from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
45
+ from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
46
+ from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
47
+ get_all_reduce_merge_state, get_fused_moe_state,
48
+ get_rm_router_logits_state, is_310p,
49
+ npu_stream_switch, npu_wait_tensor)
50
+
51
+ MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
52
+ SELECT_GATING_TOPK_SOTFMAX_EXPERTS: bool = envs_ascend.SELECT_GATING_TOPK_SOTFMAX_EXPERTS
53
+
54
+
55
+ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
56
+ max_row_per_ep_rank: int, num_tokens: int,
57
+ top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
58
+ original_total_elements = num_tokens * top_k
59
+ device = topk_ids.device
60
+ original_dtype = topk_ids.dtype
61
+
62
+ if original_total_elements == 0:
63
+ output_len = ep_size * max_row_per_ep_rank
64
+ topk_ids_pad = torch.full((output_len, ),
65
+ expert_num,
66
+ dtype=original_dtype,
67
+ device=device)
68
+ unpad_indices = torch.full((original_total_elements, ),
69
+ -1,
70
+ dtype=torch.long,
71
+ device=device)
72
+ return topk_ids_pad, unpad_indices
73
+
74
+ experts_per_ep_rank_val = expert_num // ep_size
75
+ if experts_per_ep_rank_val == 0:
76
+ raise ValueError(
77
+ "expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. "
78
+ "Ensure expert_num >= ep_size.")
79
+
80
+ assigned_ep_rank = (topk_ids.float() /
81
+ experts_per_ep_rank_val).to(original_dtype)
82
+ indices_arange = torch.arange(topk_ids.shape[0], device=device)
83
+
84
+ is_new_segment = torch.cat(
85
+ (torch.tensor([True], device=device), assigned_ep_rank[1:]
86
+ != assigned_ep_rank[:-1]))
87
+ temp_start_markers = torch.full_like(indices_arange,
88
+ -1,
89
+ dtype=indices_arange.dtype)
90
+ temp_start_markers[is_new_segment] = indices_arange[is_new_segment]
91
+ start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0]
92
+ token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token
93
+ is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank
94
+ cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long)
95
+ indices_in_rec_cond_list_for_all = cumsum_kept - 1
96
+ unpad_indices = torch.where(
97
+ is_kept_mask, indices_in_rec_cond_list_for_all,
98
+ torch.tensor(-1, device=device, dtype=torch.long))
99
+ output_len = ep_size * max_row_per_ep_rank
100
+ topk_ids_pad = torch.full((output_len, ),
101
+ expert_num,
102
+ dtype=original_dtype,
103
+ device=device)
104
+ if topk_ids.shape[0] > 0:
105
+ all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx
106
+ temp_pad_buffer = torch.full((output_len + 1, ),
107
+ expert_num,
108
+ dtype=original_dtype,
109
+ device=device)
110
+ output_len_tensor = torch.tensor(output_len,
111
+ dtype=torch.long,
112
+ device=device)
113
+ scatter_indices = torch.where(is_kept_mask, all_destination_indices,
114
+ output_len_tensor)
115
+ temp_pad_buffer.scatter_(0, scatter_indices, topk_ids)
116
+ topk_ids_pad = temp_pad_buffer[:output_len]
117
+ return topk_ids_pad, unpad_indices
118
+
119
+
120
+ def fused_experts_with_mc2(
121
+ hidden_states: torch.Tensor,
122
+ w1: torch.Tensor,
123
+ w2: torch.Tensor,
124
+ topk_weights: torch.Tensor,
125
+ topk_ids: torch.Tensor,
126
+ top_k: int,
127
+ expert_map: torch.Tensor = None,
128
+ moe_all_to_all_group_name: Optional[str] = None,
129
+ shared_experts: Optional[Any] = None
130
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
131
+ global_bs = 0
132
+ moe_expert_num = len(expert_map)
133
+ kwargs_mc2 = {
134
+ "x": hidden_states,
135
+ "expert_ids": topk_ids,
136
+ "expert_shard_type": 0,
137
+ "shared_expert_rank_num": 0,
138
+ "moe_expert_num": moe_expert_num,
139
+ "global_bs": global_bs,
140
+ }
141
+
142
+ rank = torch.distributed.get_rank()
143
+
144
+ quant_mode = 0
145
+ ep_group = get_ep_group().device_group
146
+ local_rank = torch.distributed.get_rank(group=ep_group)
147
+ all_to_all_group_size = torch.distributed.get_world_size(ep_group)
148
+
149
+ tp_size = get_etp_group().world_size
150
+ tp_rank = rank % tp_size
151
+
152
+ stage1_kwargs = {
153
+ "scales": None,
154
+ "quant_mode": quant_mode,
155
+ "group_ep": moe_all_to_all_group_name,
156
+ "ep_world_size": all_to_all_group_size,
157
+ "ep_rank_id": local_rank,
158
+ # "group_tp": self.moe_rs_group_name,
159
+ "group_tp": moe_all_to_all_group_name,
160
+ "tp_world_size": tp_size,
161
+ "tp_rank_id": tp_rank,
162
+ }
163
+ kwargs_mc2.update(stage1_kwargs)
164
+
165
+ output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
166
+ expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
167
+ 0:5]
168
+
169
+ if shared_experts is not None:
170
+ with npu_stream_switch("moe_secondary", 0):
171
+ npu_wait_tensor(hidden_states, topk_weights)
172
+ shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
173
+ npu_wait_tensor(shared_gate_up, expand_x)
174
+ shared_act = shared_experts.act_fn(shared_gate_up)
175
+
176
+ w1 = w1.transpose(1, 2)
177
+
178
+ group_list = expert_token_nums.to(torch.int64)
179
+ gate_up_out_list = torch_npu.npu_grouped_matmul(
180
+ x=[expand_x],
181
+ weight=[w1],
182
+ split_item=2,
183
+ # 1 means count mode, to avoid cumulative operation of the group list
184
+ group_list_type=1,
185
+ group_type=0,
186
+ group_list=group_list,
187
+ )
188
+
189
+ # TODO: Remove this in the future.
190
+ gate_up_out = torch.cat(gate_up_out_list, dim=0)
191
+ gate_up_out = torch_npu.npu_swiglu(gate_up_out)
192
+
193
+ w2 = w2.transpose(1, 2)
194
+ down_out_list = torch_npu.npu_grouped_matmul(
195
+ x=[gate_up_out],
196
+ weight=[w2],
197
+ split_item=2,
198
+ group_list_type=1,
199
+ group_type=0,
200
+ group_list=group_list,
201
+ )
202
+
203
+ down_out_list = torch.cat(down_out_list, dim=0)
204
+
205
+ # moeCombine
206
+ kwargs_mc2 = {
207
+ "expand_x": down_out_list,
208
+ "expert_ids": topk_ids,
209
+ "expand_idx": expand_idx,
210
+ "expert_scales": topk_weights.to(torch.float32),
211
+ "expert_shard_type": 0,
212
+ "shared_expert_rank_num": 0,
213
+ "moe_expert_num": moe_expert_num,
214
+ "global_bs": 0,
215
+ }
216
+ tp_recv_counts = output[5]
217
+ stage3_kwargs = {
218
+ "ep_send_counts": ep_recv_counts,
219
+ "group_ep": moe_all_to_all_group_name,
220
+ "ep_world_size": all_to_all_group_size,
221
+ "ep_rank_id": local_rank,
222
+ "tp_send_counts": tp_recv_counts,
223
+ # "group_tp": self.moe_rs_group_name,
224
+ "group_tp": moe_all_to_all_group_name,
225
+ "tp_world_size": tp_size,
226
+ "tp_rank_id": tp_rank,
227
+ }
228
+ kwargs_mc2.update(stage3_kwargs)
229
+
230
+ hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
231
+
232
+ if shared_experts is None:
233
+ return hidden_states
234
+ else:
235
+ with npu_stream_switch("moe_secondary", 0):
236
+ npu_wait_tensor(shared_act, down_out_list)
237
+ shared_hidden_states, _ = shared_experts.down_proj(shared_act)
238
+ return hidden_states, shared_hidden_states
239
+
240
+
241
+ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
242
+ w1: torch.Tensor,
243
+ w2: torch.Tensor,
244
+ group_list: torch.Tensor,
245
+ group_list_type: int = 1) -> torch.Tensor:
246
+ """
247
+ apply MLP: gate_up_proj -> swiglu -> down_proj
248
+
249
+ Args:
250
+ hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
251
+ w1: expert weights1 with shape
252
+ (num_experts, hidden_size, intermediate_size * 2)
253
+ w2: expert weights2 with shape
254
+ (num_experts, intermediate_size, hidden_size)
255
+ group_list: number of tokens for each expert, follow cumsum mode, and
256
+ with shape (num_experts).
257
+ transpose_weight:
258
+ w1: (num_experts, intermediate_size * 2, hidden_size) ->
259
+ (num_experts, hidden_size, intermediate_size * 2)
260
+ w2: (num_experts, hidden_size, intermediate_size) ->
261
+ (num_experts, intermediate_size, hidden_size)
262
+
263
+ Returns:
264
+ hidden_states: output hidden states after MLP.
265
+ """
266
+
267
+ assert len(hidden_states_wrapper) == 1
268
+ hidden_states = hidden_states_wrapper.pop()
269
+
270
+ w1 = w1.transpose(1, 2)
271
+ hidden_states = torch_npu.npu_grouped_matmul(
272
+ x=[hidden_states],
273
+ weight=[w1],
274
+ split_item=2,
275
+ group_list_type=group_list_type,
276
+ group_type=0,
277
+ group_list=group_list,
278
+ )
279
+
280
+ hidden_states = torch.cat(hidden_states, dim=0)
281
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
282
+
283
+ w2 = w2.transpose(1, 2)
284
+ hidden_states = torch_npu.npu_grouped_matmul(
285
+ x=[hidden_states],
286
+ weight=[w2],
287
+ split_item=2,
288
+ group_list_type=group_list_type,
289
+ group_type=0,
290
+ group_list=group_list,
291
+ )
292
+
293
+ hidden_states = torch.cat(hidden_states, dim=0)
294
+ return hidden_states
295
+
296
+
297
+ def fused_experts_with_all2all(
298
+ hidden_states: torch.Tensor,
299
+ w1: torch.Tensor,
300
+ w2: torch.Tensor,
301
+ topk_weights: torch.Tensor,
302
+ topk_ids: torch.Tensor,
303
+ top_k: int,
304
+ expert_map: torch.Tensor = None,
305
+ ep_group: GroupCoordinator = None,
306
+ ):
307
+ original_shape = hidden_states.shape
308
+ if len(original_shape) == 3:
309
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
310
+
311
+ num_tokens, _ = hidden_states.shape
312
+ num_experts = w1.shape[0]
313
+ device = hidden_states.device
314
+
315
+ if expert_map is not None:
316
+ global_num_experts = len(expert_map)
317
+ local_num_experts = global_num_experts // ep_group.world_size
318
+ row_idx_len = num_tokens * top_k
319
+ row_idx = (torch.arange(0,
320
+ row_idx_len,
321
+ dtype=torch.int32,
322
+ device=device).view(top_k, -1).permute(
323
+ 1, 0).contiguous())
324
+ hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
325
+ hidden_states,
326
+ row_idx=row_idx,
327
+ expert_idx=topk_ids,
328
+ active_num=num_tokens)
329
+
330
+ global_expert_tokens = torch.bincount(expanded_expert_idx,
331
+ minlength=global_num_experts)
332
+ scatter_sizes = global_expert_tokens.view(ep_group.world_size,
333
+ -1).sum(-1)
334
+
335
+ gather_sizes = torch.empty_like(scatter_sizes)
336
+ dist.all_to_all_single(gather_sizes,
337
+ scatter_sizes,
338
+ group=ep_group.device_group)
339
+ scatter_size_list = scatter_sizes.cpu().tolist()
340
+ gather_size_list = gather_sizes.cpu().tolist()
341
+
342
+ expanded_expert_idx = expanded_expert_idx % local_num_experts
343
+ hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
344
+ scatter_size_list,
345
+ gather_size_list)
346
+ local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
347
+ scatter_size_list,
348
+ gather_size_list)
349
+
350
+ sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
351
+
352
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
353
+ sorted_local_expert_idx, local_num_experts).to(torch.int64)
354
+
355
+ hidden_states = hidden_states[sorted_idx]
356
+ else:
357
+ row_idx_len = num_tokens * top_k
358
+ row_idx = torch.arange(0,
359
+ row_idx_len,
360
+ dtype=torch.int32,
361
+ device=topk_weights.device).view(
362
+ top_k, -1).permute(1, 0).contiguous()
363
+ hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
364
+ hidden_states,
365
+ row_idx=row_idx,
366
+ expert_idx=topk_ids,
367
+ active_num=num_tokens)
368
+
369
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
370
+ expanded_expert_idx, num_experts)
371
+ expert_tokens = expert_tokens.to(torch.int64)
372
+
373
+ w1 = w1.transpose(1, 2)
374
+ gate_up_out_list = torch_npu.npu_grouped_matmul(
375
+ x=[hidden_states],
376
+ weight=[w1],
377
+ split_item=2,
378
+ group_list_type=0,
379
+ group_type=0,
380
+ group_list=expert_tokens,
381
+ )
382
+
383
+ # TODO: Remove this in the future.
384
+ hidden_states = torch.cat(gate_up_out_list, dim=0)
385
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
386
+
387
+ w2 = w2.transpose(1, 2)
388
+ down_out_list = torch_npu.npu_grouped_matmul(
389
+ x=[hidden_states],
390
+ weight=[w2],
391
+ split_item=2,
392
+ group_list_type=0,
393
+ group_type=0,
394
+ group_list=expert_tokens,
395
+ )
396
+
397
+ hidden_states = torch.cat(down_out_list, dim=0)
398
+
399
+ if expert_map is not None:
400
+ resorted_idx = torch.argsort(sorted_idx)
401
+ hidden_states = hidden_states[resorted_idx]
402
+ hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
403
+ gather_size_list,
404
+ scatter_size_list)
405
+
406
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
407
+ hidden_states,
408
+ skip1=None,
409
+ skip2=None,
410
+ bias=None,
411
+ scales=topk_weights,
412
+ expanded_src_to_dst_row=expanded_row_idx,
413
+ export_for_source_row=topk_ids,
414
+ )
415
+ else:
416
+ # TODO: Reorder device memory 2 times here, replace the current
417
+ # implementation here when suitable operators become available.
418
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
419
+ hidden_states,
420
+ skip1=None,
421
+ skip2=None,
422
+ bias=None,
423
+ scales=topk_weights,
424
+ expanded_src_to_dst_row=expanded_row_idx,
425
+ export_for_source_row=topk_ids,
426
+ )
427
+ if len(original_shape) == 3:
428
+ final_hidden_states = final_hidden_states.view(original_shape)
429
+ return final_hidden_states
430
+
431
+
432
+ # currently expert parallelism implemented with all2all
433
+ # is under-optimized.
434
+ def fused_experts_with_all2all_buffer(
435
+ hidden_states: torch.Tensor,
436
+ w1: torch.Tensor,
437
+ w2: torch.Tensor,
438
+ topk_weights: torch.Tensor,
439
+ topk_ids: torch.Tensor,
440
+ top_k: int,
441
+ max_model_len: int,
442
+ global_batch_size: int,
443
+ expert_map: torch.Tensor = None,
444
+ ep_group: GroupCoordinator = None,
445
+ ):
446
+ original_shape = hidden_states.shape
447
+ if len(original_shape) == 3:
448
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
449
+
450
+ num_tokens, _ = hidden_states.shape
451
+ device = hidden_states.device
452
+
453
+ global_num_experts = len(expert_map)
454
+ local_num_experts = global_num_experts // ep_group.world_size
455
+ row_idx_len = num_tokens * top_k
456
+ row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32,
457
+ device=device).view(top_k,
458
+ -1).permute(1, 0).contiguous())
459
+ hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
460
+ hidden_states,
461
+ row_idx=row_idx,
462
+ expert_idx=topk_ids,
463
+ active_num=num_tokens)
464
+
465
+ max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) *
466
+ max_model_len // ep_group.world_size +
467
+ 1) * top_k * 2
468
+ expert_idx_buffer_scatter, unpad_indices = process_topk_ids(
469
+ expanded_expert_idx, global_num_experts, ep_group.world_size,
470
+ max_row_per_ep_rank, num_tokens, top_k)
471
+ hidden_states_pad_idx = torch.zeros(
472
+ expert_idx_buffer_scatter.shape,
473
+ dtype=expert_idx_buffer_scatter.dtype,
474
+ device=expert_idx_buffer_scatter.device)
475
+ non_pad_len = torch.sum((expert_idx_buffer_scatter
476
+ != global_num_experts).to(torch.int32))
477
+ hidden_states_pad_idx[expert_idx_buffer_scatter !=
478
+ global_num_experts] = torch.arange(
479
+ non_pad_len,
480
+ dtype=expert_idx_buffer_scatter.dtype,
481
+ device=hidden_states.device)
482
+
483
+ hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx]
484
+ expert_idx_buffer_gather = torch.empty_like(
485
+ expert_idx_buffer_scatter,
486
+ dtype=expert_idx_buffer_scatter.dtype,
487
+ device=expert_idx_buffer_scatter.device)
488
+ hidden_states_buffer_gather = torch.empty_like(
489
+ hidden_states_buffer_scatter,
490
+ dtype=hidden_states_buffer_scatter.dtype,
491
+ device=hidden_states_buffer_scatter.device)
492
+ dist.all_to_all_single(expert_idx_buffer_gather,
493
+ expert_idx_buffer_scatter,
494
+ group=ep_group.device_group)
495
+ dist.all_to_all_single(hidden_states_buffer_gather,
496
+ hidden_states_buffer_scatter,
497
+ group=ep_group.device_group)
498
+ mask = expert_idx_buffer_gather != global_num_experts
499
+ local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * (
500
+ global_num_experts // ep_group.world_size)
501
+ hidden_states = hidden_states_buffer_gather[mask]
502
+ idx_type = local_expert_idx.dtype
503
+ sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float())
504
+ sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type)
505
+
506
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
507
+ sorted_local_expert_idx, local_num_experts).to(torch.int64)
508
+ hidden_states = hidden_states[sorted_idx]
509
+ group_list_type = 0
510
+
511
+ hidden_states_wrapper = [hidden_states]
512
+ del hidden_states
513
+
514
+ hidden_states = apply_mlp(hidden_states_wrapper,
515
+ w1,
516
+ w2,
517
+ expert_tokens,
518
+ group_list_type=group_list_type)
519
+
520
+ resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype)
521
+ hidden_states = hidden_states[resorted_idx]
522
+ hidden_states_scatter = torch.zeros(
523
+ (mask.shape[0], hidden_states.shape[1]),
524
+ dtype=hidden_states.dtype,
525
+ device=hidden_states.device)
526
+ hidden_states_scatter[mask] = hidden_states
527
+ hidden_states_gatter = torch.empty_like(
528
+ hidden_states_scatter,
529
+ dtype=hidden_states_scatter.dtype,
530
+ device=hidden_states_scatter.device)
531
+ dist.all_to_all_single(hidden_states_gatter,
532
+ hidden_states_scatter,
533
+ group=ep_group.device_group)
534
+ hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter !=
535
+ global_num_experts]
536
+ if hidden_states_gatter.shape[0] != row_idx_len:
537
+ hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]),
538
+ dtype=hidden_states.dtype,
539
+ device=hidden_states.device)
540
+ hidden_states[unpad_indices != -1] = hidden_states_gatter
541
+ else:
542
+ # TODO: Reorder device memory 2 times here, replace the current
543
+ hidden_states = hidden_states_gatter
544
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
545
+ hidden_states,
546
+ skip1=None,
547
+ skip2=None,
548
+ bias=None,
549
+ scales=topk_weights,
550
+ expanded_src_to_dst_row=expanded_row_idx,
551
+ export_for_source_row=topk_ids,
552
+ )
553
+
554
+ if len(original_shape) == 3:
555
+ final_hidden_states = final_hidden_states.view(original_shape)
556
+ return final_hidden_states
557
+
558
+
559
+ def fused_experts_moge(
560
+ hidden_states: torch.Tensor,
561
+ w1: torch.Tensor,
562
+ w2: torch.Tensor,
563
+ topk_weights: torch.Tensor,
564
+ topk_ids: torch.Tensor,
565
+ top_k: int,
566
+ global_num_experts: int,
567
+ expert_map: torch.Tensor = None,
568
+ apply_router_weight_on_input: bool = False,
569
+ ) -> torch.Tensor:
570
+ """
571
+
572
+ Args:
573
+ hidden_states: Hidden states of shape (num_tokens, hidden_size).
574
+ w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
575
+ w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
576
+ topk_weights: Routing weights of shape (num_tokens, top_k).
577
+ topk_ids: Selected expert IDs of shape (num_tokens, top_k).
578
+ top_k: Number of experts to select.
579
+ expert_map: Expert mapping of shape (num_experts,).
580
+
581
+ Returns:
582
+ hidden_states: Hidden states after routing.
583
+ """
584
+ ep_size = get_ep_group().world_size
585
+ local_num_experts = global_num_experts // ep_size
586
+ local_num_group = top_k // ep_size
587
+
588
+ if apply_router_weight_on_input:
589
+ assert (topk_weights.dim() == 2
590
+ ), "`topk_weights` should be in shape (num_tokens, topk)"
591
+ _, topk = topk_weights.shape
592
+ assert (
593
+ topk == 1
594
+ ), "Only support topk=1 when `apply_router_weight_on_input` is True"
595
+ hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
596
+
597
+ bsz, _ = hidden_states.shape
598
+ flatten_topk_ids = topk_ids.view(-1)
599
+ sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
600
+ sorted_topk_ids = sorted_topk_ids.to(torch.int32)
601
+ sorted_hidden_states = hidden_states.index_select(
602
+ 0, sorted_topk_ids // local_num_group)
603
+
604
+ experts_id = torch.arange(0,
605
+ local_num_experts,
606
+ dtype=topk_ids.dtype,
607
+ device=topk_ids.device)
608
+ num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
609
+ torch.float32).sum(0)
610
+ topk_scales = topk_weights.view(-1).index_select(
611
+ 0, sorted_topk_ids).unsqueeze(-1)
612
+ group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
613
+
614
+ w1 = w1.transpose(1, 2)
615
+ gate_up_out = torch_npu.npu_grouped_matmul(
616
+ x=[sorted_hidden_states],
617
+ weight=[w1],
618
+ split_item=2,
619
+ group_list_type=0,
620
+ group_type=0,
621
+ group_list=group_list,
622
+ )[0]
623
+
624
+ if is_310p():
625
+ gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
626
+ torch.float16)
627
+ else:
628
+ gate_up_out = torch_npu.npu_swiglu(gate_up_out)
629
+ gate_up_out *= topk_scales
630
+
631
+ w2 = w2.transpose(1, 2)
632
+ down_out_list = torch_npu.npu_grouped_matmul(
633
+ x=[gate_up_out],
634
+ weight=[w2],
635
+ split_item=2,
636
+ group_list_type=0,
637
+ group_type=0,
638
+ group_list=group_list,
639
+ )[0]
640
+
641
+ unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
642
+ unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
643
+ final_hidden_states = unsorted_hidden_states.reshape(
644
+ bsz, top_k // ep_size, -1).sum(1)
645
+
646
+ return final_hidden_states
647
+
648
+
649
+ def fused_experts(
650
+ hidden_states: torch.Tensor,
651
+ w1: torch.Tensor,
652
+ w2: torch.Tensor,
653
+ topk_weights: torch.Tensor,
654
+ topk_ids: torch.Tensor,
655
+ top_k: int,
656
+ expert_map: torch.Tensor = None,
657
+ apply_router_weight_on_input: bool = False,
658
+ max_num_tokens: Optional[int] = None,
659
+ ) -> torch.Tensor:
660
+ """
661
+ Fused experts with top-k routing.
662
+
663
+ Args:
664
+ hidden_states: Hidden states of shape (num_tokens, hidden_size).
665
+ w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
666
+ w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
667
+ topk_weights: Routing weights of shape (num_tokens, top_k).
668
+ topk_ids: Selected expert IDs of shape (num_tokens, top_k).
669
+ top_k: Number of experts to select.
670
+ expert_map: Expert mapping of shape (num_experts,).
671
+
672
+ Returns:
673
+ hidden_states: Hidden states after routing.
674
+ """
675
+ """
676
+ # Check constraints.
677
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
678
+ assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
679
+ assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
680
+ assert w1.is_contiguous(), "Expert weights1 must be contiguous"
681
+ assert w2.is_contiguous(), "Expert weights2 must be contiguous"
682
+ """
683
+ # if torch.distributed.get_rank() == 0:
684
+ # print(w1.shape)
685
+ # print(hidden_states.shape)
686
+
687
+ original_shape = hidden_states.shape
688
+ # assert len(original_shape) == 2
689
+
690
+ num_tokens = hidden_states.shape[:-1].numel()
691
+ num_experts = w1.shape[0]
692
+ dtype = hidden_states.dtype
693
+ device = hidden_states.device
694
+ # assert dtype in [torch.float32, torch.float16, torch.bfloat16
695
+ # ], "Only float32, float16, and bfloat16 are supported"
696
+
697
+ if apply_router_weight_on_input:
698
+ assert (topk_weights.dim() == 2
699
+ ), "`topk_weights` should be in shape (num_tokens, topk)"
700
+ _, topk = topk_weights.shape
701
+ assert (
702
+ topk == 1
703
+ ), "Only support topk=1 when `apply_router_weight_on_input` is True"
704
+ hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
705
+
706
+ if expert_map is not None:
707
+ # Generate token indices and flatten
708
+ token_indices = (torch.arange(num_tokens,
709
+ device=device,
710
+ dtype=torch.int64).unsqueeze(1).expand(
711
+ -1, top_k).reshape(-1))
712
+
713
+ # Flatten token-to-expert mappings and map to local experts
714
+ weights_flat = topk_weights.view(-1)
715
+ experts_flat = topk_ids.view(-1)
716
+ local_experts_flat = expert_map[experts_flat]
717
+
718
+ # Filter valid token-expert pairs
719
+ mask = local_experts_flat != -1
720
+ filtered_weights = torch.where(
721
+ mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype)
722
+ filtered_experts = torch.where(
723
+ mask, local_experts_flat,
724
+ torch.full_like(local_experts_flat,
725
+ num_experts)).to(topk_ids.dtype)
726
+
727
+ # Sort by local expert IDs
728
+ sort_indices = torch.argsort(filtered_experts.view(torch.float32))
729
+ sorted_token_indices = token_indices[sort_indices]
730
+ sorted_weights = filtered_weights[sort_indices]
731
+
732
+ # Compute token counts with minlength of num_experts
733
+ # This is equivalent to but faster than:
734
+ # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
735
+ token_counts = torch.zeros(num_experts + 1,
736
+ device=device,
737
+ dtype=torch.int64)
738
+ ones = torch.ones_like(filtered_experts, dtype=torch.int64)
739
+ token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
740
+ token_counts = token_counts[:num_experts]
741
+ expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64)
742
+
743
+ # Rearrange hidden_states
744
+ sorted_hidden_states = hidden_states[sorted_token_indices]
745
+ else:
746
+ row_idx_len = num_tokens * top_k
747
+ row_idx = (torch.arange(0,
748
+ row_idx_len,
749
+ dtype=torch.int32,
750
+ device=device).view(top_k, -1).permute(
751
+ 1, 0).contiguous())
752
+ active_num = max_num_tokens if max_num_tokens is not None else num_tokens
753
+ sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
754
+ hidden_states,
755
+ row_idx=row_idx,
756
+ expert_idx=topk_ids,
757
+ active_num=active_num)
758
+
759
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
760
+ expanded_expert_idx, num_experts)
761
+ expert_tokens = expert_tokens.to(torch.int64)
762
+
763
+ w1 = w1.transpose(1, 2)
764
+ gate_up_out_list = torch_npu.npu_grouped_matmul(
765
+ x=[sorted_hidden_states],
766
+ weight=[w1],
767
+ split_item=2,
768
+ group_list_type=0,
769
+ group_type=0,
770
+ group_list=expert_tokens,
771
+ )
772
+
773
+ # TODO: Remove this in the future.
774
+ gate_up_out = torch.cat(gate_up_out_list, dim=0)
775
+ gate_up_out = torch_npu.npu_swiglu(gate_up_out)
776
+
777
+ w2 = w2.transpose(1, 2)
778
+ down_out_list = torch_npu.npu_grouped_matmul(
779
+ x=[gate_up_out],
780
+ weight=[w2],
781
+ split_item=2,
782
+ group_list_type=0,
783
+ group_type=0,
784
+ group_list=expert_tokens,
785
+ )
786
+
787
+ down_out_list = torch.cat(down_out_list, dim=0)
788
+
789
+ if expert_map is not None:
790
+ weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)
791
+
792
+ final_hidden_states = torch.zeros(*original_shape,
793
+ device=hidden_states.device,
794
+ dtype=dtype)
795
+
796
+ # TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
797
+ # This created multiple NaN and index_add_ will mix them up which harms accuracy
798
+ # remove this mask and filter after it being fixed
799
+ num_valid_tokens = mask.sum()
800
+ valid_token_mask = torch.arange(
801
+ 0, sorted_token_indices.shape[0],
802
+ device=device).unsqueeze(1) < num_valid_tokens
803
+ valid_output = torch.where(
804
+ valid_token_mask, weighted_down_out,
805
+ torch.zeros_like(weighted_down_out)).to(dtype)
806
+ final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
807
+ else:
808
+ scales = torch.ones_like(
809
+ topk_weights) if apply_router_weight_on_input else topk_weights
810
+ # TODO: Reorder device memory 2 times here, replace the current
811
+ # implementation here when suitable operators become available.
812
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
813
+ down_out_list,
814
+ skip1=None,
815
+ skip2=None,
816
+ bias=None,
817
+ scales=scales,
818
+ expanded_src_to_dst_row=expanded_row_idx,
819
+ export_for_source_row=topk_ids,
820
+ )
821
+
822
+ return final_hidden_states
823
+
824
+
825
+ def fused_experts_allgather_ep(
826
+ hidden_states: torch.Tensor,
827
+ w1: torch.Tensor,
828
+ w2: torch.Tensor,
829
+ topk_weights: torch.Tensor,
830
+ topk_ids: torch.Tensor,
831
+ is_prefill: bool
832
+ ):
833
+ local_rank = torch.distributed.get_rank(group=get_ep_group().device_group)
834
+ num_experts_per_ep = w1.shape[0]
835
+ local_expert_indices_offset = local_rank * num_experts_per_ep
836
+ global_local_mask = (topk_ids >= local_expert_indices_offset) & \
837
+ (topk_ids <= local_expert_indices_offset + num_experts_per_ep - 1)
838
+ non_global_local_mask = (~global_local_mask).to(torch.int32)
839
+ global_local_mask = global_local_mask.to(torch.int32)
840
+ row_idx = torch.arange(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32).view(
841
+ -1, topk_ids.shape[0]).transpose(0, 1).contiguous()
842
+
843
+ topk_ids -= local_expert_indices_offset
844
+ local_topk_ids_mask_with_max = topk_ids * global_local_mask + non_global_local_mask * num_experts_per_ep
845
+ sorted_tokens, expanded_src_to_dst_row, expanded_expert_idx = torch_npu.npu_moe_init_routing(
846
+ x=hidden_states,
847
+ row_idx=row_idx,
848
+ expert_idx=local_topk_ids_mask_with_max,
849
+ active_num=topk_ids.shape[0]*topk_ids.shape[1]
850
+ )
851
+ if expanded_expert_idx.shape[0] > 8192:
852
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(expanded_expert_idx, num_experts_per_ep + 1)
853
+ expert_tokens = expert_tokens[:-1]
854
+ else:
855
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(expanded_expert_idx, num_experts_per_ep)
856
+ expert_tokens = expert_tokens.to(torch.int64)
857
+
858
+ w1 = w1.transpose(1, 2)
859
+ gate_up_out = torch_npu.npu_grouped_matmul(
860
+ x=[sorted_tokens],
861
+ weight=[w1],
862
+ group_list=expert_tokens,
863
+ split_item=3,
864
+ group_type=0
865
+ )[0]
866
+ gate_up_out = torch_npu.npu_swiglu(gate_up_out)
867
+
868
+ w2 = w2.transpose(1, 2)
869
+ down_out = torch_npu.npu_grouped_matmul(
870
+ x=[gate_up_out],
871
+ weight=[w2],
872
+ group_list=expert_tokens,
873
+ split_item=3,
874
+ group_type=0
875
+ )[0]
876
+
877
+ if is_prefill:
878
+ down_out[expert_tokens[-1]:] = 0
879
+ else:
880
+ sorted_tokens_mask = expanded_expert_idx != num_experts_per_ep
881
+ down_out *= sorted_tokens_mask.unsqueeze(1)
882
+
883
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
884
+ expanded_permuted_rows=down_out,
885
+ skip1=None,
886
+ skip2=None,
887
+ bias=None,
888
+ scales=topk_weights.to(down_out.dtype),
889
+ expanded_src_to_dst_row=expanded_src_to_dst_row,
890
+ export_for_source_row=topk_ids
891
+ )
892
+ return final_hidden_states
893
+
894
+
895
+ def select_gating_top_k_softmax_experts(
896
+ hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int,
897
+ renormalize: bool) -> tuple[torch.Tensor, torch.Tensor]:
898
+ """
899
+ Select top-k experts based on router logits.
900
+ only supports float16、bfloat16、float32
901
+
902
+ Args:
903
+ hidden_states: Hidden states of shape (num_tokens, hidden_size).
904
+ router_logits: Router logits of shape (num_tokens, num_experts).
905
+ top_k: Number of experts to select.
906
+ renormalize: Whether to renormalize the routing weights.
907
+
908
+ Returns:
909
+ topk_weights: Routing weights of shape (num_tokens, top_k).
910
+ topk_ids: Selected expert IDs of shape (num_tokens, top_k).
911
+
912
+ Raises:
913
+ ValueError: If an unsupported scoring function is provided.
914
+ """
915
+ topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
916
+ router_logits, None, k=top_k)
917
+
918
+ # # Required by npu_moe_init_routing
919
+ # topk_weights = topk_weights.to(hidden_states.dtype)
920
+ # topk_ids = topk_ids.to(torch.int32)
921
+
922
+ if renormalize:
923
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
924
+
925
+ return topk_weights, topk_ids
926
+
927
+
928
+ def native_grouped_topk(
929
+ topk_weights: torch.Tensor,
930
+ num_expert_group: Optional[int],
931
+ topk_group: Optional[int],
932
+ ):
933
+ topk_group = 0 if topk_group is None else topk_group
934
+ num_expert_group = 0 if num_expert_group is None else num_expert_group
935
+
936
+ num_token = topk_weights.shape[0]
937
+ grouped_weights = topk_weights.view(num_token, num_expert_group,
938
+ -1).max(dim=-1).values
939
+ topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
940
+ k=topk_group,
941
+ dim=-1,
942
+ sorted=False)[1]
943
+ topk_group_mask = torch.zeros_like(grouped_weights)
944
+ topk_group_mask.scatter_(1, topk_group_indices, 1)
945
+ topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
946
+ num_token, num_expert_group,
947
+ topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
948
+ topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
949
+
950
+ return topk_weights
951
+
952
+
953
+ def select_experts(
954
+ hidden_states: torch.Tensor,
955
+ router_logits: torch.Tensor,
956
+ top_k: int,
957
+ use_grouped_topk: bool,
958
+ renormalize: bool,
959
+ topk_group: Optional[int] = None,
960
+ num_expert_group: Optional[int] = None,
961
+ custom_routing_function: Optional[Callable] = None,
962
+ scoring_func: str = "softmax",
963
+ e_score_correction_bias: Optional[torch.Tensor] = None,
964
+ global_num_experts: Optional[torch.Tensor] = None
965
+ ) -> tuple[torch.Tensor, torch.Tensor]:
966
+ """
967
+ Select top-k experts based on router logits.
968
+
969
+ Args:
970
+ hidden_states: Hidden states of shape (num_tokens, hidden_size).
971
+ router_logits: Router logits of shape (num_tokens, num_experts).
972
+ top_k: Number of experts to select.
973
+ use_grouped_topk: Whether to group experts before selecting top-k.
974
+ renormalize: Whether to renormalize the routing weights.
975
+ topk_group: Number of expert groups to select from.
976
+ num_expert_group: Number of experts in each group.
977
+ custom_routing_function: Custom routing function.
978
+ scoring_func: Scoring function to use.
979
+ e_score_correction_bias: Correction bias to apply to expert scores.
980
+
981
+ Returns:
982
+ topk_weights: Routing weights of shape (num_tokens, top_k).
983
+ topk_ids: Selected expert IDs of shape (num_tokens, top_k).
984
+
985
+ Raises:
986
+ ValueError: If an unsupported scoring function is provided.
987
+ """
988
+
989
+ if scoring_func == "softmax":
990
+ # NOTE: vLLM use dtype=torch.float here
991
+ topk_weights = router_logits.softmax(dim=-1)
992
+ elif scoring_func == "sigmoid":
993
+ topk_weights = router_logits.sigmoid()
994
+ else:
995
+ raise ValueError(f"Unsupported scoring function: {scoring_func}")
996
+
997
+ if use_grouped_topk:
998
+ assert topk_group is not None
999
+ assert num_expert_group is not None
1000
+
1001
+ if e_score_correction_bias is not None:
1002
+ # Store original scores before applying correction bias. We use biased
1003
+ # scores for expert selection but original scores for routing weights
1004
+ original_weights = topk_weights
1005
+ topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
1006
+
1007
+ # TODO: Change to npu_group_topk when the latest CANN and NNAL is available
1008
+ # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
1009
+ topk_weights = native_grouped_topk(topk_weights, num_expert_group,
1010
+ topk_group)
1011
+ # TODO bfloat16 is not supported in torch.topk with ge graph.
1012
+ if e_score_correction_bias is not None:
1013
+ topk_ids = torch.topk(topk_weights.to(torch.float32),
1014
+ k=top_k,
1015
+ dim=-1,
1016
+ sorted=False)[1]
1017
+ # Use original unbiased scores for the routing weights
1018
+ topk_weights = original_weights.gather(1, topk_ids)
1019
+ else:
1020
+ topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
1021
+ k=top_k,
1022
+ dim=-1,
1023
+ sorted=False)
1024
+ elif custom_routing_function is None:
1025
+ topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
1026
+ else:
1027
+ topk_weights, topk_ids = custom_routing_function(
1028
+ hidden_states=hidden_states,
1029
+ gating_output=router_logits,
1030
+ topk=top_k,
1031
+ renormalize=renormalize,
1032
+ global_num_experts=global_num_experts)
1033
+ # Required by npu_moe_init_routing
1034
+ topk_ids = topk_ids.to(torch.int32)
1035
+ return topk_weights, topk_ids
1036
+
1037
+ # Required by npu_moe_init_routing
1038
+ topk_ids = topk_ids.to(torch.int32)
1039
+
1040
+ if renormalize:
1041
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
1042
+
1043
+ return topk_weights, topk_ids
1044
+
1045
+
1046
+ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
1047
+
1048
+ def __init__(self, moe: FusedMoEConfig = None):
1049
+
1050
+ super().__init__(moe=moe)
1051
+ vllm_config = get_current_vllm_config()
1052
+
1053
+ self.ep_group = get_ep_group()
1054
+ self.ep_size = self.ep_group.world_size
1055
+ self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
1056
+ self.local_batch_size = self.global_batch_size // self.ep_size
1057
+ self.max_model_len = vllm_config.model_config.max_model_len
1058
+
1059
+ ascend_config = get_ascend_config()
1060
+ self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
1061
+
1062
+ try:
1063
+ device_group = self.ep_group.device_group
1064
+ # TODO: Try local_rank = ep_group.rank_in_group
1065
+ local_rank = torch.distributed.get_rank(group=device_group)
1066
+ backend = device_group._get_backend(torch.device("npu"))
1067
+ self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
1068
+ local_rank)
1069
+ except AttributeError:
1070
+ self.moe_all_to_all_group_name = None
1071
+
1072
+ def process_weights_after_loading(self, layer):
1073
+ super(UnquantizedFusedMoEMethod,
1074
+ self).process_weights_after_loading(layer)
1075
+ layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
1076
+ layer.w13_weight.data),
1077
+ requires_grad=False)
1078
+ layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
1079
+ layer.w2_weight.data),
1080
+ requires_grad=False)
1081
+
1082
+ def apply(
1083
+ self,
1084
+ layer: torch.nn.Module,
1085
+ x: torch.Tensor,
1086
+ router_logits: torch.Tensor,
1087
+ top_k: int,
1088
+ renormalize: bool,
1089
+ use_grouped_topk: bool = False,
1090
+ global_num_experts: int = -1,
1091
+ expert_map: Optional[torch.Tensor] = None,
1092
+ topk_group: Optional[int] = None,
1093
+ num_expert_group: Optional[int] = None,
1094
+ custom_routing_function: Optional[Callable] = None,
1095
+ scoring_func: str = "softmax",
1096
+ e_score_correction_bias: Optional[torch.Tensor] = None,
1097
+ is_prefill: bool = False,
1098
+ enable_force_load_balance: bool = False,
1099
+ shared_experts: Optional[Any] = None,
1100
+ **kwargs,
1101
+ ) -> torch.Tensor:
1102
+ use_grouped_topk = (topk_group > 1 or num_expert_group > 1)
1103
+ is_deepseek_v3_r1 = global_num_experts == 256
1104
+ # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
1105
+ if use_grouped_topk and is_deepseek_v3_r1:
1106
+ topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
1107
+ router_logits,
1108
+ k=top_k, # topk当前写8
1109
+ bias=e_score_correction_bias,
1110
+ k_group=topk_group, # fix: 4
1111
+ group_count=num_expert_group, # fix 8
1112
+ group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
1113
+ renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
1114
+ norm_type=1, # 0: softmax; 1: sigmoid(fix)
1115
+ # out_flag=False, # todo new api; 第三个输出是否输出
1116
+ # y2_flag=False, # old api; 第三个输出是否输出
1117
+ routed_scaling_factor=1,
1118
+ eps=float(1e-20))
1119
+ elif use_grouped_topk and SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
1120
+ topk_weights, topk_ids = select_gating_top_k_softmax_experts(
1121
+ hidden_states=x,
1122
+ router_logits=router_logits,
1123
+ top_k=top_k,
1124
+ renormalize=renormalize)
1125
+ else:
1126
+ topk_weights, topk_ids = select_experts(
1127
+ hidden_states=x,
1128
+ router_logits=router_logits,
1129
+ top_k=top_k,
1130
+ use_grouped_topk=use_grouped_topk,
1131
+ renormalize=renormalize,
1132
+ topk_group=topk_group,
1133
+ num_expert_group=num_expert_group,
1134
+ custom_routing_function=custom_routing_function,
1135
+ scoring_func=scoring_func,
1136
+ e_score_correction_bias=e_score_correction_bias,
1137
+ )
1138
+
1139
+ topk_weights = topk_weights.to(x.dtype)
1140
+ # this is a naive implementation for experts load balance so as
1141
+ # to avoid accumulating too much tokens on a single rank.
1142
+ # currently it is only activated when doing profile runs.
1143
+ if enable_force_load_balance:
1144
+ topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
1145
+
1146
+ fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
1147
+ is_prefill, is_deepseek_v3_r1)
1148
+ if fused_moe_state == FusedMoEState.MC2:
1149
+ return fused_experts_with_mc2(
1150
+ hidden_states=x,
1151
+ w1=layer.w13_weight,
1152
+ w2=layer.w2_weight,
1153
+ topk_weights=topk_weights,
1154
+ topk_ids=topk_ids,
1155
+ top_k=top_k,
1156
+ expert_map=expert_map,
1157
+ moe_all_to_all_group_name=self.moe_all_to_all_group_name,
1158
+ shared_experts=shared_experts)
1159
+ elif fused_moe_state == FusedMoEState.AllGatherEP:
1160
+ return fused_experts_allgather_ep(
1161
+ hidden_states=x,
1162
+ w1=layer.w13_weight,
1163
+ w2=layer.w2_weight,
1164
+ topk_weights=topk_weights,
1165
+ topk_ids=topk_ids,
1166
+ is_prefill=is_prefill)
1167
+ elif fused_moe_state in [
1168
+ FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
1169
+ ]:
1170
+ return fused_experts(hidden_states=x,
1171
+ w1=layer.w13_weight,
1172
+ w2=layer.w2_weight,
1173
+ topk_weights=topk_weights,
1174
+ topk_ids=topk_ids,
1175
+ top_k=top_k,
1176
+ expert_map=expert_map)
1177
+ elif MOE_ALL2ALL_BUFFER:
1178
+ return fused_experts_with_all2all_buffer(
1179
+ hidden_states=x,
1180
+ w1=layer.w13_weight,
1181
+ w2=layer.w2_weight,
1182
+ topk_weights=topk_weights,
1183
+ topk_ids=topk_ids,
1184
+ top_k=top_k,
1185
+ max_model_len=self.max_model_len,
1186
+ global_batch_size=self.global_batch_size,
1187
+ expert_map=expert_map,
1188
+ ep_group=get_ep_group())
1189
+ else:
1190
+ return fused_experts_with_all2all(hidden_states=x,
1191
+ w1=layer.w13_weight,
1192
+ w2=layer.w2_weight,
1193
+ topk_weights=topk_weights,
1194
+ topk_ids=topk_ids,
1195
+ top_k=top_k,
1196
+ expert_map=expert_map,
1197
+ ep_group=get_ep_group())
1198
+
1199
+
1200
+ class AscendFusedMoE(FusedMoE):
1201
+
1202
+ # The moe_counter parameter is required during the initialization of EPLB
1203
+ # to identify the current layer index within the MOE model.
1204
+ moe_counter = -1
1205
+
1206
+ def __init__(
1207
+ self,
1208
+ num_experts: int, # Global number of experts
1209
+ top_k: int,
1210
+ hidden_size: int,
1211
+ intermediate_size: int,
1212
+ params_dtype: Optional[torch.dtype] = None,
1213
+ reduce_results: bool = False,
1214
+ renormalize: bool = True,
1215
+ use_grouped_topk: bool = False,
1216
+ num_expert_group: Optional[int] = None,
1217
+ topk_group: Optional[int] = None,
1218
+ quant_config: Optional[QuantizationConfig] = None,
1219
+ tp_size: Optional[int] = None,
1220
+ ep_size: Optional[int] = None,
1221
+ dp_size: Optional[int] = None,
1222
+ prefix: str = "",
1223
+ custom_routing_function: Optional[Callable] = None,
1224
+ scoring_func: str = "softmax",
1225
+ e_score_correction_bias: Optional[torch.Tensor] = None,
1226
+ activation: str = "silu",
1227
+ apply_router_weight_on_input: bool = False,
1228
+ ):
1229
+ # TODO: This could not initialize FusedMoE baseclass,
1230
+ # fixme and make __init__() of AscendFusedMoE more clear
1231
+ super(FusedMoE, self).__init__()
1232
+
1233
+ AscendFusedMoE.moe_counter += 1
1234
+ self.moe_instance_id = AscendFusedMoE.moe_counter
1235
+
1236
+ if params_dtype is None:
1237
+ params_dtype = torch.get_default_dtype()
1238
+
1239
+ vllm_config = get_current_vllm_config()
1240
+
1241
+ self.moe_parallel_config = FusedMoEParallelConfig.make(
1242
+ tp_size_=(tp_size if tp_size is not None else
1243
+ get_tensor_model_parallel_world_size()),
1244
+ dp_size_=(dp_size
1245
+ if dp_size is not None else get_dp_group().world_size),
1246
+ vllm_parallel_config=vllm_config.parallel_config)
1247
+
1248
+ self.top_k = top_k
1249
+ self.num_experts = num_experts
1250
+ self.global_num_experts = num_experts
1251
+ assert intermediate_size % self.tp_size == 0
1252
+ self.intermediate_size_per_partition = intermediate_size // self.tp_size
1253
+ self.reduce_results = reduce_results
1254
+ self.renormalize = renormalize
1255
+ self.use_grouped_topk = use_grouped_topk
1256
+ if self.use_grouped_topk:
1257
+ assert num_expert_group is not None and topk_group is not None
1258
+ self.num_expert_group = num_expert_group
1259
+ self.topk_group = topk_group
1260
+ self.custom_routing_function = custom_routing_function
1261
+ self.scoring_func = scoring_func
1262
+ self.e_score_correction_bias = e_score_correction_bias
1263
+ self.expert_map = None
1264
+ self.activation = activation
1265
+ self.log2phy = None
1266
+ self.global_redundant_expert_num = 0
1267
+
1268
+ is_deepseek_v3_r1 = self.global_num_experts == 256
1269
+ self.rm_router_logits = get_rm_router_logits_state(
1270
+ self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1)
1271
+ self.all_reduce_merge = get_all_reduce_merge_state(
1272
+ self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
1273
+
1274
+ ascend_config = get_ascend_config()
1275
+ expert_map_path = ascend_config.expert_map_path
1276
+ if expert_map_path and os.path.exists(expert_map_path):
1277
+ # moe expert load balance
1278
+ expert_load_balancer = ExpertLoadBalancer(expert_map_path,
1279
+ self.global_num_experts)
1280
+ self.local_num_experts, self.expert_map = \
1281
+ expert_load_balancer.get_rank_placement_map(
1282
+ self.moe_instance_id,
1283
+ get_ep_group().rank_in_group)
1284
+ self.log2phy = expert_load_balancer.get_rank_log2phy_map(
1285
+ self.moe_instance_id,
1286
+ get_ep_group().rank_in_group)
1287
+ self.global_redundant_expert_num = \
1288
+ expert_load_balancer.get_global_redundant_expert_num()
1289
+ else:
1290
+ # Create a tensor of size num_experts filled with -1
1291
+ self.local_num_experts, self.expert_map = determine_expert_map(
1292
+ self.ep_size,
1293
+ get_ep_group().rank_in_group, self.global_num_experts)
1294
+
1295
+ self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
1296
+ self.enable_multistream_moe = \
1297
+ ascend_config.torchair_graph_config.enable_multistream_moe
1298
+
1299
+ if self.scoring_func != "softmax" and not self.use_grouped_topk:
1300
+ raise ValueError("Only softmax scoring function is supported for "
1301
+ "non-grouped topk.")
1302
+ moe = FusedMoEConfig.make(
1303
+ num_experts=self.global_num_experts,
1304
+ experts_per_token=top_k,
1305
+ hidden_dim=hidden_size,
1306
+ num_local_experts=self.local_num_experts,
1307
+ moe_parallel_config=self.moe_parallel_config,
1308
+ # TODO (bnell): this needs to be fixed for quantized types.
1309
+ in_dtype=params_dtype,
1310
+ quant_config=quant_config)
1311
+
1312
+ if quant_config is None:
1313
+ self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
1314
+ else:
1315
+ self.quant_method = quant_config.get_quant_method(self, prefix)
1316
+
1317
+ assert self.quant_method is not None
1318
+
1319
+ local_num_experts = torch.sum(self.expert_map != -1) \
1320
+ if self.expert_map is not None else num_experts
1321
+
1322
+ moe_quant_params = {
1323
+ "num_experts": local_num_experts,
1324
+ "hidden_size": hidden_size,
1325
+ "intermediate_size_per_partition":
1326
+ self.intermediate_size_per_partition,
1327
+ "params_dtype": params_dtype,
1328
+ "weight_loader": self.weight_loader,
1329
+ }
1330
+ # need full intermediate size pre-sharding for WNA16 act order
1331
+ if (self.quant_method.__class__.__name__
1332
+ in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
1333
+ moe_quant_params["intermediate_size_full"] = intermediate_size
1334
+
1335
+ self.ep_group = get_ep_group()
1336
+ # NOTE: self.tp_group is not expert_tp_group
1337
+ self.tp_group = get_tp_group().device_group
1338
+ self.quant_method.create_weights(layer=self, **moe_quant_params)
1339
+
1340
+ def naive_multicast(self, x: torch.Tensor,
1341
+ cu_tokens_across_dp_cpu: torch.Tensor):
1342
+ assert (len(x.shape) == 2)
1343
+ buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
1344
+ device=x.device,
1345
+ dtype=x.dtype)
1346
+ start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
1347
+ self.dp_rank - 1]
1348
+ end = cu_tokens_across_dp_cpu[self.dp_rank]
1349
+ buffer[start:end, :].copy_(x)
1350
+ for idx in range(self.dp_size):
1351
+ start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
1352
+ end = cu_tokens_across_dp_cpu[idx]
1353
+ get_dp_group().broadcast(buffer[start:end, :], idx)
1354
+ return buffer
1355
+
1356
+ def forward(self,
1357
+ hidden_states: torch.Tensor,
1358
+ router_logits: torch.Tensor,
1359
+ is_prefill: bool,
1360
+ enable_force_load_balance: bool = False,
1361
+ top_k: Optional[int] = None,
1362
+ shared_experts: Optional[Any] = None,
1363
+ gate=None,
1364
+ replace_allreduce: bool = False):
1365
+
1366
+ assert self.quant_method is not None
1367
+
1368
+ if top_k:
1369
+ real_top_k = top_k
1370
+ else:
1371
+ real_top_k = self.top_k
1372
+
1373
+ num_tokens, hidden_size = hidden_states.shape
1374
+ is_deepseek_v3_r1 = self.global_num_experts == 256
1375
+
1376
+ fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
1377
+ is_prefill, is_deepseek_v3_r1)
1378
+ if shared_experts:
1379
+ if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
1380
+ # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
1381
+ shared_hidden_states = shared_experts(hidden_states)
1382
+
1383
+ tp_size = get_tensor_model_parallel_world_size()
1384
+ if (tp_size > 1 and fused_moe_state not in [
1385
+ FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
1386
+ FusedMoEState.NaiveMulticast
1387
+ ] and not replace_allreduce):
1388
+ if num_tokens < tp_size:
1389
+ hidden_states = nn.functional.pad(
1390
+ hidden_states, (0, 0, 0, tp_size - num_tokens))
1391
+ router_logits = nn.functional.pad(
1392
+ router_logits, (0, 0, 0, tp_size - num_tokens))
1393
+ chunk_hidden_states = torch.tensor_split(hidden_states,
1394
+ tp_size,
1395
+ dim=0)
1396
+ chunk_router_logits = torch.tensor_split(router_logits,
1397
+ tp_size,
1398
+ dim=0)
1399
+ tp_rank = get_tensor_model_parallel_rank()
1400
+ hidden_states = chunk_hidden_states[tp_rank]
1401
+ router_logits = chunk_router_logits[tp_rank]
1402
+
1403
+ if self.dp_size > 1:
1404
+ if fused_moe_state in (FusedMoEState.AllGather, FusedMoEState.AllGatherEP):
1405
+ # NOTE: When in torchair graph, it has been padded in model_runner_v1
1406
+ if not self.torchair_graph_enabled or is_prefill:
1407
+ attn_metadata = get_forward_context().attn_metadata
1408
+ if attn_metadata is not None:
1409
+ max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
1410
+ if num_tokens < max_num_tokens_across_dp:
1411
+ hidden_states = nn.functional.pad(
1412
+ hidden_states,
1413
+ (0, 0, 0,
1414
+ max_num_tokens_across_dp - num_tokens))
1415
+ if not self.rm_router_logits:
1416
+ router_logits = nn.functional.pad(
1417
+ router_logits,
1418
+ (0, 0, 0,
1419
+ max_num_tokens_across_dp - num_tokens))
1420
+ hidden_states = get_dp_group().all_gather(hidden_states, 0)
1421
+ if self.rm_router_logits:
1422
+ router_logits, _ = gate(hidden_states.float())
1423
+ else:
1424
+ router_logits = get_dp_group().all_gather(router_logits, 0)
1425
+
1426
+ elif fused_moe_state == FusedMoEState.NaiveMulticast:
1427
+ cu_tokens_across_dp_cpu = get_forward_context(
1428
+ ).dp_metadata.cu_tokens_across_dp_cpu
1429
+ hidden_states = self.naive_multicast(hidden_states,
1430
+ cu_tokens_across_dp_cpu)
1431
+ if self.rm_router_logits:
1432
+ router_logits, _ = gate(hidden_states.float())
1433
+ else:
1434
+ router_logits = self.naive_multicast(
1435
+ router_logits, cu_tokens_across_dp_cpu)
1436
+
1437
+ # Matrix multiply.
1438
+ e_hidden_states = self.quant_method.apply(
1439
+ layer=self,
1440
+ x=hidden_states,
1441
+ router_logits=router_logits,
1442
+ top_k=real_top_k,
1443
+ renormalize=self.renormalize,
1444
+ use_grouped_topk=self.use_grouped_topk,
1445
+ global_num_experts=self.global_num_experts,
1446
+ expert_map=self.expert_map,
1447
+ topk_group=self.topk_group,
1448
+ num_expert_group=self.num_expert_group,
1449
+ custom_routing_function=self.custom_routing_function,
1450
+ scoring_func=self.scoring_func,
1451
+ e_score_correction_bias=self.e_score_correction_bias,
1452
+ is_prefill=is_prefill,
1453
+ enable_force_load_balance=enable_force_load_balance,
1454
+ log2phy=self.log2phy,
1455
+ global_redundant_expert_num=self.global_redundant_expert_num,
1456
+ shared_experts=shared_experts if self.torchair_graph_enabled
1457
+ and self.enable_multistream_moe and not is_prefill else None,
1458
+ )
1459
+
1460
+ if shared_experts:
1461
+ if isinstance(e_hidden_states, tuple):
1462
+ e_hidden_states, shared_hidden_states = e_hidden_states
1463
+
1464
+ if (tp_size > 1 and fused_moe_state not in [
1465
+ FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
1466
+ FusedMoEState.NaiveMulticast
1467
+ ] and not replace_allreduce):
1468
+ dist.all_gather(list(chunk_hidden_states), e_hidden_states,
1469
+ self.tp_group)
1470
+ final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
1471
+ if num_tokens < tp_size:
1472
+ final_hidden_states = final_hidden_states[:num_tokens]
1473
+ dispose_tensor(e_hidden_states)
1474
+ elif self.dp_size > 1:
1475
+ if fused_moe_state == FusedMoEState.NaiveMulticast:
1476
+ start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
1477
+ self.dp_rank - 1]
1478
+ end = cu_tokens_across_dp_cpu[self.dp_rank]
1479
+ final_hidden_states = get_dp_group().all_reduce(
1480
+ e_hidden_states)
1481
+ final_hidden_states = final_hidden_states[start:end, :]
1482
+ dispose_tensor(e_hidden_states)
1483
+ elif fused_moe_state in (FusedMoEState.AllGather, FusedMoEState.AllGatherEP):
1484
+ final_hidden_states = data_parallel_reduce_scatter(
1485
+ e_hidden_states, dim=0)
1486
+ final_hidden_states = final_hidden_states[:num_tokens]
1487
+ dispose_tensor(e_hidden_states)
1488
+ else:
1489
+ final_hidden_states = e_hidden_states
1490
+
1491
+ if tp_size > 1 and not self.all_reduce_merge and fused_moe_state in [
1492
+ FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
1493
+ FusedMoEState.NaiveMulticast
1494
+ ]:
1495
+ final_hidden_states = tensor_model_parallel_all_reduce(
1496
+ final_hidden_states)
1497
+
1498
+ if shared_experts:
1499
+ return final_hidden_states, shared_hidden_states
1500
+ else:
1501
+ return final_hidden_states
1502
+
1503
+ # ----------------------------------------- TBO-related --------------------------------------------
1504
+
1505
+ def _forward_ms_fused_moe_comp(
1506
+ self,
1507
+ hidden_states: torch.Tensor,
1508
+ router_logits: torch.Tensor,
1509
+ is_prefill: bool,
1510
+ real_top_k,
1511
+ enable_force_load_balance: bool = False,
1512
+ ):
1513
+ hidden_states = self.quant_method.apply(
1514
+ layer=self,
1515
+ x=hidden_states,
1516
+ router_logits=router_logits,
1517
+ top_k=real_top_k,
1518
+ renormalize=self.renormalize,
1519
+ use_grouped_topk=self.use_grouped_topk,
1520
+ global_num_experts=self.global_num_experts,
1521
+ expert_map=self.expert_map,
1522
+ topk_group=self.topk_group,
1523
+ num_expert_group=self.num_expert_group,
1524
+ custom_routing_function=self.custom_routing_function,
1525
+ scoring_func=self.scoring_func,
1526
+ e_score_correction_bias=self.e_score_correction_bias,
1527
+ is_prefill=is_prefill,
1528
+ enable_force_load_balance=enable_force_load_balance)
1529
+
1530
+ return hidden_states
inference/vllm_ascend/patch/worker/patch_common/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # This file is a part of the vllm-ascend project.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ # patch_utils should be the first import, because it will be used by other
19
+ # patch files.
20
+ import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
21
+ import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
22
+ import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
23
+ import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
24
+ import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
25
+ import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
26
+ import vllm_ascend.patch.worker.patch_common.patch_config # noqa
27
+ import vllm_ascend.patch.worker.patch_common.patch_parsers # noqa
inference/vllm_ascend/patch/worker/patch_common/patch_config.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # This file is a part of the vllm-ascend project.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ from vllm.config import ModelConfig
18
+
19
+
20
+ def get_attr_by_names(src_config, attrs, default_value):
21
+ for attr in attrs:
22
+ value = getattr(src_config, attr, 0)
23
+ if value > 0:
24
+ return value
25
+ return default_value
26
+
27
+
28
+ def _verify_with_expert_parallelism(self) -> None:
29
+ num_expert_names = [
30
+ "moe_num_experts", # Dbrx
31
+ "num_experts", # Jamba
32
+ "n_routed_experts", # DeepSeek
33
+ "num_local_experts", # Mixtral
34
+ "num_routed_experts", # Pangu
35
+ ]
36
+ num_experts = 0
37
+ for name in num_expert_names:
38
+ num_experts = getattr(self.hf_text_config, name, 0)
39
+ if num_experts > 0:
40
+ break
41
+ if num_experts < 1:
42
+ raise ValueError(
43
+ "Number of experts in the model must be greater than 0 "
44
+ "when expert parallelism is enabled.")
45
+
46
+
47
+ @property
48
+ def is_deepseek_mla(self) -> bool:
49
+ kv_lora_dim_names = ['attention_kv_lora_dim', 'kv_lora_rank']
50
+ kv_lora_dim = get_attr_by_names(self.hf_text_config, kv_lora_dim_names, None)
51
+ if not hasattr(self.hf_text_config, "model_type"):
52
+ return False
53
+ elif self.hf_text_config.model_type in \
54
+ ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'pangu_ultra_moe'):
55
+ return kv_lora_dim is not None
56
+ elif self.hf_text_config.model_type == 'eagle':
57
+ # if the model is an EAGLE module, check for the
58
+ # underlying architecture
59
+ return self.hf_text_config.model.model_type in \
60
+ ('deepseek_v2', 'deepseek_v3', 'pangu_ultra_moe') \
61
+ and kv_lora_dim is not None
62
+ return False
63
+
64
+
65
+ def get_head_size(self) -> int:
66
+ if self.is_deepseek_mla:
67
+ qk_rope_dim_names = ['attention_qk_rope_dim', 'qk_rope_head_dim']
68
+ kv_lora_dim_names = ['attention_kv_lora_dim', 'kv_lora_rank']
69
+ qk_rope_dim = get_attr_by_names(self.hf_text_config, qk_rope_dim_names, 0)
70
+ kv_lora_dim = get_attr_by_names(self.hf_text_config, kv_lora_dim_names, 0)
71
+ if self.use_mla:
72
+ return kv_lora_dim + qk_rope_dim
73
+ else:
74
+ qk_dim_names = ['attention_qk_dim', 'qk_nope_head_dim']
75
+ qk_dim = get_attr_by_names(self.hf_text_config, qk_dim_names, 0)
76
+ if qk_rope_dim and qk_dim:
77
+ return qk_rope_dim + qk_dim
78
+ if hasattr(self.hf_text_config,
79
+ "model_type") and (self.hf_text_config.model_type
80
+ == "zamba2"):
81
+ return self.hf_text_config.attention_head_dim
82
+
83
+ if self.is_attention_free:
84
+ return 0
85
+
86
+ # NOTE: Some configs may set head_dim=None in the config
87
+ if getattr(self.hf_text_config, "head_dim", None) is not None:
88
+ return self.hf_text_config.head_dim
89
+
90
+ # FIXME(woosuk): This may not be true for all models.
91
+ return (self.hf_text_config.hidden_size //
92
+ self.hf_text_config.num_attention_heads)
93
+
94
+
95
+ ModelConfig._verify_with_expert_parallelism = _verify_with_expert_parallelism
96
+ ModelConfig.is_deepseek_mla = is_deepseek_mla
97
+ ModelConfig.get_head_size = get_head_size
inference/vllm_ascend/patch/worker/patch_common/patch_parsers.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # This file is a part of the vllm-ascend project.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+
19
+ from vllm.entrypoints.openai import tool_parsers
20
+ from vllm_ascend.entrypoints.openai.tool_parsers import PanguToolParser
21
+ tool_parsers.__all__.append("PanguToolParser")
22
+
23
+
24
+ from vllm import reasoning
25
+ from vllm_ascend.entrypoints.openai.reasoning_parsers import PanguReasoningParser
26
+ reasoning.__all__.append("PanguReasoningParser")
inference/vllm_ascend/patch/worker/patch_common/patch_sampler.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # This file is a part of the vllm-ascend project.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ #
18
+
19
+ from typing import Optional
20
+
21
+ import torch
22
+ import torch_npu
23
+ from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
24
+ from vllm.v1.sample.sampler import Sampler, _SAMPLING_EPS
25
+ from vllm.v1.sample.metadata import SamplingMetadata
26
+ from vllm_ascend import envs
27
+
28
+
29
+ def apply_top_k_top_p(
30
+ logits: torch.Tensor,
31
+ k: torch.Tensor,
32
+ p: torch.Tensor,
33
+ ) -> torch.Tensor:
34
+ if p is not None and k is not None:
35
+ # npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
36
+ return torch_npu.npu_top_k_top_p(logits, p, k)
37
+
38
+ probs = logits.softmax(dim=-1)
39
+ probs_sort, _ = probs.sort(dim=-1, descending=False)
40
+
41
+ if k is not None:
42
+ top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
43
+ top_k_count = top_k_count.unsqueeze(dim=1)
44
+ top_k_cutoff = probs_sort.gather(-1, top_k_count)
45
+
46
+ # Make sure the no top-k rows are no-op.
47
+ no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
48
+ top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
49
+
50
+ elements_to_discard = probs < top_k_cutoff
51
+ logits.masked_fill_(elements_to_discard, -float("inf"))
52
+
53
+ if p is not None:
54
+ cumprob = torch.cumsum(probs_sort, dim=-1)
55
+ top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
56
+ top_p_mask[:, -1] = False # at least one
57
+
58
+ top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
59
+ top_p_cutoff = probs_sort.gather(-1, top_p_count)
60
+ elements_to_discard = probs < top_p_cutoff
61
+ logits.masked_fill_(elements_to_discard, -float("inf"))
62
+
63
+ return logits
64
+
65
+
66
+ def topk_topp_forward_native(
67
+ self,
68
+ logits: torch.Tensor,
69
+ generators: dict[int, torch.Generator],
70
+ k: Optional[torch.Tensor],
71
+ p: Optional[torch.Tensor],
72
+ ) -> torch.Tensor:
73
+ """
74
+ PyTorch-native implementation of top-k and top-p sampling.
75
+
76
+ The logits tensor may be updated in-place.
77
+ """
78
+ logits = apply_top_k_top_p(logits, k, p)
79
+ probs = logits.softmax(dim=-1, dtype=torch.float32)
80
+ return random_sample(probs, generators)
81
+
82
+
83
+ def apply_top_n_sigma(
84
+ logits: torch.Tensor,
85
+ sampling_metadata: SamplingMetadata,
86
+ ):
87
+ if sampling_metadata.no_top_n_sigma:
88
+ return logits
89
+
90
+ top_n_sigma = sampling_metadata.top_n_sigma[:, None]
91
+ top_n_sigma_mask = (top_n_sigma != -1)
92
+ filter_value = -3.4028e+38
93
+ max_vals, _ = logits.max(dim=-1, keepdim=True)
94
+ std_vals = logits.std(dim=-1, keepdim=True)
95
+ threshold = max_vals - top_n_sigma * std_vals
96
+ threshold[~top_n_sigma_mask] = filter_value
97
+ mask = (logits < threshold)
98
+ logits = torch.where(mask, filter_value, logits)
99
+ return logits
100
+
101
+
102
+ def sample(
103
+ self,
104
+ logits: torch.Tensor,
105
+ sampling_metadata: SamplingMetadata,
106
+ ) -> torch.Tensor:
107
+ """Sample logits based on sampling metadata.
108
+
109
+ The various logits processing functions called in this method
110
+ may update the logits tensor in-place.
111
+ """
112
+
113
+ assert not (sampling_metadata.all_greedy
114
+ and sampling_metadata.all_random)
115
+ if sampling_metadata.all_random:
116
+ greedy_sampled = None
117
+ else:
118
+ greedy_sampled = self.greedy_sample(logits)
119
+ if sampling_metadata.all_greedy:
120
+ return greedy_sampled
121
+
122
+ assert sampling_metadata.temperature is not None
123
+
124
+ # Apply temperature.
125
+ logits = self.apply_temperature(logits, sampling_metadata.temperature)
126
+
127
+ # Apply logits processors that only apply to random sampling
128
+ # (argmax invariant)
129
+ for processor in sampling_metadata.logitsprocs.argmax_invariant:
130
+ logits = processor.apply(logits)
131
+
132
+ # Apply top_n_sigma
133
+ logits = apply_top_n_sigma(logits, sampling_metadata)
134
+
135
+ # Apply top_k and/or top_p.
136
+ random_sampled = self.topk_topp_sampler(
137
+ logits,
138
+ sampling_metadata.generators,
139
+ sampling_metadata.top_k,
140
+ sampling_metadata.top_p,
141
+ )
142
+
143
+ if greedy_sampled is None:
144
+ return random_sampled
145
+
146
+ sampled = torch.where(
147
+ sampling_metadata.temperature < _SAMPLING_EPS,
148
+ greedy_sampled,
149
+ random_sampled,
150
+ out=greedy_sampled, # Reuse tensor
151
+ )
152
+ return sampled
153
+
154
+
155
+ if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION:
156
+ TopKTopPSampler.forward_native = topk_topp_forward_native
157
+
158
+ if envs.VLLM_ASCEND_ENABLE_TOP_N_SIGMA:
159
+ Sampler.sample = sample
inference/vllm_ascend/quantization/w8a8.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # This file is a part of the vllm-ascend project.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ from typing import Any, Callable, Dict, Optional
19
+
20
+ import torch
21
+ import torch_npu
22
+ from vllm.attention.backends.abstract import AttentionType
23
+
24
+ from vllm_ascend.attention.attention_v1 import AscendAttentionState
25
+ from vllm_ascend.distributed.parallel_state import get_ep_group
26
+ from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
27
+
28
+
29
+ def quant_per_tensor(in_tensor: torch.Tensor,
30
+ input_scale: torch.Tensor,
31
+ input_offset: torch.Tensor,
32
+ function=False):
33
+ return torch_npu.npu_quantize(in_tensor, input_scale, input_offset,
34
+ torch.qint8, -1, function)
35
+
36
+
37
+ class AscendW8A8LinearMethod:
38
+ """Linear method for Ascend W8A8.
39
+
40
+ Args:
41
+ w_sym: whether the linear weight is symmetrically quantized.
42
+ """
43
+
44
+ def __init__(self) -> None:
45
+ # aclnn quant matmul requires to transpose matrix B, set to true by default.
46
+ self.transpose_weight = not is_310p()
47
+
48
+ @staticmethod
49
+ def get_weight(
50
+ input_size: int,
51
+ output_size: int,
52
+ params_dtype: torch.dtype = torch.bfloat16,
53
+ ) -> Dict[str, Any]:
54
+ params_dict = {
55
+ "weight": torch.empty(output_size, input_size, dtype=torch.int8)
56
+ }
57
+ return params_dict
58
+
59
+ @staticmethod
60
+ def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
61
+ params_dict = {}
62
+ params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
63
+ params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
64
+ return params_dict
65
+
66
+ @staticmethod
67
+ def get_perchannel_param(
68
+ output_size: int,
69
+ params_dtype: torch.dtype,
70
+ ) -> Dict[str, Any]:
71
+ params_dict = {}
72
+ params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
73
+ if params_dtype == torch.bfloat16:
74
+ params_dict["deq_scale"] = torch.empty(output_size,
75
+ dtype=torch.float32)
76
+ elif params_dtype == torch.float16:
77
+ params_dict["deq_scale"] = torch.empty(output_size,
78
+ dtype=torch.int64)
79
+ params_dict["weight_scale"] = torch.empty(output_size,
80
+ 1,
81
+ dtype=params_dtype)
82
+ params_dict["weight_offset"] = torch.empty(output_size,
83
+ 1,
84
+ dtype=params_dtype)
85
+ return params_dict
86
+
87
+ @staticmethod
88
+ def apply(
89
+ layer: torch.nn.Module,
90
+ x: torch.Tensor,
91
+ bias: Optional[torch.Tensor] = None,
92
+ tp_rank: Optional[int] = 0,
93
+ ) -> torch.Tensor:
94
+ original_dtype = x.dtype
95
+ if original_dtype != torch.int8:
96
+ x = quant_per_tensor(x, layer.aclnn_input_scale,
97
+ layer.aclnn_input_offset)
98
+ quant_bias = layer.quant_bias if tp_rank == 0 else None
99
+ if is_310p():
100
+ # On 300I Duo platform, we need transpose again if
101
+ # using nz. This transpose can be skipped in torchair.
102
+ output = torch_npu.npu_quant_matmul(
103
+ x,
104
+ layer.weight.data.transpose(1, 0),
105
+ layer.deq_scale,
106
+ bias=quant_bias,
107
+ output_dtype=original_dtype,
108
+ )
109
+ else:
110
+ output = torch_npu.npu_quant_matmul(
111
+ x,
112
+ layer.weight,
113
+ layer.deq_scale,
114
+ bias=quant_bias,
115
+ output_dtype=original_dtype,
116
+ )
117
+ return output
118
+
119
+ def process_weights_after_loading(self, layer):
120
+ expanding_factor = layer.weight.data.shape[1]
121
+ layer.aclnn_input_scale = 1 / torch.nn.Parameter(
122
+ layer.input_scale.data.repeat(expanding_factor),
123
+ requires_grad=False)
124
+ layer.aclnn_input_offset = torch.nn.Parameter(
125
+ layer.input_offset.data.repeat(expanding_factor),
126
+ requires_grad=False).to(layer.aclnn_input_scale.dtype)
127
+ if self.transpose_weight:
128
+ layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
129
+ layer.weight.data = torch_npu.npu_format_cast(layer.weight.data,
130
+ ACL_FORMAT_FRACTAL_NZ)
131
+ layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
132
+ layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
133
+
134
+
135
+ class AscendW8A8FusedMoEMethod:
136
+ """FusedMoe method for Ascend W8A8.
137
+ """
138
+
139
+ def __init__(self):
140
+ self.transpose_weight = True
141
+
142
+ @staticmethod
143
+ def get_weight(num_experts: int, intermediate_size_per_partition: int,
144
+ hidden_sizes: int,
145
+ params_dtype: torch.dtype) -> Dict[str, Any]:
146
+ param_dict = {}
147
+ param_dict["w13_weight"] = torch.empty(num_experts,
148
+ 2 *
149
+ intermediate_size_per_partition,
150
+ hidden_sizes,
151
+ dtype=torch.int8,
152
+ requires_grad=False)
153
+ param_dict["w2_weight"] = torch.empty(num_experts,
154
+ hidden_sizes,
155
+ intermediate_size_per_partition,
156
+ dtype=torch.int8,
157
+ requires_grad=False)
158
+ return param_dict
159
+
160
+ @staticmethod
161
+ def get_dynamic_quant_param(num_experts: int,
162
+ intermediate_size_per_partition: int,
163
+ hidden_sizes: int,
164
+ params_dtype: torch.dtype) -> Dict[str, Any]:
165
+ param_dict = {}
166
+ param_dict["w13_weight_scale"] = torch.empty(
167
+ num_experts,
168
+ 2 * intermediate_size_per_partition,
169
+ 1,
170
+ dtype=torch.float32)
171
+ param_dict["w13_weight_offset"] = torch.empty(
172
+ num_experts,
173
+ 2 * intermediate_size_per_partition,
174
+ 1,
175
+ dtype=torch.float16)
176
+ param_dict["w2_weight_scale"] = torch.empty(num_experts,
177
+ hidden_sizes,
178
+ 1,
179
+ dtype=torch.float32)
180
+ param_dict["w2_weight_offset"] = torch.empty(num_experts,
181
+ hidden_sizes,
182
+ 1,
183
+ dtype=torch.float16)
184
+ param_dict["w2_deq_scale"] = torch.empty(num_experts,
185
+ hidden_sizes,
186
+ dtype=torch.float32)
187
+ param_dict["w13_deq_scale"] = torch.empty(
188
+ num_experts,
189
+ 2 * intermediate_size_per_partition,
190
+ dtype=torch.float32)
191
+ param_dict["w2_input_scale"] = torch.empty(num_experts,
192
+ 1,
193
+ dtype=torch.float32)
194
+ param_dict["w13_input_scale"] = torch.empty(num_experts,
195
+ 1,
196
+ dtype=torch.float32)
197
+ param_dict["w2_input_offset"] = torch.empty(num_experts,
198
+ 1,
199
+ dtype=torch.int8)
200
+ param_dict["w13_input_offset"] = torch.empty(num_experts,
201
+ 1,
202
+ dtype=torch.int8)
203
+ param_dict["quant_bias"] = torch.empty(num_experts,
204
+ hidden_sizes,
205
+ dtype=torch.int32)
206
+
207
+ return param_dict
208
+
209
+ def apply(
210
+ self,
211
+ layer: torch.nn.Module,
212
+ x: torch.Tensor,
213
+ router_logits: torch.Tensor,
214
+ top_k: int,
215
+ renormalize: bool,
216
+ use_grouped_topk: bool = False,
217
+ global_num_experts: int = -1,
218
+ expert_map: Optional[torch.Tensor] = None,
219
+ topk_group: Optional[int] = None,
220
+ num_expert_group: Optional[int] = None,
221
+ custom_routing_function: Optional[Callable] = None,
222
+ scoring_func: str = "softmax",
223
+ e_score_correction_bias: Optional[torch.Tensor] = None,
224
+ is_prefill: bool = True,
225
+ enable_force_load_balance: bool = False,
226
+ log2phy: torch.Tensor = None,
227
+ global_redundant_expert_num: int = 0,
228
+ shared_experts: Optional[Any] = None,
229
+ **kwargs,
230
+ ) -> torch.Tensor:
231
+ assert router_logits.shape[
232
+ 1] == global_num_experts, "Number of global experts mismatch"
233
+
234
+ topk_weights, topk_ids = select_experts(
235
+ hidden_states=x,
236
+ router_logits=router_logits,
237
+ top_k=top_k,
238
+ use_grouped_topk=use_grouped_topk,
239
+ renormalize=renormalize,
240
+ topk_group=topk_group,
241
+ num_expert_group=num_expert_group,
242
+ custom_routing_function=custom_routing_function,
243
+ scoring_func=scoring_func,
244
+ e_score_correction_bias=e_score_correction_bias,
245
+ global_num_experts=global_num_experts,
246
+ )
247
+
248
+ if is_310p():
249
+ return fused_experts_310p(hidden_states=x,
250
+ w1=layer.w13_weight,
251
+ w1_scale=layer.w13_weight_scale,
252
+ w1_input_scale=layer.w13_input_scale,
253
+ w2=layer.w2_weight,
254
+ w2_scale=layer.w2_weight_scale,
255
+ w2_input_scale=layer.w2_input_scale,
256
+ topk_weights=topk_weights,
257
+ topk_ids=topk_ids,
258
+ top_k=top_k,
259
+ global_num_experts=global_num_experts,
260
+ expert_map=expert_map)
261
+ return fused_experts(hidden_states=x,
262
+ w1=layer.w13_weight,
263
+ w1_scale=layer.w13_weight_scale,
264
+ w1_input_scale=layer.w13_input_scale,
265
+ w1_input_offset=layer.w13_input_offset,
266
+ w2=layer.w2_weight,
267
+ w2_scale=layer.w2_weight_scale,
268
+ w2_input_scale=layer.w2_input_scale,
269
+ w2_input_offset=layer.w2_input_offset,
270
+ topk_weights=topk_weights,
271
+ topk_ids=topk_ids,
272
+ top_k=top_k,
273
+ global_num_experts=global_num_experts,
274
+ expert_map=expert_map)
275
+
276
+ def process_weights_after_loading(self, layer):
277
+ if not is_310p():
278
+ layer.w13_weight.data = layer.w13_weight.data.transpose(
279
+ 1, 2).contiguous()
280
+ layer.w2_weight.data = layer.w2_weight.data.transpose(
281
+ 1, 2).contiguous()
282
+ layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
283
+ layer.w13_weight_scale.data.shape[0], -1)
284
+
285
+ layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
286
+ layer.w13_weight_offset.data.shape[0], -1)
287
+ layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
288
+ layer.w2_weight_scale.data.shape[0], -1)
289
+ layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
290
+ layer.w2_weight_offset.data.shape[0], -1)
291
+ expanding_factor_w13 = layer.w13_weight.data.shape[1]
292
+ expanding_factor_w2 = layer.w2_weight.data.shape[1]
293
+
294
+ if is_310p():
295
+ layer.w13_input_scale.data = torch.nn.Parameter(
296
+ layer.w13_input_scale.data.max())
297
+ layer.w2_input_scale.data = torch.nn.Parameter(
298
+ layer.w2_input_scale.data.max())
299
+ else:
300
+ layer.w13_input_scale.data = torch.nn.Parameter(
301
+ layer.w13_input_scale.data.repeat(1,
302
+ expanding_factor_w13)[0:1])
303
+ layer.w2_input_scale.data = torch.nn.Parameter(
304
+ layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
305
+
306
+ layer.w13_input_offset.data = torch.nn.Parameter(
307
+ layer.w13_input_scale.data.repeat(1, expanding_factor_w13)[0:1])
308
+ layer.w2_input_offset.data = torch.nn.Parameter(
309
+ layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
310
+
311
+ # converting ACL_FORMAT_FRACTAL_NZ.
312
+ # npu_quant_grouped_matmul_dequant in eager mode does not accept
313
+ # ACL_FORMAT_FRACTAL_NZ.
314
+ if not is_310p():
315
+ layer.w13_weight.data = torch_npu.npu_format_cast(
316
+ layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
317
+ layer.w2_weight.data = torch_npu.npu_format_cast(
318
+ layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
319
+
320
+
321
+ class AscendC8KVCacheMethod:
322
+
323
+ def __init__(self) -> None:
324
+ self.antiquant_scale_comb = None
325
+
326
+ @staticmethod
327
+ def create_weights(layer) -> None:
328
+ param_dict = {} # num_kv_heads * head_size
329
+ param_dict["key_antiquant_scale"] = torch.empty(layer.num_kv_heads *
330
+ layer.head_size,
331
+ dtype=torch.float16,
332
+ requires_grad=False)
333
+ param_dict["value_antiquant_scale"] = torch.empty(layer.num_kv_heads *
334
+ layer.head_size,
335
+ dtype=torch.float16,
336
+ requires_grad=False)
337
+ for weight_name, weight_param in param_dict.items():
338
+ param = torch.nn.Parameter(weight_param, requires_grad=False)
339
+ layer.register_parameter(weight_name, param)
340
+
341
+ def process_weights_after_loading(self, layer):
342
+ self.antiquant_scale_comb = torch.cat(
343
+ (layer.key_antiquant_scale.data.unsqueeze(0),
344
+ layer.value_antiquant_scale.data.unsqueeze(0)),
345
+ dim=0).to(torch.float16).contiguous()
346
+
347
+ def apply(self, layer, query, key, value, kv_cache, attn_metadata,
348
+ attn_type, scale, output) -> torch.Tensor:
349
+ num_tokens = query.shape[0]
350
+ if attn_metadata is None:
351
+ return output.view(num_tokens, layer.num_heads * layer.head_size)
352
+ assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
353
+ if attn_type != AttentionType.DECODER:
354
+ raise NotImplementedError("Encoder self-attention and "
355
+ "encoder/decoder cross-attention "
356
+ "are not implemented for "
357
+ "PallasAttentionBackendImpl")
358
+
359
+ # C8
360
+ quant_key = quant_per_tensor(
361
+ key.view(-1, layer.num_kv_heads * layer.head_size),
362
+ layer.key_antiquant_scale.data.view(-1), None, True)
363
+ quant_value = quant_per_tensor(
364
+ value.view(-1, layer.num_kv_heads * layer.head_size),
365
+ layer.value_antiquant_scale.data.view(-1), None, True)
366
+
367
+ # View q k v to BSH.
368
+ query = query.view(-1, layer.num_heads, layer.head_size)
369
+ key = key.view(-1, layer.num_kv_heads, layer.head_size)
370
+ value = value.view(-1, layer.num_kv_heads, layer.head_size)
371
+ # TODO: Remove this contiguous in the future.
372
+ value = value.contiguous()
373
+
374
+ if kv_cache[0].numel() > 0:
375
+ # if key_cache is None:
376
+ key_cache, value_cache = kv_cache[0], kv_cache[1]
377
+ slots = attn_metadata.slot_mapping
378
+
379
+ block_size = key_cache.shape[1]
380
+ slots_indices = slots.reshape(-1, 1)
381
+ block_indices = slots_indices // block_size
382
+ slots_indices = slots_indices % block_size
383
+ indices = torch.cat((block_indices, slots_indices), dim=1)
384
+
385
+ # C8
386
+ torch_npu.npu_scatter_nd_update_(key_cache, indices, quant_key)
387
+ torch_npu.npu_scatter_nd_update_(value_cache, indices, quant_value)
388
+
389
+ # V0-Style scheduler situation.
390
+ if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
391
+ assert attn_metadata is not None
392
+ assert attn_metadata.attn_mask is not None
393
+ mask = attn_metadata.attn_mask
394
+ torch_npu._npu_flash_attention(query=query,
395
+ key=key,
396
+ value=value,
397
+ mask=mask,
398
+ seq_len=attn_metadata.seq_lens,
399
+ scale_value=scale,
400
+ num_heads=layer.num_heads,
401
+ num_kv_heads=layer.num_kv_heads,
402
+ out=output.reshape(query.shape))
403
+
404
+ elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
405
+ raise NotImplementedError("kv cache int8 are not "
406
+ "implemented for "
407
+ "PrefillCacheHit")
408
+ elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # changed attn_metadata.attn_state == AscendAttentionState.DecodeOnly
409
+ if hasattr(attn_metadata, "decode"):
410
+ # torch_air
411
+ decode_meta = attn_metadata.decode
412
+ seq_lens = decode_meta.seq_lens_list
413
+ else:
414
+ seq_lens = attn_metadata.seq_lens
415
+ block_size = key_cache.shape[1]
416
+ query = query.view(num_tokens, 1, layer.num_heads *
417
+ layer.head_size).contiguous() # changed
418
+
419
+ # [num_blocks, block_size, N, D] --> [num_blocks, N, block_size, D]
420
+ key = key_cache
421
+ value = value_cache
422
+
423
+ output = torch_npu.npu_incre_flash_attention(
424
+ query,
425
+ key,
426
+ value,
427
+ num_key_value_heads=layer.num_kv_heads,
428
+ num_heads=layer.num_heads,
429
+ actual_seq_lengths=seq_lens,
430
+ scale_value=scale,
431
+ input_layout='BSH',
432
+ block_size=block_size,
433
+ block_table=attn_metadata.block_tables,
434
+ antiquant_scale=self.antiquant_scale_comb,
435
+ )
436
+
437
+ # Normal V1 situation.
438
+ else:
439
+ raise NotImplementedError("kv cache int8 are not "
440
+ "implemented for "
441
+ "other case")
442
+ return output
443
+
444
+
445
+ def fused_experts_310p(
446
+ hidden_states: torch.Tensor,
447
+ w1: torch.Tensor,
448
+ w1_scale: torch.Tensor,
449
+ w1_input_scale: torch.Tensor,
450
+ w2: torch.Tensor,
451
+ w2_scale: torch.Tensor,
452
+ w2_input_scale: torch.Tensor,
453
+ topk_weights: torch.Tensor,
454
+ topk_ids: torch.Tensor,
455
+ top_k: int,
456
+ global_num_experts: int,
457
+ expert_map: torch.Tensor = None,
458
+ ) -> torch.Tensor:
459
+ ep_size = get_ep_group().world_size
460
+ local_num_experts = global_num_experts // ep_size
461
+ local_num_group = top_k // ep_size
462
+
463
+ bsz, _ = hidden_states.shape
464
+ flatten_topk_ids = topk_ids.view(-1)
465
+ sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
466
+ sorted_topk_ids = sorted_topk_ids.to(torch.int32)
467
+ sorted_hidden_states = hidden_states.index_select(
468
+ 0, sorted_topk_ids // local_num_group)
469
+
470
+ experts_id = torch.arange(0,
471
+ local_num_experts,
472
+ dtype=topk_ids.dtype,
473
+ device=topk_ids.device)
474
+ num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
475
+ torch.float32).sum(0)
476
+ topk_scales = topk_weights.view(-1).index_select(
477
+ 0, sorted_topk_ids).unsqueeze(-1)
478
+ group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
479
+
480
+ gate_up_out = torch_npu.npu_quant_grouped_matmul_dequant(
481
+ x=sorted_hidden_states,
482
+ quantized_weight=w1,
483
+ weight_scale=w1_scale,
484
+ group_list=group_list,
485
+ x_scale=w1_input_scale,
486
+ quant_mode="pertensor")
487
+
488
+ gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
489
+ torch.float16)
490
+ gate_up_out *= topk_scales
491
+
492
+ down_out = torch_npu.npu_quant_grouped_matmul_dequant(
493
+ x=gate_up_out,
494
+ quantized_weight=w2,
495
+ weight_scale=w2_scale,
496
+ group_list=group_list,
497
+ x_scale=w2_input_scale,
498
+ quant_mode="pertensor")
499
+
500
+ unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
501
+ unsorted_hidden_states = down_out.index_select(0, unsorted_topk_ids)
502
+ final_hidden_states = unsorted_hidden_states.reshape(
503
+ bsz, top_k // ep_size, -1).sum(1)
504
+
505
+ return final_hidden_states
506
+
507
+
508
+ def fused_experts(
509
+ hidden_states: torch.Tensor,
510
+ w1: torch.Tensor,
511
+ w1_scale: torch.Tensor,
512
+ w1_input_scale: torch.Tensor,
513
+ w1_input_offset: torch.Tensor,
514
+ w2: torch.Tensor,
515
+ w2_scale: torch.Tensor,
516
+ w2_input_scale: torch.Tensor,
517
+ w2_input_offset: torch.Tensor,
518
+ topk_weights: torch.Tensor,
519
+ topk_ids: torch.Tensor,
520
+ top_k: int,
521
+ global_num_experts: int,
522
+ expert_map: torch.Tensor = None,
523
+ ) -> torch.Tensor:
524
+ """
525
+ Fused experts with top-k routing.
526
+
527
+ Args:
528
+ hidden_states: Hidden states of shape (num_tokens, hidden_size).
529
+ w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
530
+ w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
531
+ topk_weights: Routing weights of shape (num_tokens, top_k).
532
+ topk_ids: Selected expert IDs of shape (num_tokens, top_k).
533
+ top_k: Number of experts to select.
534
+ expert_map: Expert mapping of shape (num_experts,).
535
+
536
+ Returns:
537
+ hidden_states: Hidden states after routing.
538
+ """
539
+ """
540
+ # Check constraints.
541
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
542
+ assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
543
+ assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
544
+ assert w1.is_contiguous(), "Expert weights1 must be contiguous"
545
+ assert w2.is_contiguous(), "Expert weights2 must be contiguous"
546
+ """
547
+
548
+ original_dtype = hidden_states.dtype
549
+ ep_size = get_ep_group().world_size
550
+ local_num_experts = global_num_experts // ep_size
551
+ w1_input_scale, _ = w1_input_scale.max(0)
552
+ quant_sorted_hidden_states = quant_per_tensor(
553
+ hidden_states,
554
+ w1_input_scale,
555
+ None,
556
+ True,
557
+ )
558
+ if expert_map is not None:
559
+ expanded_x, expanded_row_idx, expert_token_count, expanded_scale = torch_npu.npu_moe_init_routing_v2(
560
+ quant_sorted_hidden_states,
561
+ topk_ids,
562
+ scale=None,
563
+ active_num=topk_ids.numel(),
564
+ expert_capacity=-1,
565
+ expert_num=local_num_experts,
566
+ drop_pad_mode=0,
567
+ expert_tokens_num_type=1,
568
+ expert_tokens_num_flag=True,
569
+ quant_mode=-1,
570
+ active_expert_range=[0, local_num_experts],
571
+ row_idx_type=0,
572
+ )
573
+
574
+ else:
575
+ raise NotImplementedError(
576
+ "The quantified version of MOE class models "
577
+ "currently does not support tensor parallelism")
578
+ if expanded_x.dtype != w1.dtype:
579
+ w1_input_scale, _ = w1_input_scale.max(0)
580
+ quant_sorted_hidden_states = quant_per_tensor(
581
+ expanded_x,
582
+ w1_input_scale,
583
+ None,
584
+ True,
585
+ )
586
+ else:
587
+ quant_sorted_hidden_states = expanded_x
588
+ gate_up_out = torch_npu.npu_grouped_matmul(
589
+ x=[quant_sorted_hidden_states],
590
+ weight=[w1],
591
+ scale=[w1_scale * w1_input_scale[0]],
592
+ split_item=2,
593
+ group_list_type=1,
594
+ group_type=0,
595
+ group_list=expert_token_count,
596
+ output_dtype=original_dtype,
597
+ )[0]
598
+ gate_up_out = torch_npu.npu_swiglu(gate_up_out)
599
+
600
+ if gate_up_out.dtype != w2.dtype:
601
+ w2_input_scale, _ = w2_input_scale.max(0)
602
+ quant_gate_up_out = quant_per_tensor(
603
+ gate_up_out,
604
+ w2_input_scale,
605
+ None,
606
+ True,
607
+ )
608
+ else:
609
+ quant_gate_up_out = gate_up_out
610
+
611
+ down_out = torch_npu.npu_grouped_matmul(
612
+ x=[quant_gate_up_out],
613
+ weight=[w2],
614
+ scale=[w2_scale * w2_input_scale[0]],
615
+ split_item=2,
616
+ group_list_type=1,
617
+ group_type=0,
618
+ group_list=expert_token_count,
619
+ output_dtype=original_dtype,
620
+ )[0]
621
+
622
+ if expert_map is not None:
623
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
624
+ down_out,
625
+ skip1=None,
626
+ skip2=None,
627
+ bias=None,
628
+ scales=topk_weights.to(down_out.dtype),
629
+ expanded_src_to_dst_row=expanded_row_idx,
630
+ export_for_source_row=topk_ids,
631
+ drop_pad_mode=2,
632
+ )
633
+ else:
634
+ raise NotImplementedError(
635
+ "The quantified version of MOE class models "
636
+ "currently does not support tensor parallelism")
637
+
638
+ return final_hidden_states
639
+
640
+
641
+ def select_experts(
642
+ hidden_states: torch.Tensor,
643
+ router_logits: torch.Tensor,
644
+ top_k: int,
645
+ use_grouped_topk: bool,
646
+ renormalize: bool,
647
+ topk_group: Optional[int] = None,
648
+ num_expert_group: Optional[int] = None,
649
+ custom_routing_function: Optional[Callable] = None,
650
+ scoring_func: str = "softmax",
651
+ e_score_correction_bias: Optional[torch.Tensor] = None,
652
+ global_num_experts=-1,
653
+ ) -> tuple[torch.Tensor, torch.Tensor]:
654
+ """
655
+ Select top-k experts based on router logits.
656
+
657
+ Args:
658
+ hidden_states: Hidden states of shape (num_tokens, hidden_size).
659
+ router_logits: Router logits of shape (num_tokens, num_experts).
660
+ top_k: Number of experts to select.
661
+ use_grouped_topk: Whether to group experts before selecting top-k.
662
+ renormalize: Whether to renormalize the routing weights.
663
+ topk_group: Number of expert groups to select from.
664
+ num_expert_group: Number of experts in each group.
665
+ custom_routing_function: Custom routing function.
666
+ scoring_func: Scoring function to use.
667
+ e_score_correction_bias: Correction bias to apply to expert scores.
668
+
669
+ Returns:
670
+ topk_weights: Routing weights of shape (num_tokens, top_k).
671
+ topk_ids: Selected expert IDs of shape (num_tokens, top_k).
672
+
673
+ Raises:
674
+ ValueError: If an unsupported scoring function is provided.
675
+ """
676
+
677
+ if scoring_func == "softmax":
678
+ # NOTE: vLLM use dtype=torch.float here
679
+ topk_weights = router_logits.softmax(dim=-1)
680
+ elif scoring_func == "sigmoid":
681
+ topk_weights = router_logits.sigmoid()
682
+ else:
683
+ raise ValueError(f"Unsupported scoring function: {scoring_func}")
684
+
685
+ if use_grouped_topk:
686
+ assert topk_group is not None
687
+ assert num_expert_group is not None
688
+
689
+ if e_score_correction_bias is not None:
690
+ # Store original scores before applying correction bias. We use biased
691
+ # scores for expert selection but original scores for routing weights
692
+ original_weights = topk_weights
693
+ topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
694
+
695
+ # TODO: Change to npu_group_topk when the latest CANN and NNAL is available
696
+ # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
697
+ topk_weights = native_grouped_topk(topk_weights, num_expert_group,
698
+ topk_group)
699
+ # TODO bfloat16 is not supported in torch.topk with ge graph.
700
+ if e_score_correction_bias is not None:
701
+ topk_ids = torch.topk(topk_weights.to(torch.float32),
702
+ k=top_k,
703
+ dim=-1,
704
+ sorted=False)[1]
705
+ # Use original unbiased scores for the routing weights
706
+ topk_weights = original_weights.gather(1, topk_ids)
707
+ else:
708
+ topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
709
+ k=top_k,
710
+ dim=-1,
711
+ sorted=False)
712
+ elif custom_routing_function is None:
713
+ topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
714
+ else:
715
+ topk_weights, topk_ids = custom_routing_function(
716
+ hidden_states=hidden_states,
717
+ gating_output=router_logits,
718
+ topk=top_k,
719
+ renormalize=renormalize,
720
+ global_num_experts=global_num_experts,
721
+ )
722
+ # Required by npu_moe_init_routing
723
+ topk_ids = topk_ids.to(torch.int32)
724
+ return topk_weights, topk_ids
725
+
726
+ # Required by npu_moe_init_routing
727
+ topk_ids = topk_ids.to(torch.int32)
728
+
729
+ if renormalize:
730
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
731
+
732
+ return topk_weights, topk_ids
733
+
734
+
735
+ def native_grouped_topk(
736
+ topk_weights: torch.Tensor,
737
+ num_expert_group: Optional[int],
738
+ topk_group: Optional[int],
739
+ ):
740
+ topk_group = 0 if topk_group is None else topk_group
741
+ num_expert_group = 0 if num_expert_group is None else num_expert_group
742
+
743
+ num_token = topk_weights.shape[0]
744
+ grouped_weights = topk_weights.view(num_token, num_expert_group,
745
+ -1).max(dim=-1).values
746
+ topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
747
+ k=topk_group,
748
+ dim=-1,
749
+ sorted=False)[1]
750
+ topk_group_mask = torch.zeros_like(grouped_weights)
751
+ topk_group_mask.scatter_(1, topk_group_indices, 1)
752
+ topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
753
+ num_token, num_expert_group,
754
+ topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
755
+ topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
756
+
757
+ return topk_weights
inference/vllm_ascend/quantization/w8a8_dynamic.py ADDED
@@ -0,0 +1,831 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # This file is a part of the vllm-ascend project.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ from typing import Any, Callable, Dict, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+ import torch_npu
23
+ from vllm.distributed import GroupCoordinator
24
+
25
+ import vllm_ascend.envs as envs
26
+ from vllm_ascend.ascend_config import get_ascend_config
27
+ from vllm_ascend.distributed.parallel_state import get_ep_group
28
+ from vllm_ascend.ops.fused_moe import select_experts
29
+ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, FusedMoEState,
30
+ dispose_tensor, get_fused_moe_state,
31
+ npu_stream_switch, npu_wait_tensor)
32
+
33
+
34
+ def apply_mlp(hidden_states: torch.Tensor,
35
+ w1: torch.Tensor,
36
+ w1_scale: torch.Tensor,
37
+ w2: torch.Tensor,
38
+ w2_scale: torch.Tensor,
39
+ group_list: torch.Tensor,
40
+ dynamic_scale: torch.Tensor = None,
41
+ group_list_type: int = 1) -> torch.Tensor:
42
+ """
43
+ apply MLP: gate_up_proj -> swiglu -> down_proj
44
+
45
+ Args:
46
+ hidden_states: input hidden states with shape (num_tokens, hidden_size).
47
+ w1: expert weights1 with shape
48
+ (num_experts, hidden_size, intermediate_size * 2)
49
+ w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
50
+ w2: expert weights2 with shape
51
+ (num_experts, intermediate_size, hidden_size)
52
+ w2_scale: weights2 scale with shape (num_experts, hidden_size)
53
+ group_list: number of tokens for each expert, follow cumsum mode, and
54
+ with shape (num_experts).
55
+ transpose_weight:
56
+ w1: (num_experts, intermediate_size * 2, hidden_size) ->
57
+ (num_experts, hidden_size, intermediate_size * 2)
58
+ w2: (num_experts, hidden_size, intermediate_size) ->
59
+ (num_experts, intermediate_size, hidden_size)
60
+
61
+ Returns:
62
+ hidden_states: output hidden states after MLP.
63
+ """
64
+
65
+ if dynamic_scale is None:
66
+ unquantized_hidden_states = hidden_states
67
+ hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
68
+ hidden_states)
69
+ # Dispose the original unquantized hidden states
70
+ # to save npu memory because they're no longer used.
71
+ dispose_tensor(unquantized_hidden_states)
72
+ else:
73
+ pertoken_scale = dynamic_scale
74
+
75
+ # gmm1: gate_up_proj
76
+ hidden_states = torch_npu.npu_grouped_matmul(
77
+ x=[hidden_states],
78
+ weight=[w1],
79
+ scale=[w1_scale],
80
+ per_token_scale=[pertoken_scale],
81
+ split_item=2,
82
+ group_list_type=group_list_type,
83
+ group_type=0,
84
+ group_list=group_list,
85
+ output_dtype=w2_scale.dtype)[0]
86
+
87
+ # act_fn: swiglu
88
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
89
+ hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
90
+ hidden_states)
91
+
92
+ # gmm2: down_proj
93
+ hidden_states = torch_npu.npu_grouped_matmul(
94
+ x=[hidden_states],
95
+ weight=[w2],
96
+ scale=[w2_scale],
97
+ per_token_scale=[swiglu_out_scale],
98
+ split_item=2,
99
+ group_list_type=group_list_type,
100
+ group_type=0,
101
+ group_list=group_list,
102
+ output_dtype=w2_scale.dtype)[0]
103
+
104
+ return hidden_states
105
+
106
+
107
+ def fused_experts_with_mc2(
108
+ hidden_states: torch.Tensor,
109
+ w1: torch.Tensor,
110
+ w2: torch.Tensor,
111
+ w1_scale: torch.Tensor,
112
+ w2_scale: torch.Tensor,
113
+ topk_weights: torch.Tensor,
114
+ topk_ids: torch.Tensor,
115
+ top_k: int,
116
+ expert_map: torch.Tensor = None,
117
+ moe_all_to_all_group_name: str = "",
118
+ log2phy: torch.Tensor = None,
119
+ global_redundant_expert_num: int = 0,
120
+ shared_experts: Optional[Any] = None,
121
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
122
+ if log2phy is not None:
123
+ topk_ids = log2phy[topk_ids]
124
+ global_bs = 0
125
+ moe_expert_num = len(expert_map) + global_redundant_expert_num
126
+ # hidden_states = hidden_states.bfloat16()
127
+ kwargs_mc2 = {
128
+ "x": hidden_states,
129
+ "expert_ids": topk_ids,
130
+ "expert_shard_type": 0,
131
+ "shared_expert_rank_num": 0,
132
+ "moe_expert_num": moe_expert_num,
133
+ "global_bs": global_bs,
134
+ "expert_scales": topk_weights.to(torch.float32),
135
+ }
136
+
137
+ rank = torch.distributed.get_rank()
138
+
139
+ quant_mode = 2
140
+ ep_group = get_ep_group().device_group
141
+ local_rank = torch.distributed.get_rank(group=ep_group)
142
+ all_to_all_group_size = torch.distributed.get_world_size(ep_group)
143
+
144
+ world_size = torch.distributed.get_world_size()
145
+ tp_size = world_size // all_to_all_group_size
146
+ tp_rank = rank % tp_size
147
+
148
+ stage1_kwargs = {
149
+ "scales": None,
150
+ "quant_mode": quant_mode,
151
+ "group_ep": moe_all_to_all_group_name,
152
+ "ep_world_size": all_to_all_group_size,
153
+ "ep_rank_id": local_rank,
154
+ # "group_tp": self.moe_rs_group_name,
155
+ "group_tp": moe_all_to_all_group_name,
156
+ "tp_world_size": tp_size,
157
+ "tp_rank_id": tp_rank,
158
+ }
159
+ kwargs_mc2.update(stage1_kwargs)
160
+
161
+ output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
162
+ # comm_stream.wait_stream(torch.npu.current_stream())
163
+ expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts, _, expand_scales = output[
164
+ 0:7]
165
+
166
+ if shared_experts is not None:
167
+ with npu_stream_switch("moe_secondary", 0):
168
+ npu_wait_tensor(hidden_states, topk_weights)
169
+ shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
170
+ npu_wait_tensor(shared_gate_up[0], expand_x)
171
+ shared_act = shared_experts.act_fn(shared_gate_up)
172
+
173
+ # `expand_x` will be disposed in the `apply_mlp` function
174
+ down_out_list = apply_mlp(expand_x,
175
+ w1,
176
+ w1_scale,
177
+ w2,
178
+ w2_scale,
179
+ expert_token_nums,
180
+ dynamic_scale=dynamic_scale)
181
+
182
+ # moeCombine
183
+ kwargs_mc2 = {
184
+ "expand_x": down_out_list,
185
+ "expert_ids": topk_ids,
186
+ "expand_idx": expand_idx,
187
+ "expert_scales": topk_weights.to(torch.float32),
188
+ "expert_shard_type": 0,
189
+ "shared_expert_rank_num": 0,
190
+ "moe_expert_num": moe_expert_num,
191
+ "global_bs": 0,
192
+ "expand_scales": expand_scales,
193
+ }
194
+ tp_recv_counts = torch.empty(1,
195
+ dtype=torch.int32,
196
+ device=hidden_states.device)
197
+ stage3_kwargs = {
198
+ "ep_send_counts": ep_recv_counts,
199
+ "group_ep": moe_all_to_all_group_name,
200
+ "ep_world_size": all_to_all_group_size,
201
+ "ep_rank_id": local_rank,
202
+ "tp_send_counts": tp_recv_counts,
203
+ # "group_tp": self.moe_rs_group_name,
204
+ "group_tp": moe_all_to_all_group_name,
205
+ "tp_world_size": tp_size,
206
+ "tp_rank_id": tp_rank,
207
+ }
208
+ kwargs_mc2.update(stage3_kwargs)
209
+
210
+ hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
211
+
212
+ if shared_experts is None:
213
+ return hidden_states
214
+ else:
215
+ with npu_stream_switch("moe_secondary", 0):
216
+ npu_wait_tensor(shared_act[0], down_out_list)
217
+ shared_output, _ = shared_experts.down_proj(shared_act)
218
+ return hidden_states, shared_output
219
+
220
+
221
+ # currently expert parallelism implemented with all2all
222
+ # is under-optimized.
223
+ def fused_experts_with_all2all(
224
+ hidden_states: torch.Tensor,
225
+ w1: torch.Tensor,
226
+ w1_scale: torch.Tensor,
227
+ w2: torch.Tensor,
228
+ w2_scale: torch.Tensor,
229
+ topk_weights: torch.Tensor,
230
+ topk_ids: torch.Tensor,
231
+ top_k: int,
232
+ expert_map: torch.Tensor = None,
233
+ ep_group: GroupCoordinator = None,
234
+ log2phy: torch.Tensor = None,
235
+ global_redundant_expert_num: int = 0,
236
+ ):
237
+ if log2phy is not None:
238
+ topk_ids = log2phy[topk_ids]
239
+ original_shape = hidden_states.shape
240
+ if len(original_shape) == 3:
241
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
242
+
243
+ num_tokens, _ = hidden_states.shape
244
+ num_experts = w1.shape[0]
245
+ device = hidden_states.device
246
+
247
+ if expert_map is not None:
248
+ global_num_experts = len(expert_map) + global_redundant_expert_num
249
+ local_num_experts = global_num_experts // ep_group.world_size
250
+ row_idx_len = num_tokens * top_k
251
+ row_idx = (torch.arange(0,
252
+ row_idx_len,
253
+ dtype=torch.int32,
254
+ device=device).view(top_k, -1).permute(
255
+ 1, 0).contiguous())
256
+ hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
257
+ hidden_states,
258
+ row_idx=row_idx,
259
+ expert_idx=topk_ids,
260
+ active_num=num_tokens)
261
+
262
+ global_expert_tokens = torch.bincount(expanded_expert_idx,
263
+ minlength=global_num_experts)
264
+ scatter_sizes = global_expert_tokens.view(ep_group.world_size,
265
+ -1).sum(-1)
266
+
267
+ gather_sizes = torch.empty_like(scatter_sizes)
268
+ dist.all_to_all_single(gather_sizes,
269
+ scatter_sizes,
270
+ group=ep_group.device_group)
271
+ scatter_size_list = scatter_sizes.cpu().tolist()
272
+ gather_size_list = gather_sizes.cpu().tolist()
273
+
274
+ expanded_expert_idx = expanded_expert_idx % local_num_experts
275
+ hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
276
+ scatter_size_list,
277
+ gather_size_list)
278
+ local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
279
+ scatter_size_list,
280
+ gather_size_list)
281
+
282
+ sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
283
+
284
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
285
+ sorted_local_expert_idx, local_num_experts).to(torch.int64)
286
+
287
+ hidden_states = hidden_states[sorted_idx]
288
+ group_list_type = 0
289
+ else:
290
+ row_idx_len = num_tokens * top_k
291
+ row_idx = torch.arange(0,
292
+ row_idx_len,
293
+ dtype=torch.int32,
294
+ device=topk_weights.device).view(
295
+ top_k, -1).permute(1, 0).contiguous()
296
+ hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
297
+ hidden_states,
298
+ row_idx=row_idx,
299
+ expert_idx=topk_ids,
300
+ active_num=num_tokens)
301
+
302
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
303
+ expanded_expert_idx, num_experts)
304
+ expert_tokens = expert_tokens.to(torch.int64)
305
+ group_list_type = 0
306
+
307
+ # `hidden_states` will be disposed in the `apply_mlp` function
308
+ hidden_states = apply_mlp(
309
+ hidden_states,
310
+ w1,
311
+ w1_scale, #17
312
+ w2,
313
+ w2_scale,
314
+ expert_tokens, #16
315
+ group_list_type=group_list_type)
316
+
317
+ if expert_map is not None:
318
+ resorted_idx = torch.argsort(sorted_idx)
319
+ hidden_states = hidden_states[resorted_idx]
320
+ hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
321
+ gather_size_list,
322
+ scatter_size_list)
323
+
324
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
325
+ hidden_states,
326
+ skip1=None,
327
+ skip2=None,
328
+ bias=None,
329
+ scales=topk_weights,
330
+ expanded_src_to_dst_row=expanded_row_idx,
331
+ export_for_source_row=topk_ids,
332
+ )
333
+ else:
334
+ # TODO: Reorder device memory 2 times here, replace the current
335
+ # implementation here when suitable operators become available.
336
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
337
+ hidden_states,
338
+ skip1=None,
339
+ skip2=None,
340
+ bias=None,
341
+ scales=topk_weights,
342
+ expanded_src_to_dst_row=expanded_row_idx,
343
+ export_for_source_row=topk_ids,
344
+ )
345
+ if len(original_shape) == 3:
346
+ final_hidden_states = final_hidden_states.view(original_shape)
347
+ return final_hidden_states
348
+
349
+
350
+ def fused_experts_with_allgather(hidden_states: torch.Tensor,
351
+ w1: torch.Tensor,
352
+ w1_scale: torch.Tensor,
353
+ w2: torch.Tensor,
354
+ w2_scale: torch.Tensor,
355
+ topk_weights: torch.Tensor,
356
+ topk_ids: torch.Tensor,
357
+ top_k: int,
358
+ expert_map: torch.Tensor = None):
359
+ original_shape = hidden_states.shape
360
+ if len(original_shape) == 3:
361
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
362
+ num_tokens = hidden_states.shape[0]
363
+ batch_size, hidden_size = hidden_states.shape
364
+
365
+ ep_group = get_ep_group().device_group
366
+ ep_rank = torch.distributed.get_rank(group=ep_group)
367
+ ep_size = torch.distributed.get_world_size(ep_group)
368
+
369
+ global_num_experts = len(expert_map)
370
+ local_num_experts = global_num_experts // ep_size
371
+
372
+ hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
373
+
374
+ hidden_states, expanded_x_idx, expert_tokens, pertoken_scale = torch_npu.npu_moe_init_routing_v2(
375
+ hidden_states,
376
+ topk_ids,
377
+ scale=pertoken_scale,
378
+ offset=None,
379
+ active_num=num_tokens * top_k,
380
+ expert_num=global_num_experts,
381
+ expert_tokens_num_type=1,
382
+ expert_tokens_num_flag=True,
383
+ active_expert_range=[
384
+ ep_rank * local_num_experts, (ep_rank + 1) * local_num_experts
385
+ ],
386
+ quant_mode=-1,
387
+ row_idx_type=0)
388
+ group_list_type = 1
389
+
390
+
391
+ hidden_states = torch_npu.npu_grouped_matmul(
392
+ x=[hidden_states],
393
+ weight=[w1],
394
+ split_item=3,
395
+ group_list_type=group_list_type,
396
+ group_type=0,
397
+ group_list=expert_tokens,
398
+ output_dtype=torch.int32)[0]
399
+
400
+ # act_fn: swiglu
401
+ hidden_states, pertoken_scale = torch_npu.npu_dequant_swiglu_quant(
402
+ x=hidden_states,
403
+ weight_scale=w1_scale.to(torch.float32),
404
+ activation_scale=pertoken_scale,
405
+ bias=None,
406
+ quant_scale=None,
407
+ quant_offset=None,
408
+ group_index=expert_tokens,
409
+ activate_left=True,
410
+ quant_mode=1,
411
+ )
412
+
413
+ hidden_states = torch_npu.npu_grouped_matmul(
414
+ x=[hidden_states],
415
+ weight=[w2],
416
+ scale=[w2_scale.to(torch.bfloat16)],
417
+ per_token_scale=[pertoken_scale.view(-1)],
418
+ split_item=3,
419
+ group_list_type=group_list_type,
420
+ group_type=0,
421
+ group_list=expert_tokens,
422
+ output_dtype=torch.bfloat16)[0]
423
+
424
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
425
+ expanded_permuted_rows=hidden_states.unsqueeze(1),
426
+ skip1=None,
427
+ skip2=None,
428
+ bias=None,
429
+ scales=topk_weights.to(torch.bfloat16),
430
+ expanded_src_to_dst_row=expanded_x_idx.to(torch.int32),
431
+ export_for_source_row=topk_ids,
432
+ drop_pad_mode=3
433
+ ).to(torch.bfloat16)
434
+
435
+ if len(original_shape) == 3:
436
+ final_hidden_states = final_hidden_states.view(original_shape)
437
+
438
+ return final_hidden_states
439
+
440
+
441
+ def fused_experts(hidden_states: torch.Tensor,
442
+ w1: torch.Tensor,
443
+ w1_scale: torch.Tensor,
444
+ w2: torch.Tensor,
445
+ w2_scale: torch.Tensor,
446
+ topk_weights: torch.Tensor,
447
+ topk_ids: torch.Tensor,
448
+ top_k: int,
449
+ expert_map: torch.Tensor = None):
450
+ original_shape = hidden_states.shape
451
+ if len(original_shape) == 3:
452
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
453
+
454
+ num_tokens, _ = hidden_states.shape
455
+ num_experts = w1.shape[0]
456
+ dtype = hidden_states.dtype
457
+ device = hidden_states.device
458
+
459
+ if expert_map is not None:
460
+ # Generate token indices and flatten
461
+ token_indices = (torch.arange(num_tokens,
462
+ device=device,
463
+ dtype=torch.int64).unsqueeze(1).expand(
464
+ -1, top_k).reshape(-1))
465
+
466
+ # Flatten token-to-expert mappings and map to local experts
467
+ weights_flat = topk_weights.view(-1)
468
+ experts_flat = topk_ids.view(-1)
469
+ local_experts_flat = expert_map[experts_flat]
470
+
471
+ # Filter valid token-expert pairs
472
+ mask = local_experts_flat != -1
473
+ filtered_weights = torch.where(
474
+ mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype)
475
+ filtered_experts = torch.where(
476
+ mask, local_experts_flat,
477
+ torch.full_like(local_experts_flat,
478
+ num_experts)).to(topk_ids.dtype)
479
+
480
+ # Sort by local expert IDs
481
+ sort_indices = torch.argsort(filtered_experts)
482
+ sorted_token_indices = token_indices[sort_indices]
483
+ sorted_weights = filtered_weights[sort_indices]
484
+
485
+ # Compute token counts with minlength of num_experts
486
+ # This is equivalent to but faster than:
487
+ # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
488
+ token_counts = torch.zeros(num_experts + 1,
489
+ device=device,
490
+ dtype=torch.int64)
491
+ ones = torch.ones_like(filtered_experts, dtype=torch.int64)
492
+ token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
493
+ expert_tokens = token_counts[:num_experts]
494
+ # Rearrange hidden_states
495
+ hidden_states = hidden_states[sorted_token_indices]
496
+ group_list_type = 1
497
+ else:
498
+ row_idx_len = num_tokens * top_k
499
+ row_idx = torch.arange(0,
500
+ row_idx_len,
501
+ dtype=torch.int32,
502
+ device=topk_weights.device).view(
503
+ top_k, -1).permute(1, 0).contiguous()
504
+ hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
505
+ hidden_states,
506
+ row_idx=row_idx,
507
+ expert_idx=topk_ids,
508
+ active_num=num_tokens)
509
+
510
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
511
+ expanded_expert_idx, num_experts)
512
+ expert_tokens = expert_tokens.to(torch.int64)
513
+ group_list_type = 0
514
+
515
+ # `hidden_states` will be disposed in the `apply_mlp` function
516
+ hidden_states = apply_mlp(hidden_states,
517
+ w1,
518
+ w1_scale,
519
+ w2,
520
+ w2_scale,
521
+ expert_tokens,
522
+ group_list_type=group_list_type)
523
+
524
+ if expert_map is not None:
525
+ hidden_states.mul_(sorted_weights.unsqueeze(1))
526
+ final_hidden_states = torch.zeros(*original_shape,
527
+ device=device,
528
+ dtype=dtype)
529
+
530
+ num_valid_tokens = mask.sum()
531
+ valid_token_mask = torch.arange(
532
+ 0, sorted_token_indices.shape[0],
533
+ device=device).unsqueeze(1) < num_valid_tokens
534
+ hidden_states = hidden_states.masked_fill_(~valid_token_mask,
535
+ 0).to(dtype)
536
+ final_hidden_states.index_add_(0, sorted_token_indices, hidden_states)
537
+ else:
538
+ # TODO: Reorder device memory 2 times here, replace the current
539
+ # implementation here when suitable operators become available.
540
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
541
+ hidden_states,
542
+ skip1=None,
543
+ skip2=None,
544
+ bias=None,
545
+ scales=topk_weights,
546
+ expanded_src_to_dst_row=expanded_row_idx,
547
+ export_for_source_row=topk_ids,
548
+ )
549
+
550
+ if len(original_shape) == 3:
551
+ final_hidden_states = final_hidden_states.view(original_shape)
552
+ return final_hidden_states
553
+
554
+
555
+ class AscendW8A8DynamicLinearMethod:
556
+ """Linear method for Ascend W8A8_DYNAMIC.
557
+ """
558
+
559
+ def __init__(self):
560
+ self.transpose_weight = True
561
+
562
+ @staticmethod
563
+ def get_weight(input_size: int, output_size: int,
564
+ params_dtype: torch.dtype) -> Dict[str, Any]:
565
+ params_dict = {
566
+ "weight": torch.empty(output_size, input_size, dtype=torch.int8)
567
+ }
568
+ return params_dict
569
+
570
+ @staticmethod
571
+ def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
572
+ return {}
573
+
574
+ @staticmethod
575
+ def get_perchannel_param(
576
+ output_size: int,
577
+ params_dtype: torch.dtype,
578
+ ) -> Dict[str, Any]:
579
+ params_dict = {}
580
+ params_dict["weight_scale"] = torch.empty(output_size,
581
+ 1,
582
+ dtype=params_dtype)
583
+ params_dict["weight_offset"] = torch.empty(output_size,
584
+ 1,
585
+ dtype=params_dtype)
586
+ return params_dict
587
+
588
+ @staticmethod
589
+ def apply(
590
+ layer: torch.nn.Module,
591
+ x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
592
+ bias: Optional[torch.Tensor] = None,
593
+ tp_rank: Optional[int] = 0,
594
+ ) -> torch.Tensor:
595
+ config = getattr(layer, "_ascend_quant_config", {})
596
+ if not isinstance(x, tuple):
597
+ output_dtype = config.get("output_dtype", x.dtype)
598
+ quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
599
+ else:
600
+ assert "output_dtype" in config.keys(), (
601
+ f"DynamicLinearMethod needs explicitly specified `output_dtype`"
602
+ f"for pre-quantized input, got config [{config}]")
603
+ output_dtype = config["output_dtype"]
604
+ quantized_x, dynamic_scale = x
605
+ pertoken_scale = (dynamic_scale
606
+ if config.get("pertoken_scale", True) else None)
607
+
608
+ output = torch_npu.npu_quant_matmul(
609
+ quantized_x,
610
+ layer.weight,
611
+ layer.weight_scale,
612
+ pertoken_scale=pertoken_scale,
613
+ bias=bias,
614
+ output_dtype=output_dtype,
615
+ )
616
+ return ((output, dynamic_scale)
617
+ if config.get("return_scale", False) else output)
618
+
619
+ def process_weights_after_loading(self, layer):
620
+ if self.transpose_weight:
621
+ layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
622
+ # cast quantized weight tensors in NZ format (29) for higher inference speed
623
+ layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
624
+ layer.weight_scale.data = layer.weight_scale.data.flatten()
625
+ layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
626
+ layer.weight_offset.data = layer.weight_offset.data.flatten()
627
+
628
+
629
+ class AscendW8A8DynamicFusedMoEMethod:
630
+ """FusedMoe method for Ascend W8A8_DYNAMIC.
631
+ """
632
+
633
+ def __init__(self):
634
+ self.transpose_weight = True
635
+
636
+ self.ep_group = get_ep_group()
637
+
638
+ ascend_config = get_ascend_config()
639
+ self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
640
+
641
+ try:
642
+ device_group = self.ep_group.device_group
643
+ # TODO: Try local_rank = ep_group.rank_in_group
644
+ local_rank = torch.distributed.get_rank(group=device_group)
645
+ backend = device_group._get_backend(torch.device("npu"))
646
+ self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
647
+ local_rank)
648
+ except AttributeError:
649
+ self.moe_all_to_all_group_name = ""
650
+
651
+ @staticmethod
652
+ def get_weight(num_experts: int, intermediate_size_per_partition: int,
653
+ hidden_sizes: int,
654
+ params_dtype: torch.dtype) -> Dict[str, Any]:
655
+ param_dict = {}
656
+ param_dict["w13_weight"] = torch.empty(num_experts,
657
+ 2 *
658
+ intermediate_size_per_partition,
659
+ hidden_sizes,
660
+ dtype=torch.int8)
661
+ param_dict["w2_weight"] = torch.empty(num_experts,
662
+ hidden_sizes,
663
+ intermediate_size_per_partition,
664
+ dtype=torch.int8)
665
+ return param_dict
666
+
667
+ @staticmethod
668
+ def get_dynamic_quant_param(num_experts: int,
669
+ intermediate_size_per_partition: int,
670
+ hidden_sizes: int,
671
+ params_dtype: torch.dtype) -> Dict[str, Any]:
672
+ param_dict = {}
673
+ param_dict["w13_weight_scale"] = torch.empty(
674
+ num_experts,
675
+ 2 * intermediate_size_per_partition,
676
+ 1,
677
+ dtype=params_dtype)
678
+ param_dict["w13_weight_offset"] = torch.empty(
679
+ num_experts,
680
+ 2 * intermediate_size_per_partition,
681
+ 1,
682
+ dtype=params_dtype)
683
+ param_dict["w2_weight_scale"] = torch.empty(num_experts,
684
+ hidden_sizes,
685
+ 1,
686
+ dtype=params_dtype)
687
+ param_dict["w2_weight_offset"] = torch.empty(num_experts,
688
+ hidden_sizes,
689
+ 1,
690
+ dtype=params_dtype)
691
+ return param_dict
692
+
693
+ def apply(
694
+ self,
695
+ layer: torch.nn.Module,
696
+ x: torch.Tensor,
697
+ router_logits: torch.Tensor,
698
+ top_k: int,
699
+ renormalize: bool,
700
+ use_grouped_topk: bool = False,
701
+ global_num_experts: int = -1,
702
+ expert_map: Optional[torch.Tensor] = None,
703
+ topk_group: Optional[int] = None,
704
+ num_expert_group: Optional[int] = None,
705
+ custom_routing_function: Optional[Callable] = None,
706
+ scoring_func: str = "softmax",
707
+ e_score_correction_bias: Optional[torch.Tensor] = None,
708
+ is_prefill: bool = True,
709
+ enable_force_load_balance: bool = True,
710
+ log2phy: torch.Tensor = None,
711
+ global_redundant_expert_num: int = 0,
712
+ shared_experts: Optional[Any] = None,
713
+ **kwargs,
714
+ ) -> torch.Tensor:
715
+ assert router_logits.shape[
716
+ 1] == global_num_experts, "Number of global experts mismatch"
717
+
718
+ is_deepseek_v3_r1 = global_num_experts == 256
719
+ use_grouped_topk = (topk_group > 1 or num_expert_group > 1)
720
+
721
+ # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
722
+ if use_grouped_topk and is_deepseek_v3_r1:
723
+ topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
724
+ router_logits,
725
+ k=top_k, # topk当前写8
726
+ bias=e_score_correction_bias,
727
+ k_group=topk_group, # fix: 4
728
+ group_count=num_expert_group, # fix 8
729
+ group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
730
+ renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
731
+ norm_type=1, # 0: softmax; 1: sigmoid(fix)
732
+ # out_flag=False, # todo new api; 第三个输出是否输出
733
+ # y2_flag=False, # old api; 第三个输出是否输出
734
+ routed_scaling_factor=1,
735
+ eps=float(1e-20))
736
+ else:
737
+ topk_weights, topk_ids = select_experts(
738
+ hidden_states=x,
739
+ router_logits=router_logits,
740
+ top_k=top_k,
741
+ use_grouped_topk=use_grouped_topk,
742
+ renormalize=renormalize,
743
+ topk_group=topk_group,
744
+ num_expert_group=num_expert_group,
745
+ custom_routing_function=custom_routing_function,
746
+ scoring_func=scoring_func,
747
+ e_score_correction_bias=e_score_correction_bias,
748
+ )
749
+
750
+ # this is a naive implementation for experts load balance so as
751
+ # to avoid accumulating too much tokens on a single rank.
752
+ # currently it is only activated when doing profile runs.
753
+ if enable_force_load_balance:
754
+ topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
755
+
756
+ topk_weights = topk_weights.to(x.dtype)
757
+
758
+ fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
759
+ is_prefill, is_deepseek_v3_r1)
760
+ if fused_moe_state == FusedMoEState.AllGatherEP:
761
+ return fused_experts_with_allgather(
762
+ hidden_states=x,
763
+ w1=layer.w13_weight,
764
+ w1_scale=layer.w13_weight_scale,
765
+ w2=layer.w2_weight,
766
+ w2_scale=layer.w2_weight_scale,
767
+ topk_weights=topk_weights,
768
+ topk_ids=topk_ids,
769
+ top_k=top_k,
770
+ expert_map=expert_map)
771
+ elif fused_moe_state == FusedMoEState.MC2:
772
+ return fused_experts_with_mc2(
773
+ hidden_states=x,
774
+ w1=layer.w13_weight,
775
+ w2=layer.w2_weight,
776
+ w1_scale=layer.w13_weight_scale,
777
+ w2_scale=layer.w2_weight_scale,
778
+ topk_weights=topk_weights,
779
+ topk_ids=topk_ids,
780
+ top_k=top_k,
781
+ expert_map=expert_map,
782
+ moe_all_to_all_group_name=self.moe_all_to_all_group_name,
783
+ log2phy=log2phy,
784
+ global_redundant_expert_num=global_redundant_expert_num,
785
+ shared_experts=shared_experts)
786
+ elif fused_moe_state in [
787
+ FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
788
+ ]:
789
+ return fused_experts(hidden_states=x,
790
+ w1=layer.w13_weight,
791
+ w1_scale=layer.w13_weight_scale,
792
+ w2=layer.w2_weight,
793
+ w2_scale=layer.w2_weight_scale,
794
+ topk_weights=topk_weights,
795
+ topk_ids=topk_ids,
796
+ top_k=top_k,
797
+ expert_map=expert_map)
798
+ else:
799
+ # The current implementation of deepseek moe splits hidden_states
800
+ # according to tp_size before they are feed into fused_moe module.
801
+ # Therefore, all2all is needed no matter how dp/tp is set so as to
802
+ # dispatch/combine tokens.
803
+ return fused_experts_with_all2all(
804
+ hidden_states=x,
805
+ w1=layer.w13_weight,
806
+ w1_scale=layer.w13_weight_scale,
807
+ w2=layer.w2_weight,
808
+ w2_scale=layer.w2_weight_scale,
809
+ topk_weights=topk_weights,
810
+ topk_ids=topk_ids,
811
+ top_k=top_k,
812
+ expert_map=expert_map,
813
+ ep_group=self.ep_group,
814
+ log2phy=log2phy,
815
+ global_redundant_expert_num=global_redundant_expert_num,
816
+ )
817
+
818
+ def process_weights_after_loading(self, layer):
819
+ if self.transpose_weight:
820
+ layer.w13_weight.data = layer.w13_weight.data.transpose(
821
+ 1, 2).contiguous()
822
+ layer.w2_weight.data = layer.w2_weight.data.transpose(
823
+ 1, 2).contiguous()
824
+ layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
825
+ layer.w13_weight_scale.data.shape[0], -1)
826
+ layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
827
+ layer.w13_weight_offset.data.shape[0], -1)
828
+ layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
829
+ layer.w2_weight_scale.data.shape[0], -1)
830
+ layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
831
+ layer.w2_weight_offset.data.shape[0], -1)
inference/vllm_ascend/utils.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # Copyright 2023 The vLLM team.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # This file is a part of the vllm-ascend project.
17
+ # Adapted from vllm-project/vllm/vllm/worker/worker.py
18
+ #
19
+
20
+ import atexit
21
+ import fcntl
22
+ import math
23
+ import os
24
+ import shutil
25
+ from contextlib import contextmanager, nullcontext
26
+ from enum import Enum
27
+ from threading import Lock
28
+ from typing import TYPE_CHECKING, List, Tuple
29
+
30
+ import torch
31
+ import torch_npu # noqa: F401 # noqa: F401
32
+ from packaging.version import InvalidVersion, Version
33
+ from torch_npu.npu.streams import Event
34
+ from vllm.logger import logger
35
+
36
+ import vllm_ascend.envs as envs
37
+ from vllm_ascend.ascend_config import get_ascend_config
38
+
39
+ try:
40
+ # Recent release of torchair has moved these ops to `.scope`.
41
+ from torchair.scope import npu_stream_switch as _npu_stream_switch
42
+ from torchair.scope import npu_wait_tensor as _npu_wait_tensor
43
+ except ImportError:
44
+ from torchair.ops import NpuStreamSwitch as _npu_stream_switch
45
+ from torchair.ops import npu_wait_tensor as _npu_wait_tensor
46
+
47
+ if TYPE_CHECKING:
48
+ from vllm.config import VllmConfig
49
+ else:
50
+ VllmConfig = None
51
+
52
+ # NOTE: Currently, we can only capture 1920 graphs at most,
53
+ # due to the limitation of ACL graph. This number is bounded by
54
+ # the number of streams, which is 2048, we save 128 streams
55
+ # as a buffer.
56
+ # Maximum number of graphs that can be captured by ACL Graph
57
+ MAX_CAPTURE_SIZE = 1920
58
+
59
+ ASCEND_QUATIZATION_METHOD = "ascend"
60
+ SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"]
61
+
62
+ ACL_FORMAT_FRACTAL_ND = 2
63
+ ACL_FORMAT_FRACTAL_NZ = 29
64
+
65
+ _CUSTOM_OP_ENABLED = None
66
+ _IS_310P = None
67
+ _SLEEP_MODE_ENABLED = None
68
+ _CURRENT_STREAM = None
69
+
70
+
71
+ def is_310p():
72
+ global _IS_310P
73
+ if _IS_310P is None:
74
+ from vllm_ascend import _build_info # type: ignore
75
+ _IS_310P = _build_info.__soc_version__.lower().startswith("ascend310p")
76
+ return _IS_310P
77
+
78
+
79
+ def sleep_mode_enabled():
80
+ global _SLEEP_MODE_ENABLED
81
+ if _SLEEP_MODE_ENABLED is None:
82
+ from vllm_ascend import _build_info # type: ignore
83
+ _SLEEP_MODE_ENABLED = _build_info.__sleep_mode_enabled__
84
+ return _SLEEP_MODE_ENABLED
85
+
86
+
87
+ def _round_up(x: int, align: int):
88
+ # round up x to align, for example, if align is 16, x will be rounded up to 16, 32, 48, etc.
89
+ # input: 15, 16 -> output: 16
90
+ # input: 17, 16 -> output: 32
91
+ # input: 30, 16 -> output: 32
92
+ # input: 33, 16 -> output: 48
93
+ # ...
94
+ return (x + align - 1) // align * align
95
+
96
+
97
+ def _custom_pad(x, pad_dims):
98
+ # pad the input tensor to the shape of pad_dims
99
+ # input: (13, 30), pad_dims: [0, 2, 0, 3]
100
+ # output: (16, 32)
101
+ return torch.nn.functional.pad(x, pad_dims)
102
+
103
+
104
+ def _custom_reshape(x, target_shape):
105
+ # reshape the input tensor to the shape of target_shape
106
+ # input: (16, 32), target_shape: [1, 16, 2, 16]
107
+ # output: (1, 16, 2, 16)
108
+ return x.reshape(target_shape)
109
+
110
+
111
+ def _custom_transpose(x, dim1, dim2):
112
+ # transpose the input tensor
113
+ # input: (1, 16, 2, 16), dim1: 1, dim2: 2
114
+ # output: (1, 2, 16, 16)
115
+ return x.transpose(dim1, dim2)
116
+
117
+
118
+ def nd_to_nz_2d(in_tensor: torch.Tensor) -> torch.Tensor:
119
+ # in_tensor: (13, 30)
120
+ aux_dims = [1, 0, 0, 16]
121
+ # aux_dims[1]: 16
122
+ aux_dims[1] = _round_up(in_tensor.size(0), 16)
123
+ # aux_dims[2]: 2
124
+ aux_dims[2] = _round_up(in_tensor.size(1), 16) // 16
125
+
126
+ # after: aux_dims: [1, 16, 2, 16]
127
+
128
+ pad_dims = [0, 0, 0, 0]
129
+ # pad_dims[1]: 2
130
+ pad_dims[1] = _round_up(in_tensor.size(1), 16) - in_tensor.size(1)
131
+ # pad_dims[3]: 3
132
+ pad_dims[3] = _round_up(in_tensor.size(0), 16) - in_tensor.size(0)
133
+
134
+ # after: pad_dims: [0, 2, 0, 3]
135
+
136
+ # return: (1, 2, 16, 16)
137
+ return _custom_transpose(
138
+ _custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1,
139
+ 2).contiguous()
140
+
141
+
142
+ def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor:
143
+ num_tokens = mask_tensor.shape[0]
144
+ max_seq_len = mask_tensor.shape[1]
145
+
146
+ tokens_pad = (num_tokens + 15) // 16 * 16
147
+ max_seq_len_pad = (max_seq_len + 15) // 16 * 16
148
+
149
+ mask_tensor_pad = \
150
+ torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device)
151
+ mask_tensor_pad[0][:num_tokens, :max_seq_len] = mask_tensor
152
+ mask = mask_tensor_pad.reshape(
153
+ (1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)
154
+ return mask
155
+
156
+
157
+ def aligned_16(tensor: torch.Tensor):
158
+ """Aligned tensor for 310P"""
159
+
160
+ # Get the size of the current 0th dimension
161
+ n = tensor.size(0)
162
+
163
+ # Calculate the aligned size
164
+ n_aligned = ((n + 15) // 16) * 16
165
+
166
+ # If already aligned, return the original tensor
167
+ if n == n_aligned:
168
+ return tensor
169
+
170
+ # Create a new tensor with shape (n_aligned, H, W) and fill it with zeros
171
+ new_tensor = torch.zeros(n_aligned,
172
+ *tensor.shape[1:],
173
+ dtype=tensor.dtype,
174
+ device=tensor.device)
175
+
176
+ # Copy the original tensor to the first N positions of the new tensor
177
+ new_tensor[:n] = tensor
178
+
179
+ return new_tensor
180
+
181
+
182
+ def maybe_converting_weight_acl_format(model, format=ACL_FORMAT_FRACTAL_NZ):
183
+ # currently, there are some operations which do not support ACL_FORMAT_FRACTAL_NZ
184
+ # in eager mode but support it in torchair graph mode. since ACL_FORMAT_FRACTAL_NZ
185
+ # is much more preferred than ACL_FORMAT_FRACTAL_ND on 300I Duo, we add this
186
+ # conversion when using torchair graph mode on 300I Duo platform.
187
+ # TODO: we will remove this conversion if npu_quant_grouped_matmul_dequant
188
+ # accepts weight format of ACL_FORMAT_FRACTAL_NZ in eager mode.
189
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE
190
+
191
+ use_torchair = get_ascend_config().torchair_graph_config.enabled
192
+ if not is_310p() or not use_torchair:
193
+ return
194
+ for module in model.modules():
195
+ if isinstance(module, FusedMoE):
196
+ if torch_npu.get_npu_format(module.w13_weight.data) == format:
197
+ return
198
+ module.w13_weight.data = torch_npu.npu_format_cast(
199
+ module.w13_weight.data, format)
200
+ module.w2_weight.data = torch_npu.npu_format_cast(
201
+ module.w2_weight.data, format)
202
+
203
+
204
+ def try_register_lib(lib_name: str, lib_info: str = ""):
205
+ import importlib
206
+ import importlib.util
207
+ try:
208
+ module_spec = importlib.util.find_spec(lib_name)
209
+ if module_spec is not None:
210
+ importlib.import_module(lib_name)
211
+ if lib_info:
212
+ logger.info(lib_info)
213
+ except Exception:
214
+ pass
215
+
216
+
217
+ def enable_custom_op():
218
+ """
219
+ Enable lazy init for vllm_ascend_C to avoid early initialization of CANN's RTS component.
220
+ Ensure that ASCEND_RT_VISIBLE_DEVICES can be dynamically modified before torch.npu.set_device().
221
+ """
222
+ global _CUSTOM_OP_ENABLED
223
+ if _CUSTOM_OP_ENABLED is not None:
224
+ return _CUSTOM_OP_ENABLED
225
+ try:
226
+ # register custom ops into torch_library here
227
+ import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
228
+ _CUSTOM_OP_ENABLED = True
229
+ except ImportError:
230
+ _CUSTOM_OP_ENABLED = False
231
+ logger.warning(
232
+ "Warning: Failed to register custom ops, all custom ops will be disabled"
233
+ )
234
+ return _CUSTOM_OP_ENABLED
235
+
236
+
237
+ def find_hccl_library() -> str:
238
+ """
239
+ We either use the library file specified by the `HCCL_SO_PATH`
240
+ environment variable, or we find the library file brought by PyTorch.
241
+ After importing `torch`, `libhccl.so` can be
242
+ found by `ctypes` automatically.
243
+ """
244
+ so_file = envs.HCCL_SO_PATH
245
+
246
+ # manually load the hccl library
247
+ if so_file:
248
+ logger.info("Found hccl from environment variable HCCL_SO_PATH=%s",
249
+ so_file)
250
+ else:
251
+ if torch.version.cann is not None:
252
+ so_file = "libhccl.so"
253
+ else:
254
+ raise ValueError("HCCL only supports Ascend NPU backends.")
255
+ logger.info("Found hccl from library %s", so_file)
256
+ return so_file
257
+
258
+
259
+ def current_stream() -> torch.npu.Stream:
260
+ """
261
+ replace `torch.npu.current_stream()` with `vllm.utils.current_stream()`.
262
+ it turns out that `torch.npu.current_stream()` is quite expensive,
263
+ as it will construct a new stream object at each call.
264
+ here we patch `torch.npu.set_stream` to keep track of the current stream
265
+ directly, so that we can avoid calling `torch.npu.current_stream()`.
266
+
267
+ """
268
+ global _CURRENT_STREAM
269
+ if _CURRENT_STREAM is None:
270
+ # when this function is called before any stream is set,
271
+ # we return the default stream.
272
+ _CURRENT_STREAM = torch.npu.current_stream()
273
+ return _CURRENT_STREAM
274
+
275
+
276
+ def adapt_patch(is_global_patch: bool = False):
277
+ if is_global_patch:
278
+ from vllm_ascend.patch import platform # noqa: F401
279
+ else:
280
+ from vllm_ascend.patch import worker # noqa: F401
281
+
282
+
283
+ def vllm_version_is(target_vllm_version: str):
284
+ if envs.VLLM_VERSION is not None:
285
+ vllm_version = envs.VLLM_VERSION
286
+ else:
287
+ import vllm
288
+ vllm_version = vllm.__version__
289
+ try:
290
+ return Version(vllm_version) == Version(target_vllm_version)
291
+ except InvalidVersion:
292
+ raise ValueError(
293
+ f"Invalid vllm version {vllm_version} found. A dev version of vllm "
294
+ "is installed probably. Set the environment variable VLLM_VERSION "
295
+ "to control it by hand. And please make sure the value follows the "
296
+ "format of x.y.z.")
297
+
298
+
299
+ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
300
+ """Update ACL graph capture sizes based on hardware limitations"""
301
+ # Store original configuration and temporarily clear it
302
+ compilation_config = vllm_config.compilation_config
303
+ original_sizes, compilation_config.cudagraph_capture_sizes = \
304
+ compilation_config.cudagraph_capture_sizes, None
305
+
306
+ # Calculate parallel configuration factor
307
+ num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
308
+ parallel_config = vllm_config.parallel_config
309
+
310
+ # TODO: Find out whether we need to take into account the pp_size
311
+ parallel_factor = 1 + sum(size > 1 for size in [
312
+ parallel_config.data_parallel_size_local,
313
+ parallel_config.tensor_parallel_size,
314
+ parallel_config.expert_parallel_size,
315
+ parallel_config.expert_tensor_parallel_size,
316
+ ])
317
+
318
+ # Calculate maximum supported batch sizes considering model architecture
319
+ max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
320
+ (num_hidden_layers + 1) / parallel_factor)
321
+ logger.info("Calculated maximum supported batch sizes for ACL graph: %s",
322
+ max_num_batch_sizes)
323
+
324
+ # If original sizes exceed maximum, sample a representative subset
325
+ if max_num_batch_sizes < len(original_sizes):
326
+ # Sample uniformly from original sizes
327
+ step = (len(original_sizes) - 1) / (max_num_batch_sizes - 1)
328
+ indices = [round(i * step) for i in range(max_num_batch_sizes)]
329
+
330
+ # Ensure first and last elements are preserved
331
+ indices[0], indices[-1] = 0, len(original_sizes) - 1
332
+
333
+ sampled_sizes = [original_sizes[i] for i in indices]
334
+ compilation_config.init_with_cudagraph_sizes(sampled_sizes)
335
+
336
+ logger.info(
337
+ "Adjusted ACL graph batch sizes for %s model (layers: %d): %d → %d sizes",
338
+ vllm_config.model_config.architectures[0],
339
+ num_hidden_layers,
340
+ len(original_sizes),
341
+ len(compilation_config.
342
+ cudagraph_capture_sizes # type: ignore[arg-type]
343
+ ))
344
+ else:
345
+ # No adjustment needed
346
+ compilation_config.cudagraph_capture_sizes = original_sizes
347
+ logger.info(
348
+ "No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes",
349
+ vllm_config.model_config.architectures[0], num_hidden_layers,
350
+ len(original_sizes))
351
+
352
+
353
+ # TODO(wxy): Move to ops module
354
+ def dispose_tensor(x: torch.Tensor):
355
+ x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))
356
+
357
+
358
+ class ProfileExecuteDuration:
359
+ _instance = None
360
+ _observations: List[Tuple[str, Event, Event]] = []
361
+ _lock = Lock()
362
+
363
+ def __new__(cls):
364
+ with cls._lock:
365
+ if cls._instance is None:
366
+ cls._instance = super().__new__(cls)
367
+ atexit.register(cls._instance.destroy)
368
+ return cls._instance
369
+
370
+ def destroy(self):
371
+ with self._lock:
372
+ self._observations.clear()
373
+
374
+ @contextmanager
375
+ def capture_async(self, duration_tag: str):
376
+ if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
377
+ yield
378
+ return
379
+
380
+ observe_start = Event(enable_timing=True)
381
+ observe_start.record()
382
+ try:
383
+ yield
384
+ finally:
385
+ observe_end = Event(enable_timing=True)
386
+ observe_end.record()
387
+ with self._lock:
388
+ self._observations.append(
389
+ (duration_tag, observe_start, observe_end))
390
+
391
+ def pop_captured_sync(self) -> dict:
392
+ """Pop and synchronize all events in the observation list"""
393
+ durations: dict[str, float] = {}
394
+ if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
395
+ return durations
396
+
397
+ while self._observations:
398
+ with self._lock:
399
+ tag, observe_start, observe_end = self._observations.pop()
400
+ observe_end.synchronize()
401
+ durations[tag] = observe_start.elapsed_time(observe_end)
402
+
403
+ return durations
404
+
405
+
406
+ # TODO(wxy): Move to ops module
407
+ def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
408
+ return _npu_stream_switch(tag, priority) if enabled else nullcontext()
409
+
410
+
411
+ # TODO(wxy): Move to ops module
412
+ def npu_wait_tensor(self: torch.Tensor,
413
+ dependency: torch.Tensor,
414
+ *,
415
+ enabled: bool = True):
416
+ return _npu_wait_tensor(self, dependency) if enabled else self
417
+
418
+
419
+ # TODO(wxy): Move to ops module
420
+ def npu_prefetch(input: torch.Tensor,
421
+ dependency: torch.Tensor,
422
+ max_size: int = 0,
423
+ *,
424
+ enabled: bool = True):
425
+ if not enabled:
426
+ return
427
+ input_size = input.element_size() * input.numel()
428
+ if max_size <= 0 or max_size > input_size:
429
+ max_size = input_size
430
+ torch_npu.npu_prefetch(input, dependency, max_size)
431
+
432
+
433
+ # TODO(zzzzwwjj): move this into forward_context
434
+ class FusedMoEState(Enum):
435
+ AllGather = 0
436
+ All2All = 1
437
+ MC2 = 2
438
+ AllGatherEP = 3
439
+ NaiveMulticast = 4
440
+
441
+
442
+ # TODO(ttanzhiqiang): rm_router_logits
443
+ # dp>1 will trigger
444
+ # In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
445
+ def get_rm_router_logits_state(ep_size: int, dp_size: int,
446
+ is_deepseek_v3_r1: bool):
447
+ # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
448
+ # only supports deepseek v3/r1
449
+ if dp_size > 1:
450
+ if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
451
+ and is_deepseek_v3_r1):
452
+ return True
453
+ elif ep_size == 1 and is_deepseek_v3_r1:
454
+ return True
455
+ return False
456
+
457
+
458
+ # TODO(ttanzhiqiang): all_reduce merge
459
+ # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
460
+ # Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
461
+ def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
462
+ # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
463
+ # only supports deepseek v3/r1
464
+ if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
465
+ and is_deepseek_v3_r1):
466
+ return True
467
+ elif ep_size == 1 and is_deepseek_v3_r1:
468
+ return True
469
+ return False
470
+
471
+
472
+ # TODO(zzzzwwjj): add soc_version to choose branch
473
+ def get_fused_moe_state(ep_size: int, with_prefill: bool,
474
+ is_deepseek_v3_r1: bool):
475
+ # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
476
+ # only supports deepseek v3/r1
477
+ if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
478
+ and is_deepseek_v3_r1 and not with_prefill):
479
+ return FusedMoEState.AllGatherEP
480
+ elif ep_size == 1:
481
+ if with_prefill:
482
+ return FusedMoEState.NaiveMulticast
483
+ else:
484
+ return FusedMoEState.AllGather
485
+ # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
486
+ elif ep_size < 16 or with_prefill:
487
+ return FusedMoEState.All2All
488
+ else:
489
+ return FusedMoEState.MC2
490
+
491
+
492
+ KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
493
+ KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
494
+ TORCHAIR_CACHE_PATH_NAME = ".torchair_cache"
495
+ TORCHAIR_CACHE_DIR = os.getenv(
496
+ 'TORCHAIR_CACHE_HOME', os.path.join(os.getcwd(), TORCHAIR_CACHE_PATH_NAME))
497
+
498
+
499
+ def get_torchair_current_work_dir(file_name=None):
500
+ if file_name is None:
501
+ return TORCHAIR_CACHE_DIR
502
+ return os.path.join(TORCHAIR_CACHE_DIR, file_name)
503
+
504
+
505
+ def check_torchair_cache_exist():
506
+ res = False
507
+ torch_air_abs_path = get_torchair_current_work_dir()
508
+ if os.path.exists(torch_air_abs_path):
509
+ file_list = os.listdir(torch_air_abs_path)
510
+ if len(file_list) != 0:
511
+ res = True
512
+ return res
513
+
514
+
515
+ def check_kv_cache_bytes_cache_exist():
516
+ res = False
517
+ kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
518
+ KV_CACHE_BYTES_CACHE_PATH_NAME)
519
+ if os.path.exists(kv_cache_bytes_cache_abs_path):
520
+ file_list = os.listdir(kv_cache_bytes_cache_abs_path)
521
+ if len(file_list) != 0:
522
+ res = True
523
+ return res
524
+
525
+
526
+ def read_kv_cache_bytes_from_file(rank) -> int:
527
+ kv_cache_bytes = -1
528
+ kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
529
+ KV_CACHE_BYTES_CACHE_PATH_NAME)
530
+ kv_cache_bytes_file = os.path.join(
531
+ kv_cache_bytes_cache_abs_path,
532
+ f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
533
+ with open(kv_cache_bytes_file, "r", encoding="utf-8") as f:
534
+ with file_lock(f, fcntl.LOCK_SH):
535
+ kv_cache_bytes = int(f.readline())
536
+ return kv_cache_bytes
537
+
538
+
539
+ @contextmanager
540
+ def file_lock(file_descriptor, lock_type):
541
+ fcntl.flock(file_descriptor, lock_type)
542
+ try:
543
+ yield
544
+ finally:
545
+ fcntl.flock(file_descriptor, fcntl.LOCK_UN)
546
+
547
+
548
+ def write_kv_cache_bytes_to_file(rank, kv_cache_bytes):
549
+ kv_cache_bytes_cache_abs_path = get_torchair_current_work_dir(
550
+ KV_CACHE_BYTES_CACHE_PATH_NAME)
551
+ os.makedirs(kv_cache_bytes_cache_abs_path, exist_ok=True)
552
+ kv_cache_bytes_file = os.path.join(
553
+ kv_cache_bytes_cache_abs_path,
554
+ f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
555
+ with open(kv_cache_bytes_file, "w", encoding="utf-8") as f:
556
+ with file_lock(f, fcntl.LOCK_EX):
557
+ f.write(f"{kv_cache_bytes}")
558
+
559
+
560
+ def delete_torchair_cache_file():
561
+ torch_air_abs_path = get_torchair_current_work_dir()
562
+ if os.path.exists(torch_air_abs_path):
563
+ shutil.rmtree(torch_air_abs_path)
inference/vllm_ascend/worker/model_runner_v1.py ADDED
The diff for this file is too large to render. See raw diff
 
inference/vllm_ascend/worker/npu_input_batch.py ADDED
@@ -0,0 +1,796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ # Copyright 2023 The vLLM team.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # This file is a part of the vllm-ascend project.
17
+ # Adapted from vllm-project/vllm/vllm/worker/gpu_input_batch.py
18
+ #
19
+
20
+ from dataclasses import dataclass
21
+ from typing import Optional, cast, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+ from vllm.lora.request import LoRARequest
26
+ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
27
+ from vllm.pooling_params import PoolingParams
28
+ from vllm.sampling_params import SamplingParams, SamplingType
29
+ from vllm.utils import swap_dict_values
30
+ from vllm.v1.outputs import LogprobsTensors
31
+ from vllm.v1.sample.logits_processor import init_builtin_logitsprocs
32
+ from vllm.v1.sample.metadata import SamplingMetadata
33
+ from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
34
+ from vllm.v1.utils import copy_slice
35
+ from vllm.v1.worker.block_table import MultiGroupBlockTable
36
+
37
+ from vllm_ascend.pool.metadata import PoolingMetadata
38
+
39
+ _SAMPLING_EPS = 1e-5
40
+
41
+
42
+ @dataclass
43
+ class CachedRequestState:
44
+
45
+ req_id: str
46
+ prompt_token_ids: list[int]
47
+ mm_inputs: list[MultiModalKwargs]
48
+ mm_positions: list[PlaceholderRange]
49
+ sampling_params: Optional[SamplingParams]
50
+ pooling_params: Optional[PoolingParams]
51
+ generator: Optional[torch.Generator]
52
+
53
+ block_ids: tuple[list[int], ...]
54
+ num_computed_tokens: int
55
+ output_token_ids: list[int]
56
+
57
+ mrope_positions: Optional[torch.Tensor] = None
58
+ mrope_position_delta: Optional[int] = None
59
+
60
+ lora_request: Optional[LoRARequest] = None
61
+
62
+ def __post_init__(self):
63
+ self.num_prompt_tokens = len(self.prompt_token_ids)
64
+
65
+ @property
66
+ def num_tokens(self) -> int:
67
+ return self.num_prompt_tokens + len(self.output_token_ids)
68
+
69
+ def get_token_id(self, idx: int) -> int:
70
+ if idx < self.num_prompt_tokens:
71
+ return self.prompt_token_ids[idx]
72
+ else:
73
+ return self.output_token_ids[idx - self.num_prompt_tokens]
74
+
75
+ @dataclass
76
+ class SamplingMetadataTopNSigma(SamplingMetadata):
77
+ top_n_sigma: torch.Tensor
78
+ no_top_n_sigma: bool
79
+
80
+ class InputBatch:
81
+
82
+ def __init__(
83
+ self,
84
+ max_num_reqs: int,
85
+ max_model_len: int,
86
+ max_num_batched_tokens: int,
87
+ device: torch.device,
88
+ pin_memory: bool,
89
+ vocab_size: int,
90
+ block_sizes: list[int], # The block_size of each kv cache group
91
+ logits_processing_needs_token_ids: bool = False,
92
+ is_spec_decode: bool = False,
93
+ ):
94
+ self.is_spec_decode = is_spec_decode
95
+ self.max_num_reqs = max_num_reqs
96
+ self.max_model_len = max_model_len
97
+ self.max_num_batched_tokens = max_num_batched_tokens
98
+ self.device = device
99
+ self.pin_memory = pin_memory
100
+ self.vocab_size = vocab_size
101
+ self.logits_processing_needs_token_ids = (
102
+ logits_processing_needs_token_ids)
103
+
104
+ self._req_ids: list[Optional[str]] = []
105
+ self.req_id_to_index: dict[str, int] = {}
106
+
107
+ # TODO(woosuk): This buffer could be too large if max_model_len is big.
108
+ # Find a way to reduce the CPU memory usage.
109
+ # This buffer is not directly transferred to the NPU, so it does not
110
+ # need to be pinned.
111
+ self.token_ids_cpu_tensor = torch.zeros(
112
+ (max_num_reqs, max_model_len),
113
+ device="cpu",
114
+ dtype=torch.int32,
115
+ pin_memory=False,
116
+ )
117
+ self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
118
+ self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
119
+ self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
120
+ self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
121
+ self.num_computed_tokens_cpu_tensor = torch.zeros(
122
+ (max_num_reqs, ),
123
+ device="cpu",
124
+ dtype=torch.int32,
125
+ pin_memory=pin_memory,
126
+ )
127
+ self.num_computed_tokens_cpu = \
128
+ self.num_computed_tokens_cpu_tensor.numpy()
129
+
130
+ # Block table.
131
+ self.block_table = MultiGroupBlockTable(
132
+ max_num_reqs=max_num_reqs,
133
+ max_model_len=max_model_len,
134
+ max_num_batched_tokens=max_num_batched_tokens,
135
+ pin_memory=pin_memory,
136
+ device=device,
137
+ block_sizes=block_sizes,
138
+ )
139
+
140
+ # Sampling-related.
141
+ self.temperature = torch.empty((max_num_reqs, ),
142
+ dtype=torch.float32,
143
+ device=device)
144
+ self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
145
+ dtype=torch.float32,
146
+ device="cpu",
147
+ pin_memory=pin_memory)
148
+ self.temperature_cpu = self.temperature_cpu_tensor.numpy()
149
+ self.greedy_reqs: set[str] = set()
150
+ self.random_reqs: set[str] = set()
151
+
152
+ self.top_p = torch.empty((max_num_reqs, ),
153
+ dtype=torch.float32,
154
+ device=device)
155
+ self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
156
+ dtype=torch.float32,
157
+ device="cpu",
158
+ pin_memory=pin_memory)
159
+ self.top_p_cpu = self.top_p_cpu_tensor.numpy()
160
+ self.top_p_reqs: set[str] = set()
161
+
162
+ self.top_k = torch.empty((max_num_reqs, ),
163
+ dtype=torch.int32,
164
+ device=device)
165
+ self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
166
+ dtype=torch.int32,
167
+ device="cpu",
168
+ pin_memory=pin_memory)
169
+ self.top_k_cpu = self.top_k_cpu_tensor.numpy()
170
+ self.top_k_reqs: set[str] = set()
171
+
172
+ # IDs of requests which do not support spec decoding
173
+ self.spec_decode_unsupported_reqs: set[str] = set()
174
+
175
+ self.min_p = torch.empty((max_num_reqs, ),
176
+ dtype=torch.float32,
177
+ device=device)
178
+ self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
179
+ dtype=torch.float32,
180
+ device="cpu",
181
+ pin_memory=pin_memory)
182
+ self.min_p_cpu = self.min_p_cpu_tensor.numpy()
183
+ self.min_p_reqs: set[str] = set()
184
+
185
+ # topnsigma penalty
186
+ self.top_n_sigma = torch.empty((max_num_reqs, ),
187
+ dtype=torch.float,
188
+ device=device)
189
+ self.top_n_sigma_cpu_tensor = torch.empty(
190
+ (max_num_reqs, ),
191
+ dtype=torch.float,
192
+ device="cpu",
193
+ pin_memory=pin_memory)
194
+ self.top_n_sigma_cpu = \
195
+ self.top_n_sigma_cpu_tensor.numpy()
196
+ self.top_n_sigma_reqs: set[str] = set()
197
+
198
+ # Frequency penalty related data structures
199
+ self.frequency_penalties = torch.empty((max_num_reqs, ),
200
+ dtype=torch.float,
201
+ device=device)
202
+ self.frequency_penalties_cpu_tensor = torch.empty(
203
+ (max_num_reqs, ),
204
+ dtype=torch.float,
205
+ device="cpu",
206
+ pin_memory=pin_memory)
207
+ self.frequency_penalties_cpu = \
208
+ self.frequency_penalties_cpu_tensor.numpy()
209
+ self.frequency_penalties_reqs: set[str] = set()
210
+
211
+ # Presence penalty related data structures
212
+ self.presence_penalties = torch.empty((max_num_reqs, ),
213
+ dtype=torch.float,
214
+ device=device)
215
+ self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
216
+ dtype=torch.float,
217
+ device="cpu",
218
+ pin_memory=pin_memory)
219
+ self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
220
+ )
221
+ self.presence_penalties_reqs: set[str] = set()
222
+
223
+ # Repetition penalty related data structures
224
+ self.repetition_penalties = torch.empty((max_num_reqs, ),
225
+ dtype=torch.float,
226
+ device=device)
227
+ self.repetition_penalties_cpu_tensor = torch.empty(
228
+ (max_num_reqs, ),
229
+ dtype=torch.float,
230
+ device="cpu",
231
+ pin_memory=pin_memory)
232
+ self.repetition_penalties_cpu = \
233
+ self.repetition_penalties_cpu_tensor.numpy()
234
+ self.repetition_penalties_reqs: set[str] = set()
235
+
236
+ # req_index -> (min_tokens, stop_token_ids)
237
+ self.min_tokens: dict[int, tuple[int, set[int]]] = {}
238
+
239
+ # lora related
240
+ self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
241
+ dtype=np.int32)
242
+ self.lora_id_to_request_ids: dict[int, set[str]] = {}
243
+ self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
244
+
245
+ # req_index -> generator
246
+ # NOTE(woosuk): The indices of the requests that do not have their own
247
+ # generator should not be included in the dictionary.
248
+ self.generators: dict[int, torch.Generator] = {}
249
+
250
+ self.num_logprobs: dict[str, int] = {}
251
+ # NOTE(rob): num_prompt_logprobs only includes reqs
252
+ # that are currently in the prefill phase.
253
+ self.num_prompt_logprobs: dict[str, int] = {}
254
+
255
+ # To accumulate prompt logprobs tensor chunks across prefill steps.
256
+ self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
257
+
258
+ self.logit_bias: list[Optional[dict[int,
259
+ float]]] = [None] * max_num_reqs
260
+ self.has_allowed_token_ids: set[str] = set()
261
+ # NOTE(lufang): In the mask tensor, if the corresponding token allowed,
262
+ # the value is False. Since we use masked_fill_ to set -inf.
263
+ self.allowed_token_ids_mask: Optional[torch.Tensor] = None
264
+ self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
265
+
266
+ # req_index -> bad_words_token_ids
267
+ self.bad_words_token_ids: dict[int, list[list[int]]] = {}
268
+
269
+ self.req_output_token_ids: list[Optional[list[int]]] = []
270
+
271
+ # Define logits processors.
272
+ # TODO(andy): logits processor list should be extensible via engine
273
+ # constructor argument; for now the list is fixed.
274
+ self.logitsprocs = init_builtin_logitsprocs(
275
+ pin_memory_available=pin_memory,
276
+ max_num_reqs=max_num_reqs + 1,
277
+ device=device)
278
+
279
+ # This is updated each time the batch constituents change.
280
+ self.sampling_metadata = self._make_sampling_metadata()
281
+
282
+ self.pooling_params: dict[str, PoolingParams] = {}
283
+
284
+ @property
285
+ def req_ids(self) -> list[str]:
286
+ # None elements should only be present transiently
287
+ # while performing state updates to the batch.
288
+ return cast(list[str], self._req_ids)
289
+
290
+ def add_request(
291
+ self,
292
+ request: "CachedRequestState",
293
+ req_index: Optional[int] = None,
294
+ ) -> None:
295
+ if req_index is None:
296
+ req_index = self.num_reqs
297
+ assert req_index < self.max_num_reqs
298
+
299
+ req_id = request.req_id
300
+ if req_index == len(self._req_ids):
301
+ self._req_ids.append(req_id)
302
+ self.req_output_token_ids.append(request.output_token_ids)
303
+ else:
304
+ self._req_ids[req_index] = req_id
305
+ self.req_output_token_ids[req_index] = request.output_token_ids
306
+
307
+ self.req_id_to_index[req_id] = req_index
308
+
309
+ # Copy the prompt token ids and output token ids.
310
+ num_prompt_tokens = len(request.prompt_token_ids)
311
+ self.num_prompt_tokens[req_index] = num_prompt_tokens
312
+ self.token_ids_cpu[
313
+ req_index, :num_prompt_tokens] = request.prompt_token_ids
314
+ start_idx = num_prompt_tokens
315
+ end_idx = start_idx + len(request.output_token_ids)
316
+ self.token_ids_cpu[req_index,
317
+ start_idx:end_idx] = request.output_token_ids
318
+ # Number of token ids in token_ids_cpu.
319
+ # NOTE(woosuk): This may include spec decode tokens.
320
+ self.num_tokens[req_index] = request.num_tokens
321
+ # Number of tokens without spec decode tokens.
322
+ self.num_tokens_no_spec[req_index] = request.num_tokens
323
+
324
+ self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
325
+ self.block_table.add_row(request.block_ids, req_index)
326
+
327
+ if sampling_params := request.sampling_params:
328
+ if self.is_spec_decode and is_spec_decode_unsupported(
329
+ sampling_params):
330
+ self.spec_decode_unsupported_reqs.add(req_id)
331
+ if sampling_params.sampling_type == SamplingType.GREEDY:
332
+ # Avoid later division by zero.
333
+ self.temperature_cpu[req_index] = -1.0
334
+ self.greedy_reqs.add(req_id)
335
+ else:
336
+ self.temperature_cpu[req_index] = sampling_params.temperature
337
+ self.random_reqs.add(req_id)
338
+
339
+ self.top_p_cpu[req_index] = sampling_params.top_p
340
+ if sampling_params.top_p < 1:
341
+ self.top_p_reqs.add(req_id)
342
+ top_k = sampling_params.top_k
343
+ if 0 < top_k < self.vocab_size:
344
+ self.top_k_reqs.add(req_id)
345
+ else:
346
+ top_k = self.vocab_size
347
+ self.top_k_cpu[req_index] = top_k
348
+ self.min_p_cpu[req_index] = sampling_params.min_p
349
+ self.frequency_penalties_cpu[
350
+ req_index] = sampling_params.frequency_penalty
351
+ if sampling_params.min_p > _SAMPLING_EPS:
352
+ self.min_p_reqs.add(req_id)
353
+ if sampling_params.frequency_penalty != 0.0:
354
+ self.frequency_penalties_reqs.add(req_id)
355
+ self.presence_penalties_cpu[
356
+ req_index] = sampling_params.presence_penalty
357
+ if sampling_params.presence_penalty != 0.0:
358
+ self.presence_penalties_reqs.add(req_id)
359
+ self.repetition_penalties_cpu[
360
+ req_index] = sampling_params.repetition_penalty
361
+ if sampling_params.repetition_penalty != 1.0:
362
+ self.repetition_penalties_reqs.add(req_id)
363
+ if sampling_params.min_tokens:
364
+ self.min_tokens[req_index] = (
365
+ sampling_params.min_tokens,
366
+ sampling_params.all_stop_token_ids)
367
+
368
+ if sampling_params.extra_args and "top_n_sigma" in sampling_params.extra_args:
369
+ self.top_n_sigma_cpu[
370
+ req_index] = sampling_params.extra_args["top_n_sigma"]
371
+ self.top_n_sigma_reqs.add(req_id)
372
+ else:
373
+ self.top_n_sigma_cpu[req_index] = -1
374
+
375
+ # NOTE(woosuk): self.generators should not include the requests that
376
+ # do not have their own generator.
377
+ if request.generator is not None:
378
+ self.generators[req_index] = request.generator
379
+
380
+ if sampling_params.logprobs is not None:
381
+ self.num_logprobs[req_id] = sampling_params.logprobs
382
+ if sampling_params.prompt_logprobs is not None:
383
+ self.num_prompt_logprobs[
384
+ req_id] = sampling_params.prompt_logprobs
385
+ if sampling_params.logit_bias is not None:
386
+ self.logit_bias[req_index] = sampling_params.logit_bias
387
+
388
+ if sampling_params.allowed_token_ids:
389
+ self.has_allowed_token_ids.add(req_id)
390
+ if self.allowed_token_ids_mask_cpu_tensor is None:
391
+ # Lazy allocation for this tensor, which can be large.
392
+ # False means we don't fill with -inf.
393
+ self.allowed_token_ids_mask = torch.zeros(
394
+ self.max_num_reqs,
395
+ self.vocab_size,
396
+ dtype=torch.bool,
397
+ device=self.device)
398
+ self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
399
+ self.max_num_reqs,
400
+ self.vocab_size,
401
+ dtype=torch.bool,
402
+ device="cpu")
403
+ self.allowed_token_ids_mask_cpu_tensor[req_index] = True
404
+ # False means we don't fill with -inf.
405
+ self.allowed_token_ids_mask_cpu_tensor[req_index][
406
+ sampling_params.allowed_token_ids] = False
407
+
408
+ if sampling_params.bad_words_token_ids:
409
+ self.bad_words_token_ids[
410
+ req_index] = sampling_params.bad_words_token_ids
411
+ else:
412
+ assert request.pooling_params is not None
413
+ self.pooling_params[req_id] = request.pooling_params
414
+
415
+ # Add request lora ID
416
+ if request.lora_request:
417
+ lora_id = request.lora_request.lora_int_id
418
+ if lora_id not in self.lora_id_to_request_ids:
419
+ self.lora_id_to_request_ids[lora_id] = set()
420
+
421
+ self.request_lora_mapping[req_index] = lora_id
422
+ self.lora_id_to_request_ids[lora_id].add(request.req_id)
423
+ self.lora_id_to_lora_request[lora_id] = request.lora_request
424
+ else:
425
+ # No LoRA
426
+ self.request_lora_mapping[req_index] = 0
427
+
428
+ def remove_request(self, req_id: str) -> Optional[int]:
429
+ """This method must always be followed by a call to condense()."""
430
+
431
+ req_index = self.req_id_to_index.pop(req_id, None)
432
+ if req_index is None:
433
+ return None
434
+ self._req_ids[req_index] = None
435
+ self.req_output_token_ids[req_index] = None
436
+
437
+ self.greedy_reqs.discard(req_id)
438
+ self.random_reqs.discard(req_id)
439
+ self.top_p_reqs.discard(req_id)
440
+ self.top_k_reqs.discard(req_id)
441
+ self.min_p_reqs.discard(req_id)
442
+ self.min_tokens.pop(req_index, None)
443
+ self.frequency_penalties_reqs.discard(req_id)
444
+ self.presence_penalties_reqs.discard(req_id)
445
+ self.repetition_penalties_reqs.discard(req_id)
446
+ self.spec_decode_unsupported_reqs.discard(req_id)
447
+ self.top_n_sigma_reqs.discard(req_id)
448
+ self.generators.pop(req_index, None)
449
+ self.num_logprobs.pop(req_id, None)
450
+ self.num_prompt_logprobs.pop(req_id, None)
451
+ self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
452
+
453
+ # LoRA
454
+ lora_id = self.request_lora_mapping[req_index]
455
+ if lora_id != 0:
456
+ self.lora_id_to_request_ids[lora_id].discard(req_id)
457
+ if len(self.lora_id_to_request_ids[lora_id]) == 0:
458
+ self.lora_id_to_request_ids.pop(lora_id)
459
+ self.lora_id_to_lora_request.pop(lora_id)
460
+ self.request_lora_mapping[req_index] = 0
461
+
462
+ self.logit_bias[req_index] = None
463
+ self.has_allowed_token_ids.discard(req_id)
464
+ if self.allowed_token_ids_mask_cpu_tensor is not None:
465
+ # False means we don't fill with -inf.
466
+ self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
467
+ self.bad_words_token_ids.pop(req_index, None)
468
+ self.pooling_params.pop(req_id, None)
469
+ return req_index
470
+
471
+ def swap_states(self, i1: int, i2: int) -> None:
472
+ old_id_i1 = self._req_ids[i1]
473
+ old_id_i2 = self._req_ids[i2]
474
+ self._req_ids[i1], self._req_ids[i2] =\
475
+ self._req_ids[i2], self._req_ids[i1] # noqa
476
+ self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
477
+ self.req_output_token_ids[i2], self.req_output_token_ids[i1]
478
+ assert old_id_i1 is not None and old_id_i2 is not None
479
+ self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
480
+ self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
481
+ self.num_tokens[i1], self.num_tokens[i2] =\
482
+ self.num_tokens[i2], self.num_tokens[i1]
483
+ self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
484
+ self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
485
+ self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
486
+ self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
487
+ self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
488
+ self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
489
+ self.temperature_cpu[i1], self.temperature_cpu[i2] =\
490
+ self.temperature_cpu[i2], self.temperature_cpu[i1]
491
+ self.top_p_cpu[i1], self.top_p_cpu[i2] =\
492
+ self.top_p_cpu[i2], self.top_p_cpu[i1]
493
+ self.top_k_cpu[i1], self.top_k_cpu[i2] =\
494
+ self.top_k_cpu[i2], self.top_k_cpu[i1]
495
+ self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
496
+ self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
497
+ self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
498
+ self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
499
+ self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
500
+ self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
501
+ self.min_p_cpu[i1], self.min_p_cpu[i2] =\
502
+ self.min_p_cpu[i2], self.min_p_cpu[i1]
503
+ self.top_n_sigma_cpu[i1], self.top_n_sigma_cpu[i2] =\
504
+ self.top_n_sigma_cpu[i2], self.top_n_sigma_cpu[i1]
505
+
506
+ # NOTE: the following is unsafe
507
+ # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
508
+ # self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
509
+ # instead, we need to temporiarily copy the data for one of the indices
510
+ # TODO(lucas): optimize this by only copying valid indices
511
+ tmp = self.token_ids_cpu[i1, ...].copy()
512
+ self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
513
+ self.token_ids_cpu[i2, ...] = tmp
514
+
515
+ swap_dict_values(self.generators, i1, i2)
516
+ swap_dict_values(self.min_tokens, i1, i2)
517
+ swap_dict_values(self.bad_words_token_ids, i1, i2)
518
+
519
+ self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
520
+ self.request_lora_mapping[i2], self.request_lora_mapping[i1]
521
+ self.logit_bias[i1], self.logit_bias[i2] =\
522
+ self.logit_bias[i2], self.logit_bias[i1]
523
+
524
+ if self.allowed_token_ids_mask_cpu_tensor is not None:
525
+ self.allowed_token_ids_mask_cpu_tensor[i1], \
526
+ self.allowed_token_ids_mask_cpu_tensor[i2] =\
527
+ self.allowed_token_ids_mask_cpu_tensor[i2], \
528
+ self.allowed_token_ids_mask_cpu_tensor[i1]
529
+ self.block_table.swap_row(i1, i2)
530
+
531
+ def condense(self, empty_req_indices: list[int]) -> None:
532
+ """Move non-empty requests down into lower, empty indices.
533
+
534
+ Args:
535
+ empty_req_indices: empty batch indices, sorted descending.
536
+ """
537
+ num_reqs = self.num_reqs
538
+ if num_reqs == 0:
539
+ # The batched states are empty.
540
+ self._req_ids.clear()
541
+ self.req_output_token_ids.clear()
542
+ return
543
+
544
+ # NOTE(woosuk): This function assumes that the empty_req_indices
545
+ # is sorted in descending order.
546
+ last_req_index = num_reqs + len(empty_req_indices) - 1
547
+ while empty_req_indices:
548
+ # Find the largest non-empty index.
549
+ while last_req_index in empty_req_indices:
550
+ last_req_index -= 1
551
+
552
+ # Find the smallest empty index.
553
+ empty_index = empty_req_indices.pop()
554
+ if empty_index >= last_req_index:
555
+ break
556
+
557
+ # Swap the states.
558
+ req_id = self._req_ids[last_req_index]
559
+ output_token_ids = self.req_output_token_ids[last_req_index]
560
+ assert req_id is not None
561
+ self._req_ids[empty_index] = req_id
562
+ self._req_ids[last_req_index] = None
563
+ self.req_output_token_ids[empty_index] = output_token_ids
564
+ self.req_output_token_ids[last_req_index] = None
565
+ self.req_id_to_index[req_id] = empty_index
566
+
567
+ num_tokens = self.num_tokens[last_req_index]
568
+ self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
569
+ last_req_index, :num_tokens]
570
+ self.num_tokens[empty_index] = num_tokens
571
+ self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
572
+ last_req_index]
573
+ self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
574
+ last_req_index]
575
+ self.num_computed_tokens_cpu[
576
+ empty_index] = self.num_computed_tokens_cpu[last_req_index]
577
+ self.block_table.move_row(last_req_index, empty_index)
578
+ self.temperature_cpu[empty_index] = self.temperature_cpu[
579
+ last_req_index]
580
+ self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
581
+ self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
582
+ self.frequency_penalties_cpu[
583
+ empty_index] = self.frequency_penalties_cpu[last_req_index]
584
+ self.presence_penalties_cpu[
585
+ empty_index] = self.presence_penalties_cpu[last_req_index]
586
+ self.repetition_penalties_cpu[
587
+ empty_index] = self.repetition_penalties_cpu[last_req_index]
588
+ self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
589
+ self.top_n_sigma_cpu[
590
+ empty_index] = self.top_n_sigma_cpu[last_req_index]
591
+ generator = self.generators.pop(last_req_index, None)
592
+ if generator is not None:
593
+ self.generators[empty_index] = generator
594
+
595
+ min_token = self.min_tokens.pop(last_req_index, None)
596
+ if min_token is not None:
597
+ self.min_tokens[empty_index] = min_token
598
+
599
+ self.request_lora_mapping[empty_index] = self.request_lora_mapping[
600
+ last_req_index]
601
+
602
+ self.logit_bias[empty_index] = self.logit_bias[last_req_index]
603
+
604
+ if self.allowed_token_ids_mask_cpu_tensor is not None:
605
+ self.allowed_token_ids_mask_cpu_tensor[
606
+ empty_index] = self.allowed_token_ids_mask_cpu_tensor[
607
+ last_req_index]
608
+
609
+ bad_words_token_ids = self.bad_words_token_ids.pop(
610
+ last_req_index, None)
611
+ if bad_words_token_ids is not None:
612
+ self.bad_words_token_ids[empty_index] = bad_words_token_ids
613
+ # Decrement last_req_index since it is now empty.
614
+ last_req_index -= 1
615
+
616
+ # Trim lists to the batch size.
617
+ del self._req_ids[self.num_reqs:]
618
+ del self.req_output_token_ids[self.num_reqs:]
619
+
620
+ def refresh_sampling_metadata(self):
621
+ self.sampling_metadata = self._make_sampling_metadata()
622
+
623
+ def _make_sampling_metadata(self) -> Union[SamplingMetadata, SamplingMetadataTopNSigma]:
624
+ num_reqs = self.num_reqs
625
+ if not self.all_greedy:
626
+ temperature = copy_slice(self.temperature_cpu_tensor,
627
+ self.temperature, num_reqs)
628
+ else:
629
+ temperature = None
630
+ if not self.no_top_p:
631
+ copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
632
+ if not self.no_top_k:
633
+ copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
634
+ if not self.no_min_p:
635
+ copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs)
636
+
637
+ if not self.no_penalties:
638
+ # Since syncing these tensors is expensive only copy them
639
+ # if necessary i.e. if there are requests which require
640
+ # penalties to be applied during sampling.
641
+ copy_slice(self.frequency_penalties_cpu_tensor,
642
+ self.frequency_penalties, num_reqs)
643
+ copy_slice(self.presence_penalties_cpu_tensor,
644
+ self.presence_penalties, num_reqs)
645
+ copy_slice(self.repetition_penalties_cpu_tensor,
646
+ self.repetition_penalties, num_reqs)
647
+
648
+ if not self.no_top_n_sigma:
649
+ copy_slice(self.top_n_sigma_cpu_tensor,
650
+ self.top_n_sigma, num_reqs)
651
+
652
+
653
+ needs_prompt_token_ids = (not self.no_penalties or
654
+ (self.num_reqs > 0
655
+ and self.logits_processing_needs_token_ids))
656
+ if needs_prompt_token_ids:
657
+ # The prompt tokens are used only for applying penalties or
658
+ # step pooling during the sampling/pooling process.
659
+ # Hence copy these tensors only when there are requests which
660
+ # need penalties/step_pooler to be applied.
661
+ prompt_token_ids = self._make_prompt_token_ids_tensor()
662
+ else:
663
+ prompt_token_ids = None
664
+
665
+ allowed_token_ids_mask: Optional[torch.Tensor] = None
666
+ if not self.no_allowed_token_ids:
667
+ assert self.allowed_token_ids_mask is not None
668
+ copy_slice(self.allowed_token_ids_mask_cpu_tensor,
669
+ self.allowed_token_ids_mask, num_reqs)
670
+ allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
671
+
672
+ return SamplingMetadataTopNSigma(
673
+ temperature=temperature,
674
+ all_greedy=self.all_greedy,
675
+ all_random=self.all_random,
676
+ top_p=None if self.no_top_p else self.top_p[:num_reqs],
677
+ top_k=None if self.no_top_k else self.top_k[:num_reqs],
678
+ generators=self.generators,
679
+ max_num_logprobs=self.max_num_logprobs,
680
+ prompt_token_ids=prompt_token_ids,
681
+ frequency_penalties=self.frequency_penalties[:num_reqs],
682
+ presence_penalties=self.presence_penalties[:num_reqs],
683
+ repetition_penalties=self.repetition_penalties[:num_reqs],
684
+ top_n_sigma=self.top_n_sigma[:num_reqs],
685
+ output_token_ids=cast(list[list[int]], self.req_output_token_ids),
686
+ no_penalties=self.no_penalties,
687
+ no_top_n_sigma=self.no_top_n_sigma,
688
+ allowed_token_ids_mask=allowed_token_ids_mask,
689
+ bad_words_token_ids=self.bad_words_token_ids,
690
+ logitsprocs=self.logitsprocs,
691
+ )
692
+
693
+ @property
694
+ def pooling_metadata(self) -> PoolingMetadata:
695
+ if len(self.pooling_params) == 0:
696
+ pooling_params = []
697
+ else:
698
+ # Note, for now this assumes that all request in the batch
699
+ # are either sampling or pooling requests
700
+ assert len(self.req_ids) == len(self.pooling_params)
701
+ pooling_params = [
702
+ self.pooling_params[req_id] for req_id in self.req_ids
703
+ ]
704
+
705
+ return PoolingMetadata(
706
+ prompt_lens=torch.from_numpy(
707
+ self.num_prompt_tokens[:self.num_reqs]).to(self.device),
708
+ prompt_token_ids=self.sampling_metadata.prompt_token_ids,
709
+ pooling_params=pooling_params,
710
+ )
711
+
712
+ def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
713
+ max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
714
+ prompt_token_ids_cpu_tensor = torch.empty(
715
+ (self.num_reqs, max_prompt_len),
716
+ device="cpu",
717
+ dtype=torch.int64,
718
+ pin_memory=self.pin_memory,
719
+ )
720
+ prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
721
+ prompt_token_ids[:] = self.token_ids_cpu[:self.
722
+ num_reqs, :max_prompt_len]
723
+ # Use the value of vocab_size as a pad since we don't have a
724
+ # token_id of this value.
725
+ for i in range(self.num_reqs):
726
+ prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
727
+ return prompt_token_ids_cpu_tensor.to(device=self.device,
728
+ non_blocking=True)
729
+
730
+ def make_lora_inputs(
731
+ self, num_scheduled_tokens: np.ndarray
732
+ ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
733
+ """
734
+ Given the num_scheduled_tokens for each request in the batch, return
735
+ datastructures used to activate the current LoRAs.
736
+ Returns:
737
+ 1. prompt_lora_mapping: A tuple of size self.num_reqs where,
738
+ prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
739
+ 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
740
+ where, token_lora_mapping[i] is the LoRA id to use for ith token.
741
+ 3. lora_requests: Set of relevant LoRA requests.
742
+ """
743
+
744
+ req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
745
+ prompt_lora_mapping = tuple(req_lora_mapping)
746
+ token_lora_mapping = tuple(
747
+ req_lora_mapping.repeat(num_scheduled_tokens))
748
+ active_lora_requests: set[LoRARequest] = set(
749
+ self.lora_id_to_lora_request.values())
750
+
751
+ return prompt_lora_mapping, token_lora_mapping, active_lora_requests
752
+
753
+ @property
754
+ def num_reqs(self) -> int:
755
+ return len(self.req_id_to_index)
756
+
757
+ @property
758
+ def all_greedy(self) -> bool:
759
+ return len(self.random_reqs) == 0
760
+
761
+ @property
762
+ def all_random(self) -> bool:
763
+ return len(self.greedy_reqs) == 0
764
+
765
+ @property
766
+ def no_top_p(self) -> bool:
767
+ return len(self.top_p_reqs) == 0
768
+
769
+ @property
770
+ def no_top_k(self) -> bool:
771
+ return len(self.top_k_reqs) == 0
772
+
773
+ @property
774
+ def no_min_p(self) -> bool:
775
+ return len(self.min_p_reqs) == 0
776
+
777
+ @property
778
+ def no_penalties(self) -> bool:
779
+ return (len(self.presence_penalties_reqs) == 0
780
+ and len(self.frequency_penalties_reqs) == 0
781
+ and len(self.repetition_penalties_reqs) == 0)
782
+ @property
783
+ def no_top_n_sigma(self) -> bool:
784
+ return len(self.top_n_sigma_reqs) == 0
785
+
786
+ @property
787
+ def max_num_logprobs(self) -> Optional[int]:
788
+ return max(self.num_logprobs.values()) if self.num_logprobs else None
789
+
790
+ @property
791
+ def no_prompt_logprob(self) -> bool:
792
+ return not self.num_prompt_logprobs
793
+
794
+ @property
795
+ def no_allowed_token_ids(self) -> bool:
796
+ return len(self.has_allowed_token_ids) == 0
model-00002-of-000062.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e29a512e3737d1826c80a2277a8b42021878847753aadbe5e1ae2a2df3d7f8d
3
+ size 1242564208
model-00003-of-000062.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f63aa17d947032e0a524b5798eee3becbfc9a9b6f8a352ead3232e7b34bb289
3
+ size 1242564208
model-00005-of-000062.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e3907f683d7f8382d2a792304155e8533ffa3a94dd4bb5ff825124b0dba3835
3
+ size 24650809648
model-00045-of-000062.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97458250006f5949c65b338a26503738200c5fb2415f4cc664a6b224aa9dce70
3
+ size 24650810432
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_openpangu_moe.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ from transformers.activations import ACT2FN
24
+ from transformers.cache_utils import Cache, DynamicCache
25
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutputWithPast,
28
+ CausalLMOutputWithPast,
29
+ )
30
+ from transformers.modeling_utils import PreTrainedModel
31
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
32
+ from transformers.utils.import_utils import is_torch_fx_available
33
+
34
+ from .configuration_openpangu_moe import PanguUltraMoEConfig
35
+
36
+ if is_torch_fx_available():
37
+ if not is_torch_greater_or_equal_than_1_13:
38
+ import torch.fx
39
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
40
+
41
+
42
+ class PanguUltraMoERMSNorm(nn.Module):
43
+ def __init__(self, hidden_dim, epsilon=1e-5):
44
+ super().__init__()
45
+ self.weight = nn.Parameter(torch.empty(hidden_dim))
46
+ self.epsilon = epsilon
47
+
48
+ def forward(self, input_x):
49
+ origin_dtype = input_x.dtype
50
+ var = input_x.to(torch.float32).pow(2).mean(-1, keepdim=True)
51
+ input_x = input_x * torch.rsqrt(var + self.epsilon)
52
+ output_x = self.weight * input_x
53
+ return output_x.to(origin_dtype)
54
+
55
+
56
+ class PanguUltraMoERotaryEmbedding(nn.Module):
57
+ def __init__(
58
+ self, dim, max_position_embeddings=131072, base=25600000.0, device=None
59
+ ):
60
+ super().__init__()
61
+ self.dim = dim
62
+ self.max_position_embeddings = max_position_embeddings
63
+ self.base = base
64
+ self._set_cache(
65
+ seq_len=max_position_embeddings,
66
+ device=device,
67
+ dtype=torch.get_default_dtype(),
68
+ )
69
+
70
+ def _set_cache(self, seq_len, device, dtype):
71
+ self.max_seq_len_cached = seq_len
72
+ dim = self.dim
73
+
74
+ inv_freq = 1.0 / (
75
+ self.base
76
+ ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
77
+ )
78
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
79
+
80
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
81
+
82
+ freqs = torch.outer(t, inv_freq)
83
+
84
+ emb = torch.cat((freqs, freqs), dim=-1)
85
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
86
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
87
+
88
+ def forward(self, x, kv_len, max_seq_len=None):
89
+ if max_seq_len is None:
90
+ self._set_cache(seq_len=kv_len, device=x.device, dtype=x.dtype)
91
+ elif max_seq_len > self.max_seq_len_cached:
92
+ self._set_cache(seq_len=max_seq_len, device=x.device, dtype=x.dtype)
93
+
94
+ batch_size = x.shape[0]
95
+ seq_len = x.shape[1]
96
+ if seq_len == 1:
97
+ cos = (
98
+ torch.index_select(self.cos_cached, dim=0, index=kv_len)
99
+ .unsqueeze(1)
100
+ .unsqueeze(1)
101
+ )
102
+ sin = (
103
+ torch.index_select(self.sin_cached, dim=0, index=kv_len)
104
+ .unsqueeze(1)
105
+ .unsqueeze(1)
106
+ )
107
+ else:
108
+ cos = (
109
+ self.cos_cached[:seq_len]
110
+ .unsqueeze(0)
111
+ .unsqueeze(2)
112
+ .repeat(batch_size, 1, 1, 1)
113
+ )
114
+ sin = (
115
+ self.sin_cached[:seq_len]
116
+ .unsqueeze(0)
117
+ .unsqueeze(2)
118
+ .repeat(batch_size, 1, 1, 1)
119
+ )
120
+
121
+ cos = cos[0, :, 0, :]
122
+ sin = sin[0, :, 0, :]
123
+ return (
124
+ cos.to(dtype=x.dtype),
125
+ sin.to(dtype=x.dtype),
126
+ )
127
+
128
+
129
+ def rotate_half(x):
130
+ """Rotates half the hidden dims of the input."""
131
+ x1 = x[..., : x.shape[-1] // 2]
132
+ x2 = x[..., x.shape[-1] // 2 :]
133
+ return torch.cat((-x2, x1), dim=-1)
134
+
135
+
136
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
137
+ """Applies Rotary Position Embedding to the query and key tensors.
138
+
139
+ Args:
140
+ q (`torch.Tensor`): The query tensor.
141
+ k (`torch.Tensor`): The key tensor.
142
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
143
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
144
+ position_ids (`torch.Tensor`):
145
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
146
+ used to pass offsetted position ids when working with a KV-cache.
147
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
148
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
149
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
150
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
151
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
152
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
153
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
154
+ Returns:
155
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
156
+ """
157
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
158
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
159
+
160
+ b, h, s, d = q.shape
161
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
162
+
163
+ b, h, s, d = k.shape
164
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
165
+
166
+ q_embed = (q * cos) + (rotate_half(q) * sin)
167
+ k_embed = (k * cos) + (rotate_half(k) * sin)
168
+ return q_embed, k_embed
169
+
170
+
171
+ class MLP(nn.Module):
172
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
173
+ super().__init__()
174
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
175
+ self.intermediate_size = (
176
+ config.intermediate_size if intermediate_size is None else intermediate_size
177
+ )
178
+
179
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
180
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
181
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
182
+ self.act_fn = ACT2FN[config.hidden_act]
183
+
184
+ def forward(self, x):
185
+ output = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
186
+ return output
187
+
188
+
189
+ class MoEGate(nn.Module):
190
+ def __init__(self, config):
191
+ super().__init__()
192
+ self.top_k = config.num_experts_per_tok
193
+ self.routed_scaling_factor = config.routed_scaling_factor
194
+
195
+ self.norm_topk_prob = config.norm_topk_prob
196
+ self.weight = nn.Parameter(
197
+ torch.empty((config.num_routed_experts, config.hidden_size))
198
+ )
199
+
200
+ def forward(self, hidden_states):
201
+ bsz, seq_len, h = hidden_states.shape
202
+ hidden_states = hidden_states.view(-1, h)
203
+ logits = F.linear(
204
+ hidden_states.to(torch.float32), self.weight.to(torch.float32), None
205
+ )
206
+ scores = logits.sigmoid()
207
+ scores_for_choice = scores.view(bsz * seq_len, -1)
208
+ _, topk_idx = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)
209
+ topk_weight = scores.gather(1, topk_idx)
210
+
211
+ if self.top_k > 1 and self.norm_topk_prob:
212
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
213
+ topk_weight = topk_weight / denominator
214
+ topk_weight = topk_weight * self.routed_scaling_factor
215
+
216
+ return topk_idx, topk_weight
217
+
218
+
219
+ class PanguUltraMoE(nn.Module):
220
+ def __init__(self, config):
221
+ super().__init__()
222
+ self.num_shared_experts = config.num_shared_experts
223
+ self.num_routed_experts = config.num_routed_experts
224
+ self.experts = nn.ModuleList(
225
+ [
226
+ MLP(config, intermediate_size=config.moe_intermediate_size)
227
+ for i in range(self.num_routed_experts)
228
+ ]
229
+ )
230
+ self.gate = MoEGate(config)
231
+ if self.num_shared_experts is not None:
232
+ intermediate_size = config.moe_intermediate_size * self.num_shared_experts
233
+ self.shared_experts = MLP(
234
+ config=config, intermediate_size=intermediate_size
235
+ )
236
+
237
+ def forward(self, hidden_states):
238
+ if self.num_shared_experts is not None:
239
+ shared_output = self.shared_experts(hidden_states)
240
+ input_shape = hidden_states.shape
241
+ topk_ids, topk_weight = self.gate(hidden_states)
242
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
243
+ counts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
244
+ counts.scatter_(1, topk_ids, 1)
245
+ tokens_per_expert = counts.sum(dim=0)
246
+ idxs = topk_ids.view(-1).argsort()
247
+ sorted_tokens = hidden_states[idxs // topk_ids.shape[1]]
248
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
249
+
250
+ output_hidden_states = []
251
+ start_idx = 0
252
+ for i, num_tokens in enumerate(tokens_per_expert):
253
+ end_idx = start_idx + num_tokens
254
+ if num_tokens == 0:
255
+ continue
256
+ expert = self.experts[i]
257
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
258
+ expert_out = expert(tokens_for_this_expert)
259
+ output_hidden_states.append(expert_out)
260
+ start_idx = end_idx
261
+
262
+ if len(output_hidden_states) > 0:
263
+ cat_hidden_states = torch.cat(output_hidden_states, dim=0)
264
+ else:
265
+ cat_hidden_states = sorted_tokens.new_empty(0)
266
+
267
+ final_hidden_states = torch.empty_like(cat_hidden_states)
268
+ final_hidden_states[idxs] = cat_hidden_states
269
+ final_out = final_hidden_states.view(*topk_ids.shape, -1).to(topk_weight.dtype)
270
+ final_out = (
271
+ final_out.mul_(topk_weight.unsqueeze(dim=-1))
272
+ .sum(dim=1)
273
+ .to(final_hidden_states.dtype)
274
+ ).view(*input_shape)
275
+ if self.num_shared_experts is not None:
276
+ final_out = final_out + shared_output
277
+ return final_out
278
+
279
+
280
+ class PanguUltraMoEAttention(nn.Module):
281
+ def __init__(self, config: PanguUltraMoEConfig, layer_idx: Optional[int] = None):
282
+ super().__init__()
283
+ self.layer_idx = layer_idx
284
+
285
+ self.attention_dropout = config.attention_dropout
286
+ self.hidden_size = config.hidden_size
287
+ self.num_heads = config.num_attention_heads
288
+ self.attention_q_lora_dim = config.attention_q_lora_dim
289
+ self.attention_qk_rope_dim = config.attention_qk_rope_dim
290
+ self.attention_kv_lora_dim = config.attention_kv_lora_dim
291
+ self.attention_v_dim = config.attention_v_dim
292
+ self.attention_qk_dim = config.attention_qk_dim
293
+ self.q_head_dim = config.attention_qk_dim + config.attention_qk_rope_dim
294
+
295
+ if self.attention_q_lora_dim is None:
296
+ self.q_proj = nn.Linear(
297
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False
298
+ )
299
+ else:
300
+ self.q_a_proj = nn.Linear(
301
+ self.hidden_size, config.attention_q_lora_dim, bias=False
302
+ )
303
+ self.q_a_layernorm = PanguUltraMoERMSNorm(config.attention_q_lora_dim)
304
+ self.q_b_proj = nn.Linear(
305
+ config.attention_q_lora_dim,
306
+ self.num_heads * self.q_head_dim,
307
+ bias=False,
308
+ )
309
+ self.kv_a_proj_with_mqa = nn.Linear(
310
+ self.hidden_size,
311
+ config.attention_kv_lora_dim + config.attention_qk_rope_dim,
312
+ bias=False,
313
+ )
314
+ self.kv_a_layernorm = PanguUltraMoERMSNorm(config.attention_kv_lora_dim)
315
+ self.kv_b_proj = nn.Linear(
316
+ config.attention_kv_lora_dim,
317
+ self.num_heads * (config.attention_qk_dim + self.attention_v_dim),
318
+ bias=False,
319
+ )
320
+ self.o_proj = nn.Linear(
321
+ self.num_heads * self.attention_v_dim,
322
+ self.hidden_size,
323
+ bias=False,
324
+ )
325
+ self.rotary_emb = PanguUltraMoERotaryEmbedding(
326
+ self.attention_qk_rope_dim,
327
+ max_position_embeddings=config.max_position_embeddings,
328
+ base=config.rope_theta,
329
+ )
330
+ self.softmax_scale = self.q_head_dim ** (-0.5)
331
+
332
+ def forward(
333
+ self,
334
+ hidden_states: torch.Tensor,
335
+ attention_mask: Optional[torch.Tensor] = None,
336
+ position_ids: Optional[torch.LongTensor] = None,
337
+ past_key_value: Optional[Cache] = None,
338
+ output_attentions: bool = False,
339
+ use_cache: bool = False,
340
+ **kwargs,
341
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
342
+ bsz, q_len, _ = hidden_states.size()
343
+
344
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
345
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
346
+ q_nope, q_pe = torch.split(
347
+ q, [self.attention_qk_dim, self.attention_qk_rope_dim], dim=-1
348
+ )
349
+
350
+ latent_kv = self.kv_a_proj_with_mqa(hidden_states)
351
+ kv_a, k_pe = torch.split(
352
+ latent_kv, [self.attention_kv_lora_dim, self.attention_qk_rope_dim], dim=-1
353
+ )
354
+ k_pe = k_pe.view(bsz, q_len, 1, self.attention_qk_rope_dim).transpose(1, 2)
355
+ kv = (
356
+ self.kv_b_proj(self.kv_a_layernorm(kv_a))
357
+ .view(
358
+ bsz, q_len, self.num_heads, self.attention_qk_dim + self.attention_v_dim
359
+ )
360
+ .transpose(1, 2)
361
+ )
362
+ kv_seq_len = kv.shape[-2]
363
+ if past_key_value is not None:
364
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
365
+ cos, sin = self.rotary_emb(kv, kv_seq_len)
366
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
367
+
368
+ k_nope, value = torch.split(
369
+ kv, [self.attention_qk_dim, self.attention_v_dim], dim=-1
370
+ )
371
+
372
+ def concat_nope_pe(nope, pe):
373
+ states = torch.empty(
374
+ [bsz, self.num_heads, q_len, self.q_head_dim],
375
+ dtype=nope.dtype,
376
+ device=nope.device,
377
+ )
378
+ states[:, :, :, : self.attention_qk_dim] = nope
379
+ states[:, :, :, self.attention_qk_dim :] = pe
380
+ return states
381
+
382
+ query = concat_nope_pe(q_nope, q_pe)
383
+ key = concat_nope_pe(k_nope, k_pe)
384
+
385
+ if past_key_value is not None:
386
+ key, value = past_key_value.update(
387
+ key, value, self.layer_idx, {"sin": sin, "cos": cos}
388
+ )
389
+
390
+ attn_weights = (
391
+ torch.matmul(query, key.transpose(2, 3)) * self.softmax_scale
392
+ + attention_mask
393
+ )
394
+ attn_weights = nn.functional.softmax(
395
+ attn_weights, dim=-1, dtype=torch.float32
396
+ ).to(query.dtype)
397
+ attn_weights = nn.functional.dropout(
398
+ attn_weights, p=self.attention_dropout, training=self.training
399
+ )
400
+ attn_output = torch.matmul(attn_weights, value)
401
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
402
+ attn_output = self.o_proj(attn_output)
403
+
404
+ return attn_output, past_key_value
405
+
406
+
407
+ class PanguUltraMoEDecoderLayer(nn.Module):
408
+ def __init__(self, config: PanguUltraMoEConfig, layer_idx: int):
409
+ super().__init__()
410
+ self.hidden_size = config.hidden_size
411
+ self.self_attn = PanguUltraMoEAttention(config=config, layer_idx=layer_idx)
412
+
413
+ self.mlp = (
414
+ PanguUltraMoE(config)
415
+ if (
416
+ config.num_routed_experts is not None
417
+ and layer_idx >= config.num_dense_layers
418
+ )
419
+ else MLP(config)
420
+ )
421
+ self.input_layernorm = PanguUltraMoERMSNorm(
422
+ config.hidden_size, epsilon=config.rms_norm_eps
423
+ )
424
+ self.post_attention_layernorm = PanguUltraMoERMSNorm(
425
+ config.hidden_size, epsilon=config.rms_norm_eps
426
+ )
427
+ if getattr(config, "sandwich_norm", False):
428
+ self.sandwich_norm = True
429
+ self.pre_mlp_layernorm = PanguUltraMoERMSNorm(
430
+ config.hidden_size, epsilon=config.rms_norm_eps
431
+ )
432
+ self.post_mlp_layernorm = PanguUltraMoERMSNorm(
433
+ config.hidden_size, epsilon=config.rms_norm_eps
434
+ )
435
+ else:
436
+ self.sandwich_norm = False
437
+
438
+ def forward(
439
+ self,
440
+ hidden_states: torch.Tensor,
441
+ attention_mask: Optional[torch.Tensor] = None,
442
+ position_ids: Optional[torch.LongTensor] = None,
443
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
444
+ use_cache: Optional[bool] = False,
445
+ **kwargs,
446
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
447
+ residual = hidden_states
448
+ hidden_states = self.input_layernorm(hidden_states)
449
+
450
+ hidden_states, present_key_value = self.self_attn(
451
+ hidden_states=hidden_states,
452
+ attention_mask=attention_mask,
453
+ position_ids=position_ids,
454
+ past_key_value=past_key_value,
455
+ use_cache=use_cache,
456
+ **kwargs,
457
+ )
458
+ if self.sandwich_norm:
459
+ hidden_states = self.post_attention_layernorm(hidden_states)
460
+ hidden_states = residual + hidden_states
461
+ residual = hidden_states
462
+ hidden_states = self.pre_mlp_layernorm(hidden_states)
463
+ else:
464
+ hidden_states = residual + hidden_states
465
+ residual = hidden_states
466
+ hidden_states = self.post_attention_layernorm(hidden_states)
467
+
468
+ hidden_states = self.mlp(hidden_states)
469
+
470
+ if self.sandwich_norm:
471
+ hidden_states = self.post_mlp_layernorm(hidden_states)
472
+ hidden_states = residual + hidden_states
473
+
474
+ return (hidden_states, present_key_value)
475
+
476
+
477
+ class PanguUltraMoEPreTrainedModel(PreTrainedModel):
478
+ config_class = PanguUltraMoEConfig
479
+ base_model_prefix = "model"
480
+ supports_gradient_checkpointing = True
481
+ _no_split_modules = ["PanguUltraMoEDecoderLayer"]
482
+ _skip_keys_device_placement = "past_key_values"
483
+ _supports_cache_class = True
484
+
485
+ def _init_weights(self, module):
486
+ std = self.config.initializer_range
487
+ self._initialize_linear(module, std)
488
+ self._initialize_embedding(module, std)
489
+
490
+ def _initialize_linear(self, module, std):
491
+ if isinstance(module, nn.Linear):
492
+ module.weight.data.normal_(mean=0.0, std=std)
493
+ if module.bias is not None:
494
+ module.bias.data.zero_()
495
+
496
+ def _initialize_embedding(self, module, std):
497
+ if isinstance(module, nn.Embedding):
498
+ module.weight.data.normal_(mean=0.0, std=std)
499
+ if module.padding_idx is not None:
500
+ module.weight.data[module.padding_idx].zero_()
501
+
502
+
503
+ class PanguUltraMoEModel(PanguUltraMoEPreTrainedModel):
504
+ def __init__(self, config: PanguUltraMoEConfig):
505
+ super().__init__(config)
506
+
507
+ self.vocab_size = config.vocab_size
508
+ self.hidden_size = config.hidden_size
509
+ self.padding_idx = config.pad_token_id
510
+ self.layer_num = config.num_hidden_layers
511
+ self.epsilon = config.rms_norm_eps
512
+
513
+ self.embed_tokens = nn.Embedding(
514
+ self.vocab_size, self.hidden_size, self.padding_idx
515
+ )
516
+ self.layers = nn.ModuleList(
517
+ [PanguUltraMoEDecoderLayer(config, idx) for idx in range(self.layer_num)]
518
+ )
519
+ self.norm = PanguUltraMoERMSNorm(self.hidden_size, epsilon=self.epsilon)
520
+ self.gradient_checkpointing = False
521
+
522
+ self.post_init()
523
+
524
+ def get_input_embeddings(self):
525
+ return self.embed_tokens
526
+
527
+ def set_input_embeddings(self, value):
528
+ self.embed_tokens = value
529
+
530
+ def forward(
531
+ self,
532
+ input_ids: torch.LongTensor = None,
533
+ attention_mask: Optional[torch.Tensor] = None,
534
+ position_ids: Optional[torch.LongTensor] = None,
535
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
536
+ inputs_embeds: Optional[torch.FloatTensor] = None,
537
+ use_cache: Optional[bool] = None,
538
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
539
+ if input_ids is not None and inputs_embeds is not None:
540
+ raise ValueError("You have to specify input_ids or inputs_embeds.")
541
+
542
+ if input_ids is not None:
543
+ hidden_states = self.embed_tokens(input_ids)
544
+ batch_size, seq_length = input_ids.size()
545
+ else:
546
+ hidden_states = inputs_embeds
547
+ batch_size, seq_length = inputs_embeds.size()
548
+
549
+ if position_ids is None:
550
+ position_ids = torch.arange(
551
+ seq_length, dtype=torch.long, device=hidden_states.device
552
+ ).unsqueeze(0)
553
+
554
+ past_key_values_length = 0
555
+ if use_cache:
556
+ use_legacy_cache = not isinstance(past_key_values, Cache)
557
+ if use_legacy_cache:
558
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
559
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
560
+ position_ids += past_key_values_length
561
+
562
+ attention_mask = _prepare_4d_causal_attention_mask(
563
+ attention_mask,
564
+ (batch_size, seq_length),
565
+ hidden_states,
566
+ past_key_values_length,
567
+ )
568
+
569
+ for decoder_layer in self.layers:
570
+ hidden_states, present_key_value = decoder_layer(
571
+ hidden_states,
572
+ attention_mask=attention_mask,
573
+ position_ids=position_ids,
574
+ past_key_value=past_key_values,
575
+ use_cache=use_cache,
576
+ )
577
+
578
+ hidden_states = self.norm(hidden_states)
579
+
580
+ if use_cache and use_legacy_cache:
581
+ present_key_value = present_key_value.to_legacy_cache()
582
+
583
+ return BaseModelOutputWithPast(
584
+ last_hidden_state=hidden_states,
585
+ past_key_values=present_key_value,
586
+ )
587
+
588
+
589
+ class PanguUltraMoEForCausalLM(PanguUltraMoEPreTrainedModel):
590
+ _tied_weights_keys = ["lm_head.weight"]
591
+
592
+ def __init__(self, config):
593
+ super().__init__(config)
594
+ self.model = PanguUltraMoEModel(config)
595
+ self.vocab_size = config.vocab_size
596
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
597
+
598
+ self.post_init()
599
+
600
+ def get_input_embeddings(self):
601
+ return self.model.embed_tokens
602
+
603
+ def set_input_embeddings(self, value):
604
+ self.model.embed_tokens = value
605
+
606
+ def get_output_embeddings(self):
607
+ return self.lm_head
608
+
609
+ def set_output_embeddings(self, new_embeddings):
610
+ self.lm_head = new_embeddings
611
+
612
+ def set_decoder(self, decoder):
613
+ self.model = decoder
614
+
615
+ def get_decoder(self):
616
+ return self.model
617
+
618
+ def forward(
619
+ self,
620
+ input_ids: torch.LongTensor = None,
621
+ attention_mask: Optional[torch.Tensor] = None,
622
+ position_ids: Optional[torch.LongTensor] = None,
623
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
624
+ inputs_embeds: Optional[torch.FloatTensor] = None,
625
+ labels: Optional[torch.LongTensor] = None,
626
+ use_cache: Optional[bool] = None,
627
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
628
+
629
+ outputs = self.model(
630
+ input_ids=input_ids,
631
+ attention_mask=attention_mask,
632
+ position_ids=position_ids,
633
+ past_key_values=past_key_values,
634
+ inputs_embeds=inputs_embeds,
635
+ use_cache=use_cache,
636
+ )
637
+
638
+ logits = self.lm_head(outputs[0])
639
+ logits = logits.float()
640
+
641
+ loss = None
642
+ if labels is not None:
643
+ loss = self.loss_function(
644
+ logits=logits, labels=labels, vocab_size=self.vocab_size
645
+ )
646
+
647
+ return CausalLMOutputWithPast(
648
+ loss=loss,
649
+ logits=logits,
650
+ past_key_values=outputs.past_key_values,
651
+ hidden_states=outputs.hidden_states,
652
+ attentions=outputs.attentions,
653
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "[unused10]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenization_openpangu.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
3
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ import os
23
+ from shutil import copyfile
24
+ from typing import Any, Dict, List, Optional, Tuple
25
+
26
+ import sentencepiece as spm
27
+
28
+ from transformers.tokenization_utils import PreTrainedTokenizer
29
+ from transformers.utils import logging
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
35
+
36
+ PRETRAINED_VOCAB_FILES_MAP = {}
37
+
38
+
39
+ def convert_bool(string):
40
+ if isinstance(string, str):
41
+ if string.lower() == "true":
42
+ return True
43
+ elif string.lower() == "false":
44
+ return False
45
+ else:
46
+ return string
47
+ else:
48
+ return string
49
+
50
+
51
+ class PanguUltraMoETokenizer(PreTrainedTokenizer):
52
+ """
53
+ Construct a tokenizer. Based on byte-level Byte-Pair-Encoding.
54
+
55
+ Args:
56
+ vocab_file (`str`):
57
+ Path to the vocabulary file.
58
+ """
59
+
60
+ vocab_files_names = VOCAB_FILES_NAMES
61
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
62
+ model_input_names = ["input_ids", "attention_mask"]
63
+ _auto_class = "AutoTokenizer"
64
+
65
+ def __init__(
66
+ self,
67
+ vocab_file,
68
+ unk_token="<unk>",
69
+ bos_token="<s>",
70
+ eos_token="</s>",
71
+ pad_token="</s>",
72
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
73
+ add_bos_token=True,
74
+ add_eos_token=False,
75
+ decode_with_prefix_space=False,
76
+ clean_up_tokenization_spaces=False,
77
+ **kwargs,
78
+ ):
79
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
80
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
81
+ self.sp_model.Load(vocab_file)
82
+ super().__init__(
83
+ bos_token=bos_token,
84
+ eos_token=eos_token,
85
+ unk_token=unk_token,
86
+ pad_token=pad_token,
87
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
88
+ **kwargs,
89
+ )
90
+ self.vocab_file = vocab_file
91
+ self.add_bos_token = convert_bool(add_bos_token)
92
+ self.add_eos_token = add_eos_token
93
+ self.decode_with_prefix_space = decode_with_prefix_space
94
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
95
+ self.sp_model.Load(vocab_file)
96
+ self._no_prefix_space_tokens = None
97
+
98
+ """ Initialisation"""
99
+
100
+ @property
101
+ def no_prefix_space_tokens(self):
102
+ if self._no_prefix_space_tokens is None:
103
+ vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
104
+ self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("▁")}
105
+ return self._no_prefix_space_tokens
106
+
107
+ @property
108
+ def vocab_size(self):
109
+ """Returns vocab size"""
110
+ return self.sp_model.get_piece_size()
111
+
112
+ @property
113
+ def bos_token_id(self) -> Optional[int]:
114
+ return self.sp_model.bos_id()
115
+
116
+ @property
117
+ def eos_token_id(self) -> Optional[int]:
118
+ return super().eos_token_id
119
+
120
+ def get_vocab(self):
121
+ """Returns vocab as a dict"""
122
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
123
+ vocab.update(self.added_tokens_encoder)
124
+ return vocab
125
+
126
+ def _tokenize(self, text):
127
+ """Returns a tokenized string."""
128
+ return self.sp_model.encode(text, out_type=str)
129
+
130
+ def _convert_token_to_id(self, token):
131
+ """Converts a token (str) in an id using the vocab."""
132
+ return self.sp_model.piece_to_id(token)
133
+
134
+ def _convert_id_to_token(self, index):
135
+ """Converts an index (integer) in a token (str) using the vocab."""
136
+ token = self.sp_model.IdToPiece(index)
137
+ return token
138
+
139
+ def _maybe_add_prefix_space(self, tokens, decoded):
140
+ if tokens and tokens[0] not in self.no_prefix_space_tokens:
141
+ return " " + decoded
142
+ else:
143
+ return decoded
144
+
145
+ def convert_tokens_to_string(self, tokens):
146
+ """Converts a sequence of tokens (string) in a single string."""
147
+ current_sub_tokens = []
148
+ out_string = ""
149
+ prev_is_special = False
150
+ for token in tokens:
151
+ # make sure that special tokens are not decoded using sentencepiece model
152
+ if token in self.all_special_tokens:
153
+ # Decode the current sub-tokens first
154
+ if current_sub_tokens:
155
+ out_string += self.sp_model.decode(current_sub_tokens)
156
+ current_sub_tokens = []
157
+ # Append the special token without adding extra spaces
158
+ out_string += token
159
+ prev_is_special = True
160
+ else:
161
+ current_sub_tokens.append(token)
162
+ prev_is_special = False
163
+ # Decode any remaining sub-tokens
164
+ if current_sub_tokens:
165
+ out_string += self.sp_model.decode(current_sub_tokens)
166
+ # Clean up leading and trailing spaces
167
+ if self.clean_up_tokenization_spaces:
168
+ out_string = self.clean_up_tokenization(out_string)
169
+ out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
170
+ return out_string[1:]
171
+
172
+ # Override decode to set spaces_between_special_tokens to True as default
173
+ def decode(self,
174
+ token_ids,
175
+ spaces_between_special_tokens: bool = False,
176
+ **kwargs):
177
+ return super().decode(
178
+ token_ids=token_ids,
179
+ spaces_between_special_tokens=spaces_between_special_tokens,
180
+ **kwargs,
181
+ )
182
+
183
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
184
+ """
185
+ Save the vocabulary and special tokens file to a directory.
186
+
187
+ Args:
188
+ save_directory (`str`):
189
+ The directory in which to save the vocabulary.
190
+
191
+ Returns:
192
+ `Tuple(str)`: Paths to the files saved.
193
+ """
194
+ if not os.path.isdir(save_directory):
195
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
196
+ return ("",)
197
+ out_vocab_file = os.path.join(
198
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
199
+ )
200
+
201
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
202
+ copyfile(self.vocab_file, out_vocab_file)
203
+ elif not os.path.isfile(self.vocab_file):
204
+ with open(out_vocab_file, "wb") as fi:
205
+ content_spiece_model = self.sp_model.serialized_model_proto()
206
+ fi.write(content_spiece_model)
207
+
208
+ return (out_vocab_file,)
209
+
210
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
211
+ if self.add_bos_token:
212
+ bos_token_ids = [self.bos_token_id]
213
+ else:
214
+ bos_token_ids = []
215
+
216
+ output = bos_token_ids + token_ids_0
217
+
218
+ if token_ids_1 is not None:
219
+ output = output + token_ids_1
220
+
221
+ if self.add_eos_token:
222
+ output = output + [self.eos_token_id]
223
+
224
+ return output
225
+
226
+ def get_special_tokens_mask(
227
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
228
+ ) -> List[int]:
229
+ """
230
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
231
+ special tokens using the tokenizer `prepare_for_model` method.
232
+
233
+ Args:
234
+ token_ids_0 (`List[int]`):
235
+ List of IDs.
236
+ token_ids_1 (`List[int]`, *optional*):
237
+ Optional second list of IDs for sequence pairs.
238
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
239
+ Whether or not the token list is already formatted with special tokens for the model.
240
+
241
+ Returns:
242
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
243
+ """
244
+ if already_has_special_tokens:
245
+ return super().get_special_tokens_mask(
246
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
247
+ )
248
+
249
+ if token_ids_1 is None:
250
+ return [1] + ([0] * len(token_ids_0)) + [1]
251
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
252
+
253
+ def create_token_type_ids_from_sequences(
254
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
255
+ ) -> List[int]:
256
+ """
257
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
258
+ use of token type ids, therefore a list of zeros is returned.
259
+
260
+ Args:
261
+ token_ids_0 (`List[int]`):
262
+ List of IDs.
263
+ token_ids_1 (`List[int]`, *optional*):
264
+ Optional second list of IDs for sequence pairs.
265
+
266
+ Returns:
267
+ `List[int]`: List of zeros.
268
+ """
269
+ eos = [self.eos_token_id]
270
+
271
+ if token_ids_1 is None:
272
+ return len(token_ids_0 + eos) * [0]
273
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"add_bos_token": false, "add_eos_token": false, "add_prefix_space": true, "added_tokens_decoder": {"0": {"content": "<unk>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "1": {"content": "<s>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "2": {"content": "</s>", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45806": {"content": "<|User|>:", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45813": {"content": "<|Bot|>:", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45830": {"content": "[unused0]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45840": {"content": "[unused1]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45846": {"content": "[unused2]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45849": {"content": "[unused3]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45861": {"content": "[unused4]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45866": {"content": "[unused5]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45874": {"content": "[unused6]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45883": {"content": "[unused7]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45884": {"content": "[unused8]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45887": {"content": "[unused9]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45892": {"content": "[unused10]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45920": {"content": "[unused11]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45932": {"content": "[unused12]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45938": {"content": "[unused13]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45953": {"content": "[unused14]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45968": {"content": "[unused15]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45974": {"content": "[unused16]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45982": {"content": "[unused17]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "45986": {"content": "[unused18]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46005": {"content": "[unused19]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46007": {"content": "[unused20]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46014": {"content": "[unused21]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46017": {"content": "[unused22]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46028": {"content": "[unused23]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46032": {"content": "[unused24]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46081": {"content": "[unused25]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46086": {"content": "[unused26]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46101": {"content": "[unused27]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46183": {"content": "[unused28]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46230": {"content": "[unused29]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46245": {"content": "[unused30]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "46257": {"content": "[unused31]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "144208": {"content": "[unused32]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}, "144209": {"content": "[unused33]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true}}, "auto_map": {"AutoTokenizer": ["tokenization_openpangu.PanguUltraMoETokenizer", null]}, "bos_token": "<s>", "clean_up_tokenization_spaces": false, "eos_token": "[unused10]", "legacy": true, "model_max_length": 1000000000000000019884624838656, "pad_token": "<unk>", "sp_model_kwargs": {}, "spaces_between_special_tokens": false, "tokenizer_class": "PanguUltraMoETokenizer", "unk_token": "<unk>", "use_default_system_prompt": false, "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '[unused9]系统:[unused10]' }}{% endif %}{% if message['role'] == 'system' %}{{ '[unused9]系统:' + message['content'] + '[unused10]' }}{% endif %}{% if message['role'] == 'assistant' %}{{'[unused9]助手:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'tool' %}{{'[unused9]工具:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'function' %}{{'[unused9]方法:' + message['content'] + '[unused10]'}}{% endif %}{% if message['role'] == 'user' %}{{'[unused9]用户:' + message['content'] + '[unused10]'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[unused9]助手:' }}{% endif %}"}