Abstract
Multi-Head Low-Rank Attention addresses long-context inference bottlenecks in large language models by enabling efficient 4-way tensor parallelism decoding through partitionable latent states.
Long-context inference in large language models is bottlenecked by Key--Value (KV) cache loading during the decoding stage, where the sequential nature of generation requires repeatedly transferring the KV cache from off-chip High-Bandwidth Memory (HBM) to on-chip Static Random-Access Memory (SRAM) at each step. While Multi-Head Latent Attention (MLA) significantly reduces the total KV cache size, it suffers from a sharding bottleneck during distributed decoding via Tensor Parallelism (TP). Since its single latent head cannot be partitioned, each device is forced to redundantly load the complete KV cache for every token, consuming excessive memory traffic and diminishing TP benefits like weight sharding. In this work, we propose Multi-Head Low-Rank Attention (MLRA), which enables partitionable latent states for efficient 4-way TP decoding. Extensive experiments show that MLRA achieves state-of-the-art perplexity and downstream task performance, while also delivering a 2.8times decoding speedup over MLA. Code is available at https://github.com/SongtaoLiu0823/MLRA. Pretrained weights, along with the training and evaluation data, are available at https://huggingface.co/Soughing/MLRA.
Community
[ICLR 2026] Introducing a very strong attention mechanism, "Multi-Head Low-Rank Attention", now published in International Conference on Learning Representations (ICLR) 2026!
Many leading open-source LLMs (e.g., DeepSeek, Kimi, GLM) have adopted Multi-Head Latent Attention (MLA) because it delivers strong model quality with a much smaller KV cache. However, MLA runs into a major bottleneck at inference time: its KV cache cannot be sharded with Tensor Parallelism (TP), which prevents scalable distributed decoding.
We introduce Multi-Head Low-Rank Attention (MLRA). MLRA preserves MLA’s low-rank KV compression while making the KV cache natively compatible with 4-way TP, enabling efficient multi-GPU decoding without redundant KV loading. We pretrain MLRA from scratch at 2.9B parameters on 100B tokens, achieving state-of-the-art model quality among existing 11 attention mechanisms under our training setting. On the systems side, we benchmark decoding latency and throughput at a DeepSeek-v3–level setup and observe leading decoding efficiency.
Crucially, MLRA can scale! In MLA, increasing the KV latent-head dimension often makes high-performance decoding kernels hard to deploy, and simply increasing the number of heads under a fixed activation/parameter budget can degrade quality. By contrast, MLRA’s multi-branch low-rank architecture supports substantially more heads while remaining TP-friendly and kernel-efficient. We release the full training pipeline, pretrained weights, and a high-performance decoding kernel (based on FlashAttention-3) to make MLRA easy to reproduce and deploy.
paper: https://arxiv.org/pdf/2603.02188
blog: https://songtaoliu0823.github.io/mlra/
code: https://github.com/SongtaoLiu0823/MLRA
data & wights: https://huggingface.co/Soughing/MLRA
Models citing this paper 1
Datasets citing this paper 0
No dataset linking this paper
Spaces citing this paper 0
No Space linking this paper
Collections including this paper 0
No Collection including this paper