Trained
Collection
3 items
โข
Updated
Sparse Autoencoders (SAEs) trained on all 36 layers of google/t5gemma-2-270m-270m for mechanistic interpretability and activation steering.
from huggingface_hub import hf_hub_download
import torch
# Load SAE
sae_path = hf_hub_download(
repo_id="mindchain/t5gemma2-sae-all-layers",
filename="encoder/sae_encoder_00.pt"
)
sae = torch.load(sae_path)
# Forward pass
activations = ... # Your activations
features = torch.relu(activations.float() @ sae['W_enc'] + sae['b_enc'])
reconstruction = features @ sae['W_dec'] + sae['b_dec']
| Property | Value |
|---|---|
| Base Model | google/t5gemma-2-270m-270m |
| Architecture | T5 Text-to-Text (Encoder-Decoder) |
| Total Parameters | ~540M (270M encoder + 270M decoder) |
| Hidden Size (d_model) | 640 |
| FFN Dimension | 2,560 |
| Attention Heads | 10 |
| Vocabulary Size | 32,128 |
| Encoder Layers | 18 |
| Decoder Layers | 18 |
| Property | Value |
|---|---|
| SAE Input Dimension (d_in) | 640 |
| SAE Hidden Dimension (d_sae) | 4,096 |
| Expansion Factor | 6.4x |
| L1 Coefficient | 0.01 |
| Training Epochs | 5 |
| Batch Size | 2 |
| Learning Rate | 1e-4 |
| Optimizer | AdamW |
| Hook Point | self_attn.o_proj |
| Precision (Model) | float16 |
| Precision (SAE) | float32 |
| Component | Layers | Files |
|---|---|---|
| Encoder | 0-17 | encoder/sae_encoder_00.pt - sae_encoder_17.pt |
| Decoder | 0-17 | decoder/sae_decoder_00.pt - sae_decoder_17.pt |
| Total | 36 | 36 checkpoint files |
class SteeringHook:
def __init__(self, sae, feature_idx, strength):
self.sae = sae
self.feature_idx = feature_idx
self.strength = strength
def __call__(self, module, input, output):
acts = output[0] if isinstance(output, tuple) else output
acts_f32 = acts.float()
features = torch.relu(acts_f32 @ sae['W_enc'] + sae['b_enc'])
features[:, :, self.feature_idx] *= (1 + self.strength)
modified = features @ sae['W_dec'] + sae['b_dec']
if isinstance(output, tuple):
output[0].data = modified.to(output[0].dtype)
else:
output.data = modified.to(output.dtype)
# Install hook
hook = SteeringHook(sae, feature_idx=123, strength=0.5)
handle = model.model.encoder.layers[0].self_attn.o_proj.register_forward_hook(hook)
{
'model_name': 'google/t5gemma-2-270m-270m',
'layer_type': 'encoder' or 'decoder',
'layer_idx': 0-17,
'd_in': 640,
'd_sae': 4096,
'W_enc': Tensor([640, 4096]),
'b_enc': Tensor([4096]),
'W_dec': Tensor([4096, 640]),
'b_dec': Tensor([640]),
'history': {'loss': [...], 'l0': [...]}
}
MIT License
Trained by mindchain
Training Date: December 2025
Keywords: SAE, sparse autoencoder, T5, T5Gemma, interpretability, mechanistic interpretability, activation steering, feature visualization, Neuronpedia, Gemma Scope, explainable AI, XAI, model steering, feature engineering, representation learning, dictionary learning, layer interpretation, hidden state analysis, NLP, natural language processing, text-to-text, language model, LLM, large language model, LLM interpretability, encoder-decoder, transformer, neural network interpretation, sparse representations, features, embeddings
Base model
google/t5gemma-2-270m-270m