HamJEPA / HamSIGReg (PyTorch checkpoints)

This repository provides PyTorch checkpoints for the models studied in Beyond Isotropy in JEPAs: Hamiltonian Geometry and Symplectic Prediction.

It includes two matched-compute training objectives:

  • HamJEPA (MV-HJEPA): a JEPA-style self-supervised learner that imposes phase-space structure by representing each sample as a concatenated latent state (q, p) and training a symplectic (leapfrog) rollout predictor in latent space.
  • HamSIGReg (SIGReg+tokens baseline): a LeJEPA-style two-view predictor baseline with a SIGReg regularizer, using the same tokenized backbone interface.

Both runs use the same backbone family, data pipeline (two global crops), optimizer family, and training schedule so that comparisons primarily reflect the geometric inductive bias/objective, not extra compute.


What’s in this repo

Checkpoints

Checkpoints live in checkpoints/.

Released naming convention:

  • checkpoints/imagenet_hjepa_mv_epoch_XXX.pth
  • checkpoints/imagenet_sigreg_tokens_epoch_XXX.pth

where XXX is the epoch index (e.g., 005, 010, ..., 045 for the 45-epoch ImageNet-100 runs).


Checkpoint format

Each .pth file is a torch.save(...) dictionary with (at least) the following keys:

  • epoch: int
  • config: training configuration dict
  • encoder: state_dict for encoder
  • projector: state_dict for projector (identity in released configs)
  • regularizer: state_dict for regularizer module
  • predictor: state_dict for predictor (present for HamJEPA, absent for SIGReg baseline)

Security note: torch.load uses Python pickle under the hood. Only load checkpoints you trust.


Architecture (high level)

Backbone: ResNet-18 in token mode.

Tokenization/readout (ImageNet-100 configs):

  • ResNet layer3 feature map
  • 1×1 projection to token_d_f=32
  • Adaptive pool to 8×8 (token_hw=8)
  • Flatten tokens: 8×8×32 = 2048 dims/image

Phase-space interface (when split_qp=true):

  • q ∈ R^{1024}
  • p ∈ R^{1024}
  • z = [q; p] ∈ R^{2048}

Projector: identity (no extra MLP head in released runs).

In SIGReg, q/p are a partition of features (no Hamiltonian semantics enforced).
In HamJEPA, objective + rollout explicitly act on (q,p) as phase-space state.


Training setup (ImageNet-100)

  • Dataset: ImageNet-100 (100-class subset)
  • Views: 2 global crops, no local crops
  • Crop resolution: 224
  • Optimizer: AdamW
  • Precision: bf16 mixed precision
  • Schedule: warmup + cosine
  • Epochs: 45

HamJEPA specifics

  • Symplectic leapfrog latent rollout predictor
  • Short-horizon multi-step latent prediction
  • Matching loss applied on configured latent state (q or (q,p), per config)

SIGReg+tokens specifics

  • Two-view JEPA-style prediction + SIGReg regularization
  • SIGReg uses a sliced characteristic-function discrepancy (Epps–Pulley style)

Evaluation protocol

Frozen encoder features evaluated by:

  • Linear probe top-1
  • kNN top-1@k=20 (cosine kNN)

Unless noted otherwise, diagnostics use raw features (no posthoc whitening).


Results

ImageNet-100 (45 epochs, strict best-over-q/p/(q,p) comparison)

Best SIGReg readout and best HamJEPA readout are chosen independently for each metric.

  • Linear probe top-1: 25.54 → 32.08 (+6.54 points)
  • kNN top-1@20: 20.94 → 24.92 (+3.98 points)
Method Best variant (LP) LP top-1 (%) Best variant (kNN@20) kNN@20 top-1 (%)
SIGReg+tokens qp 25.54 qp 20.94
HamJEPA (MV-HJEPA) p 32.08 q 24.92

(For reference, best-over-k kNN top-1 is 21.70 → 25.96, +4.26 points.)

CIFAR-100 (30 epochs, 3 seeds: 42/45/49)

Reported values are seed-mean frozen-feature results.

  • Linear probe top-1: 30.43 → 34.18 (+3.75 points)
  • kNN top-1@20: 26.56 → 31.45 (+4.89 points)
Method Best-eval variant Linear probe top-1 (%) kNN top-1@20 (%)
SIGReg baseline readout 30.43 26.56
HamJEPA (MV-HJEPA) qp (probe) / q (kNN@20) 34.18 31.45

CIFAR-100 (80 epochs)

Longer-horizon single-run setting (same minimal/headless regime):

  • Linear probe top-1: 33.95 → 44.59 (+10.64 points)
  • kNN top-1@20: 27.98 → 34.43 (+6.45 points)

How to load a checkpoint

import torch

ckpt_path = "checkpoints/imagenet_hjepa_mv_epoch_045.pth"
ckpt = torch.load(ckpt_path, map_location="cpu")

print(ckpt.keys())          # e.g. ['config','epoch','encoder','projector','regularizer','predictor',...]
print(ckpt["epoch"])        # e.g. 45
encoder_state = ckpt["encoder"]

---

## `eval_runs/` contents (artifact index)

This repo also includes full evaluation artifacts under `eval_runs/` for reproducibility.

### CIFAR-100

- `eval_runs/cifar100_hjepa_mv_80_epoch`
- `eval_runs/cifar100_sigreg_tokens_80_epoch`
- `eval_runs/cifar100_seed_42_30_epoch`
- `eval_runs/cifar100_seed_45_30_epoch`
- `eval_runs/cifar100_seed_49_30_epoch`

Notes:
- `*_80_epoch` folders are the long-run CIFAR-100 evaluations.
- `cifar100_seed_{42,45,49}_30_epoch` are the 30-epoch seed runs used for mean±std reporting.

### ImageNet-100 (main + checkpoint sweeps)

Main 45-epoch evals:
- `eval_runs/imagenet_hjepa_mv`
- `eval_runs/imagenet_sigreg_tokens`

Intermediate checkpoint evals:
- `eval_runs/imagenet_hjepa_mv_epoch_010`
- `eval_runs/imagenet_hjepa_mv_epoch_020`
- `eval_runs/imagenet_hjepa_mv_epoch_030`
- `eval_runs/imagenet_hjepa_mv_epoch_040`
- `eval_runs/imagenet_sigreg_tokens_epoch_010`
- `eval_runs/imagenet_sigreg_tokens_epoch_020`
- `eval_runs/imagenet_sigreg_tokens_epoch_030`
- `eval_runs/imagenet_sigreg_tokens_epoch_040`

### ImageNet-100 linear-probe LR sweeps

HamJEPA LR sweep:
- `eval_runs/imagenet_hjepa_mv_lp_lr0p01`
- `eval_runs/imagenet_hjepa_mv_lp_lr0p03`
- `eval_runs/imagenet_hjepa_mv_lp_lr0p1`
- `eval_runs/imagenet_hjepa_mv_lp_lr0p2`
- `eval_runs/imagenet_hjepa_mv_lp_lr0p3`

SIGReg LR sweep:
- `eval_runs/imagenet_sigreg_tokens_lp_lr0p01`
- `eval_runs/imagenet_sigreg_tokens_lp_lr0p03`
- `eval_runs/imagenet_sigreg_tokens_lp_lr0p1`
- `eval_runs/imagenet_sigreg_tokens_lp_lr0p2`
- `eval_runs/imagenet_sigreg_tokens_lp_lr0p3`

### Per-folder contents (typical)

Each `eval_runs/...` folder contains:
- `metrics.json` (primary machine-readable results)
- plots in `plots/` (summary, kNN sweep, covariance/cosine/norm diagnostics)
- auxiliary files produced by the evaluation scripts

This structure is intended so all reported numbers can be traced directly to a corresponding `metrics.json`.
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support