LWM-Temporal: Sparse Spatio-Temporal Attention for Wireless Channel Representation Learning
Model Description
LWM-Temporal is a physics-informed foundation model for wireless channels that learns universal spatio-temporal representations from sequences of CSI. It extends the Large Wireless Model (LWM) family by explicitly modeling temporal evolution, user mobility, and Doppler–delay dynamics while remaining backward-compatible with single-snapshot tasks. The model operates in the angle–delay–time (AD–t) domain and builds sparsity and motion priors directly into its Sparse Spatio-Temporal Attention (SSTA) mechanism and masking objectives. It is intended as a general-purpose feature extractor for diverse wireless downstream tasks.
Key Features
- Universal feature extraction: Handles single snapshots (LoS/NLoS classification, clustering, compression) and channel sequences (beam tracking, prediction, scatterer dynamics) with the same backbone.
- Sparse Spatio-Temporal Attention (SSTA): Restricts attention to physically plausible neighborhoods in AD–t, yielding near-linear complexity, better interpretability, and improved generalization under mobility.
- Physics-informed design: Angle–delay representation exposes multipath geometry; temporal corridors follow Doppler-driven motion; energy-aware routing prioritizes dominant paths.
- Self-supervised pretraining: Masked channel modeling on complex-valued channels; reconstruction is applied only to masked patches; normalized losses emphasize high-energy components.
Architecture Overview
- Input: Complex channel sequences shaped
(T, N_antennas, N_subcarriers)(or(T, H, W)grids) transformed to the angle–delay domain; real and imaginary (or mag/phase) processed jointly. - Tokenization: Patch-based tokenization over AD grids (tubelets across time). Optional CLS tokens.
- Backbone: Transformer encoder with sparse spatio-temporal attention, learned or RoPE positional encoding, and optional dynamic routing (
routing_topk_*). - Output: Dense token embeddings plus optional CLS; reconstruction head used during pretraining and prediction tasks.
Pretraining Method
- Objective: Masked channel modeling on complex channels in AD–t.
- Masking: Frame-local rectangles, spatiotemporal tubes (following motion), pilot-lattice combs, or uniform random masking.
- Loss: Normalized MSE on masked tokens only to avoid dominance of low-energy regions.
- Optimization: AdamW, LR
5e-4, warmup + cosine decay, mixed precision, multi-GPU supported.
Training Data
Pretraining uses a dynamic wireless digital twin pipeline (ray-traced scenarios) with multiple cities/environments, realistic trajectories, Doppler/delay/angle dynamics, and diverse LoS/NLoS scattering. Data covers varied velocities and geometries to promote robust generalization.
Intended Use
- Suitable: Channel prediction, beam tracking/selection, CSI compression and feedback, LoS/NLoS classification, clustering/zoning, digital twin calibration, mobility-aware sensing.
- Not intended: End-to-end decoding/demodulation, symbol-level recovery, or non-wireless domains without adaptation.
Checkpoints and Files
- Default checkpoint:
checkpoints/{config.json,pytorch_model.bin}(HF-compatible directory). Load withLWMBackbone.from_pretrained("checkpoints")orLWMHFModel.from_pretrained("checkpoints"). - Package: Python package under
LWMTemporal/with CLI entry points inLWMTemporal/cli/and tasks inLWMTemporal/tasks/.
How to Use (from the tutorial workflow)
These snippets mirror the tutorial workflow and assume the repo root is on
PYTHONPATH.
1) Generate a dynamic scenario (DeepMIMO-based)
from pathlib import Path
import sys
from LWMTemporal.data.scenario_generation import (
AntennaArrayConfig, DynamicScenarioGenerator, GridConfig,
ScenarioGenerationConfig, ScenarioSamplingConfig, TrafficConfig,
)
ROOT = Path.cwd()
DATA_DIR = ROOT / "examples/data"
FULL_DATA_DIR = ROOT / "examples/full_data"
FIG_DIR = ROOT / "examples/figs"
SCENARIO_DIR = ROOT / "deepmimo_scenarios"
cfg = ScenarioGenerationConfig(
scenario="city_1_losangeles_3p5",
antenna=AntennaArrayConfig(tx_horizontal=32, tx_vertical=1, subcarriers=32),
sampling=ScenarioSamplingConfig(time_steps=11, sample_dt=1e-3),
traffic=TrafficConfig(num_vehicles=180, num_pedestrians=20, turn_probability=0.1, vehicle_speed_range=(0/3.6, 108/3.6)),
grid=GridConfig(road_width=2.0, road_center_spacing=8.0),
output_dir=DATA_DIR,
full_output_dir=FULL_DATA_DIR,
figures_dir=FIG_DIR / "environment",
scenarios_dir=SCENARIO_DIR,
deepmimo_max_paths=25,
)
scenario = DynamicScenarioGenerator(cfg).generate(overwrite=True)
print("Scenario cached to:", scenario.output_path)
2) Inspect cached payload
import pickle
DATA_PATH = scenario.output_path
with DATA_PATH.open("rb") as handle:
payload = pickle.load(handle)
print("channel_discrete shape:", payload["channel_discrete"].shape)
3) Visualize angle–delay–time evolution
import torch
from IPython.display import Image, display
from LWMTemporal.data.angle_delay import AngleDelayProcessor, AngleDelayConfig
from examples.ad_temporal_evolution import pick_bins, plot_curves
processor = AngleDelayProcessor(AngleDelayConfig(keep_percentage=0.25))
ANGLE_DELAY_DIR = FIG_DIR / "angle_delay"
ANGLE_DELAY_DIR.mkdir(parents=True, exist_ok=True)
ue_idx = 10
channel = torch.tensor(payload["channel_discrete"][ue_idx])
volume = processor.forward(channel)
volume_trimmed, _ = processor.truncate_delay_bins(volume)
gif_path = ANGLE_DELAY_DIR / "discrete_angle_delay.gif"
processor.save_angle_delay_gif(volume_trimmed, gif_path, fps=6, show=True)
picks = pick_bins(volume_trimmed, k=3, coords=None)
curves_path = ANGLE_DELAY_DIR / "discrete_curves.png"
plot_curves(volume_trimmed, picks, curves_path, title="Discrete bins")
display(Image(filename=curves_path))
4) Masked channel modeling sanity check
import torch
from torch.utils.data import DataLoader
from LWMTemporal.data.datasets import AngleDelayDatasetConfig, AngleDelaySequenceDataset
from LWMTemporal.models.lwm import LWMBackbone, LWMConfig, masked_nmse_loss
from LWMTemporal.tasks.pretraining import MaskArgs, MaskGenerator
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_cfg = AngleDelayDatasetConfig(
raw_path=scenario.output_path,
keep_percentage=0.25,
normalize="per_sample_rms",
cache_dir=Path("cache"),
patch_size=(1, 1),
phase_mode="real_imag",
)
dataset = AngleDelaySequenceDataset(data_cfg)
loader = DataLoader(dataset, batch_size=16, shuffle=False)
example_shape = dataset[0]["shape"]
tokens_per_sample = int(example_shape[0] * example_shape[1] * example_shape[2])
cfg = LWMConfig(
patch_size=(1, 1),
phase_mode="real_imag",
embed_dim=32,
depth=12,
num_heads=8,
mlp_ratio=4.0,
same_frame_window=2,
temporal_offsets=(-4, -3, -2, -1, 1, 2, 3),
temporal_spatial_window=2,
temporal_drift_h=1,
temporal_drift_w=1,
routing_topk_enable=True,
topk_per_head=True,
max_seq_len=tokens_per_sample,
)
backbone = LWMBackbone.from_pretrained("checkpoints/pytorch_model.bin", config=cfg).to(device).eval()
masker = MaskGenerator(MaskArgs(mask_ratio=0.6, mask_mode="random"))
with torch.no_grad():
for batch in loader:
tokens = batch["tokens"].to(device)
base_mask = batch["base_mask"].to(device)
T, H, W = batch["shape"][0]
B, _, _ = tokens.shape
mask = torch.stack([masker(int(T), int(H), int(W), device).view(-1) for _ in range(B)])
mask = torch.logical_or(mask, base_mask)
corrupted = tokens.masked_fill(mask.unsqueeze(-1), 0.0)
recon = backbone.forward_tokens(corrupted, mask, int(T), int(H), int(W))["reconstruction"]
nmse = masked_nmse_loss(recon, tokens, mask)
print("NMSE:", nmse.item())
break
5) Channel prediction fine-tuning / evaluation
import torch
from LWMTemporal.tasks.channel_prediction import (
ChannelPredictionArgs, ChannelPredictionTrainer,
DatasetArgs, ModelArgs, TrainingArgs, PredictionArgs,
)
channel_args = ChannelPredictionArgs(
dataset=DatasetArgs(
data_path=scenario.output_path,
keep_percentage=0.25,
normalize="per_sample_rms",
seed=0,
train_limit=50,
val_limit=50,
),
model=ModelArgs(
patch_size=(1, 1),
phase_mode="real_imag",
embed_dim=32,
depth=12,
num_heads=8,
mlp_ratio=4.0,
same_frame_window=2,
temporal_offsets=(-1, -2, -3, -4, -5, -6, -7),
temporal_spatial_window=2,
temporal_drift_h=1,
temporal_drift_w=1,
routing_topk_enable=True,
routing_topk_fraction=0.2,
routing_topk_max=32,
pretrained=Path("checkpoints/pytorch_model.bin"),
),
training=TrainingArgs(
device="cuda" if torch.cuda.is_available() else "cpu",
epochs=10,
batch_size=2,
lr=1e-4,
weight_decay=1e-4,
warmup_ratio=0.1,
save_dir=ROOT / "models",
inference_only=True,
inference_split="val",
use_wandb=False,
use_dataparallel=True,
),
prediction=PredictionArgs(Tpast=10, horizon=1, viz_dir=FIG_DIR / "predictions"),
)
trainer = ChannelPredictionTrainer(channel_args)
trainer.train()
6) Self-supervised pretraining recipe
import torch
from LWMTemporal.tasks.pretraining import (
DataArgs, MaskArgs, CurriculumArgs, AugmentationArgs,
OptimizationArgs, ModelArgs as PretrainModelArgs,
LoggingArgs, PretrainingArgs, PretrainingTrainer,
)
pretrain_args = PretrainingArgs(
data=DataArgs(data_dir=DATA_DIR, keep_percentage=0.25),
mask=MaskArgs(mask_ratio=0.6, mask_mode="auto"),
curriculum=CurriculumArgs(strategy="mask", warmup_epochs=2, min_mask_ratio=0.4, max_mask_ratio=0.6),
augment=AugmentationArgs(phase_p=0.0, amp_p=0.0, awgn_p=0.0),
optim=OptimizationArgs(
device="cuda" if torch.cuda.is_available() else "cpu",
epochs=3,
batch_size=8,
lr=1e-4,
save_dir=ROOT / "checkpoints",
save_prefix="tutorial_pretrain",
use_dataparallel=False,
),
model=PretrainModelArgs(embed_dim=32, depth=12, num_heads=8, max_seq_len=5120),
logging=LoggingArgs(log_dir=ROOT / "logs/tutorial"),
)
trainer = PretrainingTrainer(pretrain_args)
trainer.train()
Citation
@misc{alikhani2026lwmtemporalsparsespatiotemporalattention,
title={LWM-Temporal: Sparse Spatio-Temporal Attention for Wireless Channel Representation Learning},
author={Sadjad Alikhani and Akshay Malhotra and Shahab Hamidi-Rad and Ahmed Alkhateeb},
year={2026},
eprint={2603.10024},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2603.10024},
}
Contact
- LWM Team, Wireless Intelligence Lab, Arizona State University
lwmwireless@gmail.com - Issues/PRs: open on this repository.