Papers
arxiv:2603.09555

Compiler-First State Space Duality and Portable O(1) Autoregressive Caching for Inference

Published on Mar 10
· Submitted by
Santoni
on Mar 11
Authors:

Abstract

Mamba-2's state space model is implemented using XLA-optimized primitives without custom CUDA or Triton kernels, enabling cross-platform deployment and achieving high performance on TPU.

AI-generated summary

State-space model releases are typically coupled to fused CUDA and Triton kernels, inheriting a hard dependency on NVIDIA hardware. We show that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required. We implement the full inference path (prefill, cached autoregressive decoding) as shaped standard primitives under XLA, without hand-written kernels, and realise the architecture's theoretical O(1) state management as a compiled on-device cache requiring no host synchronisation during generation. The implementation runs unmodified on CPU, NVIDIA GPU, and Google Cloud TPU from a single JAX source. On TPU v6e across five model scales (130M--2.7B parameters), XLA-generated code reaches approximately 140 TFLOPS on single-stream prefill (15% MFU) and up to 64% bandwidth utilisation on decode. Greedy decoding matches the PyTorch/CUDA reference token-for-token across 64 steps, with hidden-state agreement within float32 rounding tolerance. The pattern transfers to any SSM recurrence satisfying the same structural conditions, on any platform with a mature XLA backend. The implementation is publicly available at https://github.com/CosmoNaught/mamba2-jax and merged into the Bonsai JAX model library.

Community

Paper author Paper submitter

State-space models like Mamba-2 are typically coupled to fused CUDA and Triton kernels, inheriting a hard dependency on NVIDIA hardware.

This work introduces a completely kernel-free implementation of the state space duality algorithm in JAX. By mapping the algorithm directly to XLA compiler passes and realising the architecture's theoretical O(1) state management as a compiled on-device cache requiring no host synchronisation during generation, we achieve true cross-platform portability.

Key Highlights:

  • True Portability: The implementation runs unmodified on CPU, NVIDIA GPU, and Google Cloud TPU from a single JAX source.
  • Kernel-Free Performance: Proves that standard XLA fusion and tiling passes can effectively optimize the SSD algorithm without relying on bespoke hardware intrinsics.
  • Zero Host-Device Bottleneck: Carries the autoregressive state cache entirely on-device via compiled control flow, avoiding the massive latency penalties of Python host loops.
  • Numerical Parity: Greedy decoding matches the PyTorch/CUDA reference token-for-token ensuring drop-in correctness.

figure1_scaling

This is an automated message from the Librarian Bot. I found the following papers similar to this paper.

The following papers were recommended by the Semantic Scholar API

Please give a thumbs up to this comment if you found it helpful!

If you want recommendations for any Paper on Hugging Face checkout this Space

You can directly ask Librarian Bot for paper recommendations by tagging it in a comment: @librarian-bot recommend

Sign up or log in to comment

Models citing this paper 0

No model linking this paper

Cite arxiv.org/abs/2603.09555 in a model README.md to link it from this page.

Datasets citing this paper 0

No dataset linking this paper

Cite arxiv.org/abs/2603.09555 in a dataset README.md to link it from this page.

Spaces citing this paper 0

No Space linking this paper

Cite arxiv.org/abs/2603.09555 in a Space README.md to link it from this page.

Collections including this paper 1