Upload SAE model weights, config, and training state
Browse files- README.md +86 -0
- config.json +36 -0
- model.safetensors +3 -0
- training_state.pt +3 -0
README.md
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- physical-ai-interpretability-sae
|
| 5 |
+
- LeRobot
|
| 6 |
+
- Robotics
|
| 7 |
+
datasets:
|
| 8 |
+
- villekuosmanen/drop_footbag_into_dice_tower
|
| 9 |
+
- villekuosmanen/drop_footbag_into_dice_tower_continuous
|
| 10 |
+
- villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.0.0
|
| 11 |
+
- villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.1.0
|
| 12 |
+
- villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.2.0
|
| 13 |
+
- villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.3.0
|
| 14 |
+
- villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.4.0
|
| 15 |
+
- villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.5.0
|
| 16 |
+
- villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.6.0
|
| 17 |
+
- villekuosmanen/eval_footbag_11Sep
|
| 18 |
+
library_name: physical-ai-interpretability
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
# Sparse Autoencoder (SAE) Model
|
| 22 |
+
|
| 23 |
+
This model is a Sparse Autoencoder trained for interpretability analysis of robotics policies using the LeRobot framework.
|
| 24 |
+
|
| 25 |
+
## Model Details
|
| 26 |
+
|
| 27 |
+
- **Architecture**: Multi-modal Sparse Autoencoder
|
| 28 |
+
- **Training Dataset**: `[villekuosmanen/drop_footbag_into_dice_tower, villekuosmanen/drop_footbag_into_dice_tower_continuous, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.0.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.1.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.2.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.3.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.4.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.5.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.6.0, villekuosmanen/eval_footbag_11Sep]`
|
| 29 |
+
- **Base Policy**: LeRobot ACT policy
|
| 30 |
+
- **Layer Target**: `model.encoder.layers.3.norm2`
|
| 31 |
+
- **Tokens**: 77
|
| 32 |
+
- **Token Dimension**: 128
|
| 33 |
+
- **Feature Dimension**: 12320
|
| 34 |
+
- **Expansion Factor**: 1.25
|
| 35 |
+
|
| 36 |
+
## Training Configuration
|
| 37 |
+
|
| 38 |
+
- **Learning Rate**: 0.0001
|
| 39 |
+
- **Batch Size**: 16
|
| 40 |
+
- **L1 Penalty**: 0.3
|
| 41 |
+
- **Epochs**: 20
|
| 42 |
+
- **Optimizer**: adam
|
| 43 |
+
|
| 44 |
+
## Usage
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
from src.sae.trainer import load_sae_from_hub
|
| 48 |
+
|
| 49 |
+
# Load model from Hub
|
| 50 |
+
model = load_sae_from_hub("villekuosmanen/drop_footbag_into_dice_tower_ood_sae_success")
|
| 51 |
+
|
| 52 |
+
# Or load using builder
|
| 53 |
+
from src.sae.builder import SAEBuilder
|
| 54 |
+
builder = SAEBuilder(device='cuda')
|
| 55 |
+
model = builder.load_from_hub("villekuosmanen/drop_footbag_into_dice_tower_ood_sae_success")
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
## Out-of-Distribution Detection
|
| 59 |
+
|
| 60 |
+
This SAE model can be used for OOD detection with LeRobot policies:
|
| 61 |
+
|
| 62 |
+
```python
|
| 63 |
+
from src.ood import OODDetector
|
| 64 |
+
|
| 65 |
+
# Create OOD detector with Hub-loaded SAE
|
| 66 |
+
ood_detector = OODDetector(
|
| 67 |
+
policy=your_policy,
|
| 68 |
+
sae_hub_repo_id="villekuosmanen/drop_footbag_into_dice_tower_ood_sae_success"
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Fit threshold and use for detection
|
| 72 |
+
ood_detector.fit_ood_threshold_to_validation_dataset(validation_dataset)
|
| 73 |
+
is_ood, error = ood_detector.is_out_of_distribution(observation)
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
## Files
|
| 77 |
+
|
| 78 |
+
- `model.safetensors`: The trained SAE model weights
|
| 79 |
+
- `config.json`: Training and model configuration
|
| 80 |
+
- `training_state.pt`: Complete training state (optimizer, scheduler, metrics)
|
| 81 |
+
- `ood_params.json`: OOD detection parameters (if fitted)
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
## Framework
|
| 85 |
+
|
| 86 |
+
This model was trained using the [physical-ai-interpretability](https://github.com/your-repo/physical-ai-interpretability) framework with [LeRobot](https://github.com/huggingface/lerobot).
|
config.json
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"num_tokens": 77,
|
| 3 |
+
"token_dim": 128,
|
| 4 |
+
"expansion_factor": 1.25,
|
| 5 |
+
"activation_fn": "relu",
|
| 6 |
+
"use_token_sampling": true,
|
| 7 |
+
"fixed_tokens": [
|
| 8 |
+
0,
|
| 9 |
+
1
|
| 10 |
+
],
|
| 11 |
+
"sampling_strategy": "block_average",
|
| 12 |
+
"sampling_stride": 8,
|
| 13 |
+
"max_sampled_tokens": 200,
|
| 14 |
+
"block_size": 8,
|
| 15 |
+
"batch_size": 16,
|
| 16 |
+
"learning_rate": 0.0001,
|
| 17 |
+
"num_epochs": 20,
|
| 18 |
+
"validation_split": 0.1,
|
| 19 |
+
"l1_penalty": 0.3,
|
| 20 |
+
"optimizer": "adam",
|
| 21 |
+
"weight_decay": 1e-05,
|
| 22 |
+
"lr_schedule": "constant",
|
| 23 |
+
"warmup_epochs": 2,
|
| 24 |
+
"gradient_clip_norm": 1.0,
|
| 25 |
+
"early_stopping_patience": 10,
|
| 26 |
+
"early_stopping_min_delta": 1e-05,
|
| 27 |
+
"log_every": 5,
|
| 28 |
+
"save_every": 1000,
|
| 29 |
+
"validate_every": 500,
|
| 30 |
+
"device": "cuda",
|
| 31 |
+
"repo_id": "[villekuosmanen/drop_footbag_into_dice_tower, villekuosmanen/drop_footbag_into_dice_tower_continuous, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.0.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.1.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.2.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.3.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.4.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.5.0, villekuosmanen/dAgger_drop_footbag_into_dice_tower_1.6.0, villekuosmanen/eval_footbag_11Sep]",
|
| 32 |
+
"repo_hash": "e78b65d9",
|
| 33 |
+
"layer_name": "model.encoder.layers.3.norm2",
|
| 34 |
+
"activation_cache_path": "/home/ville/.cache/physical_ai_interpretability/sae_activations",
|
| 35 |
+
"experiment_name": "sae_eval_footbag_11Sep_e78b65d9"
|
| 36 |
+
}
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a7786da09e811c9dc29a8df815d0f45e4d536f687b6642cb5256ee5a92a60174
|
| 3 |
+
size 971496408
|
training_state.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7d435aafa93018ab4cb184c496b19c52d3df2af5291a9f1cfc4986ed4e515f5a
|
| 3 |
+
size 1942998303
|