LWM-Temporal: Sparse Spatio-Temporal Attention for Wireless Channel Representation Learning

Model Description

LWM-Temporal is a physics-informed foundation model for wireless channels that learns universal spatio-temporal representations from sequences of CSI. It extends the Large Wireless Model (LWM) family by explicitly modeling temporal evolution, user mobility, and Doppler–delay dynamics while remaining backward-compatible with single-snapshot tasks. The model operates in the angle–delay–time (AD–t) domain and builds sparsity and motion priors directly into its Sparse Spatio-Temporal Attention (SSTA) mechanism and masking objectives. It is intended as a general-purpose feature extractor for diverse wireless downstream tasks.

Key Features

  • Universal feature extraction: Handles single snapshots (LoS/NLoS classification, clustering, compression) and channel sequences (beam tracking, prediction, scatterer dynamics) with the same backbone.
  • Sparse Spatio-Temporal Attention (SSTA): Restricts attention to physically plausible neighborhoods in AD–t, yielding near-linear complexity, better interpretability, and improved generalization under mobility.
  • Physics-informed design: Angle–delay representation exposes multipath geometry; temporal corridors follow Doppler-driven motion; energy-aware routing prioritizes dominant paths.
  • Self-supervised pretraining: Masked channel modeling on complex-valued channels; reconstruction is applied only to masked patches; normalized losses emphasize high-energy components.

Architecture Overview

  • Input: Complex channel sequences shaped (T, N_antennas, N_subcarriers) (or (T, H, W) grids) transformed to the angle–delay domain; real and imaginary (or mag/phase) processed jointly.
  • Tokenization: Patch-based tokenization over AD grids (tubelets across time). Optional CLS tokens.
  • Backbone: Transformer encoder with sparse spatio-temporal attention, learned or RoPE positional encoding, and optional dynamic routing (routing_topk_*).
  • Output: Dense token embeddings plus optional CLS; reconstruction head used during pretraining and prediction tasks.

Pretraining Method

  • Objective: Masked channel modeling on complex channels in AD–t.
  • Masking: Frame-local rectangles, spatiotemporal tubes (following motion), pilot-lattice combs, or uniform random masking.
  • Loss: Normalized MSE on masked tokens only to avoid dominance of low-energy regions.
  • Optimization: AdamW, LR 5e-4, warmup + cosine decay, mixed precision, multi-GPU supported.

Training Data

Pretraining uses a dynamic wireless digital twin pipeline (ray-traced scenarios) with multiple cities/environments, realistic trajectories, Doppler/delay/angle dynamics, and diverse LoS/NLoS scattering. Data covers varied velocities and geometries to promote robust generalization.

Intended Use

  • Suitable: Channel prediction, beam tracking/selection, CSI compression and feedback, LoS/NLoS classification, clustering/zoning, digital twin calibration, mobility-aware sensing.
  • Not intended: End-to-end decoding/demodulation, symbol-level recovery, or non-wireless domains without adaptation.

Checkpoints and Files

  • Default checkpoint: checkpoints/{config.json,pytorch_model.bin} (HF-compatible directory). Load with LWMBackbone.from_pretrained("checkpoints") or LWMHFModel.from_pretrained("checkpoints").
  • Package: Python package under LWMTemporal/ with CLI entry points in LWMTemporal/cli/ and tasks in LWMTemporal/tasks/.

How to Use (from the tutorial workflow)

These snippets mirror the tutorial workflow and assume the repo root is on PYTHONPATH.

1) Generate a dynamic scenario (DeepMIMO-based)

from pathlib import Path
import sys
from LWMTemporal.data.scenario_generation import (
    AntennaArrayConfig, DynamicScenarioGenerator, GridConfig,
    ScenarioGenerationConfig, ScenarioSamplingConfig, TrafficConfig,
)

ROOT = Path.cwd()
DATA_DIR = ROOT / "examples/data"
FULL_DATA_DIR = ROOT / "examples/full_data"
FIG_DIR = ROOT / "examples/figs"
SCENARIO_DIR = ROOT / "deepmimo_scenarios"

cfg = ScenarioGenerationConfig(
    scenario="city_1_losangeles_3p5",
    antenna=AntennaArrayConfig(tx_horizontal=32, tx_vertical=1, subcarriers=32),
    sampling=ScenarioSamplingConfig(time_steps=11, sample_dt=1e-3),
    traffic=TrafficConfig(num_vehicles=180, num_pedestrians=20, turn_probability=0.1, vehicle_speed_range=(0/3.6, 108/3.6)),
    grid=GridConfig(road_width=2.0, road_center_spacing=8.0),
    output_dir=DATA_DIR,
    full_output_dir=FULL_DATA_DIR,
    figures_dir=FIG_DIR / "environment",
    scenarios_dir=SCENARIO_DIR,
    deepmimo_max_paths=25,
)
scenario = DynamicScenarioGenerator(cfg).generate(overwrite=True)
print("Scenario cached to:", scenario.output_path)

2) Inspect cached payload

import pickle
DATA_PATH = scenario.output_path
with DATA_PATH.open("rb") as handle:
    payload = pickle.load(handle)
print("channel_discrete shape:", payload["channel_discrete"].shape)

3) Visualize angle–delay–time evolution

import torch
from IPython.display import Image, display
from LWMTemporal.data.angle_delay import AngleDelayProcessor, AngleDelayConfig
from examples.ad_temporal_evolution import pick_bins, plot_curves

processor = AngleDelayProcessor(AngleDelayConfig(keep_percentage=0.25))
ANGLE_DELAY_DIR = FIG_DIR / "angle_delay"
ANGLE_DELAY_DIR.mkdir(parents=True, exist_ok=True)
ue_idx = 10
channel = torch.tensor(payload["channel_discrete"][ue_idx])
volume = processor.forward(channel)
volume_trimmed, _ = processor.truncate_delay_bins(volume)
gif_path = ANGLE_DELAY_DIR / "discrete_angle_delay.gif"
processor.save_angle_delay_gif(volume_trimmed, gif_path, fps=6, show=True)
picks = pick_bins(volume_trimmed, k=3, coords=None)
curves_path = ANGLE_DELAY_DIR / "discrete_curves.png"
plot_curves(volume_trimmed, picks, curves_path, title="Discrete bins")
display(Image(filename=curves_path))

4) Masked channel modeling sanity check

import torch
from torch.utils.data import DataLoader
from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset
from LWMTemporal.models.lwm import LWMBackbone, LWMConfig, masked_nmse_loss
from LWMTemporal.tasks.pretraining import MaskArgs, MaskGenerator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_cfg = AngleDelayDatasetConfig(
    raw_path=scenario.output_path,
    keep_percentage=0.25,
    normalize="per_sample_rms",
    cache_dir=Path("cache"),
    patch_size=(1, 1),
    phase_mode="real_imag",
)
dataset = AngleDelaySequenceDataset(data_cfg)
loader = DataLoader(dataset, batch_size=16, shuffle=False)
example_shape = dataset[0]["shape"]
tokens_per_sample = int(example_shape[0] * example_shape[1] * example_shape[2])

cfg = LWMConfig(
    patch_size=(1, 1),
    phase_mode="real_imag",
    embed_dim=32,
    depth=12,
    num_heads=8,
    mlp_ratio=4.0,
    same_frame_window=2,
    temporal_offsets=(-4, -3, -2, -1, 1, 2, 3),
    temporal_spatial_window=2,
    temporal_drift_h=1,
    temporal_drift_w=1,
    routing_topk_enable=True,
    topk_per_head=True,
    max_seq_len=tokens_per_sample,
)
backbone = LWMBackbone.from_pretrained("checkpoints/pytorch_model.bin", config=cfg).to(device).eval()
masker = MaskGenerator(MaskArgs(mask_ratio=0.6, mask_mode="random"))

with torch.no_grad():
    for batch in loader:
        tokens = batch["tokens"].to(device)
        base_mask = batch["base_mask"].to(device)
        T, H, W = batch["shape"][0]
        B, _, _ = tokens.shape
        mask = torch.stack([masker(int(T), int(H), int(W), device).view(-1) for _ in range(B)])
        mask = torch.logical_or(mask, base_mask)
        corrupted = tokens.masked_fill(mask.unsqueeze(-1), 0.0)
        recon = backbone.forward_tokens(corrupted, mask, int(T), int(H), int(W))["reconstruction"]
        nmse = masked_nmse_loss(recon, tokens, mask)
        print("NMSE:", nmse.item())
    break

5) Channel prediction fine-tuning / evaluation

import torch
from LWMTemporal.tasks.channel_prediction import (
    ChannelPredictionArgs, ChannelPredictionTrainer,
    DatasetArgs, ModelArgs, TrainingArgs, PredictionArgs,
)

channel_args = ChannelPredictionArgs(
    dataset=DatasetArgs(
        data_path=scenario.output_path,
        keep_percentage=0.25,
        normalize="per_sample_rms",
        seed=0,
        train_limit=50,
        val_limit=50,
    ),
    model=ModelArgs(
        patch_size=(1, 1),
        phase_mode="real_imag",
        embed_dim=32,
        depth=12,
        num_heads=8,
        mlp_ratio=4.0,
        same_frame_window=2,
        temporal_offsets=(-1, -2, -3, -4, -5, -6, -7),
        temporal_spatial_window=2,
        temporal_drift_h=1,
        temporal_drift_w=1,
        routing_topk_enable=True,
        routing_topk_fraction=0.2,
        routing_topk_max=32,
        pretrained=Path("checkpoints/pytorch_model.bin"),
    ),
    training=TrainingArgs(
        device="cuda" if torch.cuda.is_available() else "cpu",
        epochs=10,
        batch_size=2,
        lr=1e-4,
        weight_decay=1e-4,
        warmup_ratio=0.1,
        save_dir=ROOT / "models",
        inference_only=True,
        inference_split="val",
        use_wandb=False,
        use_dataparallel=True,
    ),
    prediction=PredictionArgs(Tpast=10, horizon=1, viz_dir=FIG_DIR / "predictions"),
)
trainer = ChannelPredictionTrainer(channel_args)
trainer.train()

6) Self-supervised pretraining recipe

import torch
from LWMTemporal.tasks.pretraining import (
    DataArgs, MaskArgs, CurriculumArgs, AugmentationArgs,
    OptimizationArgs, ModelArgs as PretrainModelArgs,
    LoggingArgs, PretrainingArgs, PretrainingTrainer,
)

pretrain_args = PretrainingArgs(
    data=DataArgs(data_dir=DATA_DIR, keep_percentage=0.25),
    mask=MaskArgs(mask_ratio=0.6, mask_mode="auto"),
    curriculum=CurriculumArgs(strategy="mask", warmup_epochs=2, min_mask_ratio=0.4, max_mask_ratio=0.6),
    augment=AugmentationArgs(phase_p=0.0, amp_p=0.0, awgn_p=0.0),
    optim=OptimizationArgs(
        device="cuda" if torch.cuda.is_available() else "cpu",
        epochs=3,
        batch_size=8,
        lr=1e-4,
        save_dir=ROOT / "checkpoints",
        save_prefix="tutorial_pretrain",
        use_dataparallel=False,
    ),
    model=PretrainModelArgs(embed_dim=32, depth=12, num_heads=8, max_seq_len=5120),
    logging=LoggingArgs(log_dir=ROOT / "logs/tutorial"),
)
trainer = PretrainingTrainer(pretrain_args)
trainer.train()

Citation

@misc{alikhani2026lwmtemporalsparsespatiotemporalattention,
      title={LWM-Temporal: Sparse Spatio-Temporal Attention for Wireless Channel Representation Learning}, 
      author={Sadjad Alikhani and Akshay Malhotra and Shahab Hamidi-Rad and Ahmed Alkhateeb},
      year={2026},
      eprint={2603.10024},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2603.10024}, 
}

Contact

  • LWM Team, Wireless Intelligence Lab, Arizona State University
    lwmwireless@gmail.com
  • Issues/PRs: open on this repository.
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Space using wi-lab/lwm-temporal 1

Paper for wi-lab/lwm-temporal