Update README.md
Browse files
README.md
CHANGED
|
@@ -11,9 +11,80 @@ This repository contains the following SAEs:
|
|
| 11 |
- qwen2_0.5b_sae_post_residual_layer_12
|
| 12 |
- qwen2_0.5b_sae_post_residual_layer_16
|
| 13 |
|
|
|
|
| 14 |
Load these SAEs using SAELens as below:
|
| 15 |
```python
|
| 16 |
from sae_lens import SAE
|
| 17 |
|
| 18 |
sae, cfg_dict, sparsity = SAE.from_pretrained("NoamDiamant52/qwen2_0.5b_sae", "<sae_id>")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
```
|
|
|
|
| 11 |
- qwen2_0.5b_sae_post_residual_layer_12
|
| 12 |
- qwen2_0.5b_sae_post_residual_layer_16
|
| 13 |
|
| 14 |
+
|
| 15 |
Load these SAEs using SAELens as below:
|
| 16 |
```python
|
| 17 |
from sae_lens import SAE
|
| 18 |
|
| 19 |
sae, cfg_dict, sparsity = SAE.from_pretrained("NoamDiamant52/qwen2_0.5b_sae", "<sae_id>")
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
THose SAEs were trained with the following configuration:
|
| 23 |
+
|
| 24 |
+
```python
|
| 25 |
+
for name in ["blocks.0.hook_resid_post", "blocks.4.hook_resid_post", "blocks.8.hook_resid_post", "blocks.12.hook_resid_post", "blocks.16.hook_resid_post"]: #["blocks.0.hook_resid_post",
|
| 26 |
+
total_training_steps = 30_000 # probably we should do more
|
| 27 |
+
batch_size = 4096
|
| 28 |
+
total_training_tokens = total_training_steps * batch_size
|
| 29 |
+
num = int(name.split(".")[1])
|
| 30 |
+
|
| 31 |
+
lr_warm_up_steps = 0
|
| 32 |
+
lr_decay_steps = total_training_steps // 5 # 20% of training
|
| 33 |
+
l1_warm_up_steps = total_training_steps // 20 # 5% of training
|
| 34 |
+
|
| 35 |
+
cfg = LanguageModelSAERunnerConfig(
|
| 36 |
+
# Data Generating Function (Model + Training Distibuion)
|
| 37 |
+
model_name="Qwen/Qwen2.5-0.5B", # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
|
| 38 |
+
hook_name=name, # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
|
| 39 |
+
hook_layer=num, # Only one layer in the model.
|
| 40 |
+
d_in=896, # the width of the mlp output.
|
| 41 |
+
dataset_path="NoamDiamant52/TinyStories_tokenized", # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
|
| 42 |
+
is_dataset_tokenized=True,
|
| 43 |
+
streaming=True, # we could pre-download the token dataset if it was small.
|
| 44 |
+
# SAE Parameters
|
| 45 |
+
mse_loss_normalization=None, # We won't normalize the mse loss,
|
| 46 |
+
expansion_factor=16, # the width of the SAE. Larger will result in better stats but slower training.
|
| 47 |
+
b_dec_init_method="zeros", # The geometric median can be used to initialize the decoder weights.
|
| 48 |
+
apply_b_dec_to_input=False, # We won't apply the decoder weights to the input.
|
| 49 |
+
normalize_sae_decoder=False,
|
| 50 |
+
scale_sparsity_penalty_by_decoder_norm=True,
|
| 51 |
+
decoder_heuristic_init=True,
|
| 52 |
+
init_encoder_as_decoder_transpose=True,
|
| 53 |
+
normalize_activations="expected_average_only_in",
|
| 54 |
+
# Training Parameters
|
| 55 |
+
lr=5e-5, # lower the better, we'll go fairly high to speed up the tutorial.
|
| 56 |
+
adam_beta1=0.9, # adam params (default, but once upon a time we experimented with these.)
|
| 57 |
+
adam_beta2=0.999,
|
| 58 |
+
lr_scheduler_name="constant", # constant learning rate with warmup. Could be better schedules out there.
|
| 59 |
+
lr_warm_up_steps=lr_warm_up_steps, # this can help avoid too many dead features initially.
|
| 60 |
+
lr_decay_steps=lr_decay_steps, # this will help us avoid overfitting.
|
| 61 |
+
l1_coefficient=5, # will control how sparse the feature activations are
|
| 62 |
+
l1_warm_up_steps=l1_warm_up_steps, # this can help avoid too many dead features initially.
|
| 63 |
+
lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)
|
| 64 |
+
train_batch_size_tokens=batch_size,
|
| 65 |
+
context_size=512, # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
|
| 66 |
+
# Activation Store Parameters
|
| 67 |
+
n_batches_in_buffer=64, # controls how many activations we store / shuffle.
|
| 68 |
+
training_tokens=total_training_tokens, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
|
| 69 |
+
store_batch_size_prompts=16,
|
| 70 |
+
# Resampling protocol
|
| 71 |
+
use_ghost_grads=False, # we don't use ghost grads anymore.
|
| 72 |
+
feature_sampling_window=1000, # this controls our reporting of feature sparsity stats
|
| 73 |
+
dead_feature_window=1000, # would effect resampling or ghost grads if we were using it.
|
| 74 |
+
dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it.
|
| 75 |
+
# WANDB
|
| 76 |
+
log_to_wandb=True, # always use wandb unless you are just testing code.
|
| 77 |
+
wandb_project="sae_lens_tutorial",
|
| 78 |
+
wandb_log_frequency=30,
|
| 79 |
+
eval_every_n_wandb_logs=20,
|
| 80 |
+
# Misc
|
| 81 |
+
device=device,
|
| 82 |
+
seed=42,
|
| 83 |
+
n_checkpoints=0,
|
| 84 |
+
checkpoint_path="checkpoints",
|
| 85 |
+
dtype="float32",
|
| 86 |
+
)
|
| 87 |
+
# look at the next cell to see some instruction for what to do while this is running.
|
| 88 |
+
sparse_autoencoder = SAETrainingRunner(cfg).run()
|
| 89 |
+
sparse_autoencoder.save_model(f"post_residual_layer_{num}")
|
| 90 |
```
|