villekuosmanen commited on
Commit
541bd02
·
verified ·
1 Parent(s): 44d1456

Upload SAE model weights, config, and training state

Browse files
Files changed (4) hide show
  1. README.md +86 -0
  2. config.json +36 -0
  3. model.safetensors +3 -0
  4. training_state.pt +3 -0
README.md ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - physical-ai-interpretability-sae
5
+ - LeRobot
6
+ - Robotics
7
+ datasets:
8
+ - villekuosmanen/drop_footbag_into_dice_tower
9
+ - villekuosmanen/drop_footbag_into_dice_tower_continuous
10
+ - villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.0.0
11
+ - villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.1.0
12
+ - villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.2.0
13
+ - villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.3.0
14
+ - villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.4.0
15
+ - villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.5.0
16
+ - villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.6.0
17
+ - villekuosmanen/eval_footbag_11Sep
18
+ library_name: physical-ai-interpretability
19
+ ---
20
+
21
+ # Sparse Autoencoder (SAE) Model
22
+
23
+ This model is a Sparse Autoencoder trained for interpretability analysis of robotics policies using the LeRobot framework.
24
+
25
+ ## Model Details
26
+
27
+ - **Architecture**: Multi-modal Sparse Autoencoder
28
+ - **Training Dataset**: `[villekuosmanen/drop_footbag_into_dice_tower, villekuosmanen/drop_footbag_into_dice_tower_continuous, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.0.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.1.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.2.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.3.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.4.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.5.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.6.0, villekuosmanen/eval_footbag_11Sep]`
29
+ - **Base Policy**: LeRobot ACT policy
30
+ - **Layer Target**: `model.encoder.layers.3.norm2`
31
+ - **Tokens**: 77
32
+ - **Token Dimension**: 128
33
+ - **Feature Dimension**: 12320
34
+ - **Expansion Factor**: 1.25
35
+
36
+ ## Training Configuration
37
+
38
+ - **Learning Rate**: 0.0001
39
+ - **Batch Size**: 16
40
+ - **L1 Penalty**: 0.3
41
+ - **Epochs**: 20
42
+ - **Optimizer**: adam
43
+
44
+ ## Usage
45
+
46
+ ```python
47
+ from src.sae.trainer import load_sae_from_hub
48
+
49
+ # Load model from Hub
50
+ model = load_sae_from_hub("villekuosmanen/drop_footbag_into_dice_tower_ood_sae_success")
51
+
52
+ # Or load using builder
53
+ from src.sae.builder import SAEBuilder
54
+ builder = SAEBuilder(device='cuda')
55
+ model = builder.load_from_hub("villekuosmanen/drop_footbag_into_dice_tower_ood_sae_success")
56
+ ```
57
+
58
+ ## Out-of-Distribution Detection
59
+
60
+ This SAE model can be used for OOD detection with LeRobot policies:
61
+
62
+ ```python
63
+ from src.ood import OODDetector
64
+
65
+ # Create OOD detector with Hub-loaded SAE
66
+ ood_detector = OODDetector(
67
+ policy=your_policy,
68
+ sae_hub_repo_id="villekuosmanen/drop_footbag_into_dice_tower_ood_sae_success"
69
+ )
70
+
71
+ # Fit threshold and use for detection
72
+ ood_detector.fit_ood_threshold_to_validation_dataset(validation_dataset)
73
+ is_ood, error = ood_detector.is_out_of_distribution(observation)
74
+ ```
75
+
76
+ ## Files
77
+
78
+ - `model.safetensors`: The trained SAE model weights
79
+ - `config.json`: Training and model configuration
80
+ - `training_state.pt`: Complete training state (optimizer, scheduler, metrics)
81
+ - `ood_params.json`: OOD detection parameters (if fitted)
82
+ ```
83
+
84
+ ## Framework
85
+
86
+ This model was trained using the [physical-ai-interpretability](https://github.com/your-repo/physical-ai-interpretability) framework with [LeRobot](https://github.com/huggingface/lerobot).
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_tokens": 77,
3
+ "token_dim": 128,
4
+ "expansion_factor": 1.25,
5
+ "activation_fn": "relu",
6
+ "use_token_sampling": true,
7
+ "fixed_tokens": [
8
+ 0,
9
+ 1
10
+ ],
11
+ "sampling_strategy": "block_average",
12
+ "sampling_stride": 8,
13
+ "max_sampled_tokens": 200,
14
+ "block_size": 8,
15
+ "batch_size": 16,
16
+ "learning_rate": 0.0001,
17
+ "num_epochs": 20,
18
+ "validation_split": 0.1,
19
+ "l1_penalty": 0.3,
20
+ "optimizer": "adam",
21
+ "weight_decay": 1e-05,
22
+ "lr_schedule": "constant",
23
+ "warmup_epochs": 2,
24
+ "gradient_clip_norm": 1.0,
25
+ "early_stopping_patience": 10,
26
+ "early_stopping_min_delta": 1e-05,
27
+ "log_every": 5,
28
+ "save_every": 1000,
29
+ "validate_every": 500,
30
+ "device": "cuda",
31
+ "repo_id": "[villekuosmanen/drop_footbag_into_dice_tower, villekuosmanen/drop_footbag_into_dice_tower_continuous, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.0.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.1.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.2.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.3.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.4.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.5.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.6.0, villekuosmanen/eval_footbag_11Sep]",
32
+ "repo_hash": "e78b65d9",
33
+ "layer_name": "model.encoder.layers.3.norm2",
34
+ "activation_cache_path": "/home/ville/.cache/physical_ai_interpretability/sae_activations",
35
+ "experiment_name": "sae_eval_footbag_11Sep_e78b65d9"
36
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7786da09e811c9dc29a8df815d0f45e4d536f687b6642cb5256ee5a92a60174
3
+ size 971496408
training_state.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d435aafa93018ab4cb184c496b19c52d3df2af5291a9f1cfc4986ed4e515f5a
3
+ size 1942998303