T5Gemma 2 Sparse Autoencoders (All 36 Layers)

Sparse Autoencoders (SAEs) trained on all 36 layers of google/t5gemma-2-270m-270m for mechanistic interpretability and activation steering.

Open In Colab

Quick Start

from huggingface_hub import hf_hub_download
import torch

# Load SAE
sae_path = hf_hub_download(
    repo_id="mindchain/t5gemma2-sae-all-layers",
    filename="encoder/sae_encoder_00.pt"
)
sae = torch.load(sae_path)

# Forward pass
activations = ...  # Your activations
features = torch.relu(activations.float() @ sae['W_enc'] + sae['b_enc'])
reconstruction = features @ sae['W_dec'] + sae['b_dec']

Model Specifications

Property Value
Base Model google/t5gemma-2-270m-270m
Architecture T5 Text-to-Text (Encoder-Decoder)
Total Parameters ~540M (270M encoder + 270M decoder)
Hidden Size (d_model) 640
FFN Dimension 2,560
Attention Heads 10
Vocabulary Size 32,128
Encoder Layers 18
Decoder Layers 18

SAE Configuration

Property Value
SAE Input Dimension (d_in) 640
SAE Hidden Dimension (d_sae) 4,096
Expansion Factor 6.4x
L1 Coefficient 0.01
Training Epochs 5
Batch Size 2
Learning Rate 1e-4
Optimizer AdamW
Hook Point self_attn.o_proj
Precision (Model) float16
Precision (SAE) float32

Coverage (36 SAEs)

Component Layers Files
Encoder 0-17 encoder/sae_encoder_00.pt - sae_encoder_17.pt
Decoder 0-17 decoder/sae_decoder_00.pt - sae_decoder_17.pt
Total 36 36 checkpoint files

Use Cases

  1. Mechanistic Interpretability - Understand what features represent
  2. Activation Steering - Modify model behavior by steering features
  3. Feature Analysis - Find concept-specific features
  4. Model Interventions - Ablate or enhance specific capabilities

Activation Steering

class SteeringHook:
    def __init__(self, sae, feature_idx, strength):
        self.sae = sae
        self.feature_idx = feature_idx
        self.strength = strength
    
    def __call__(self, module, input, output):
        acts = output[0] if isinstance(output, tuple) else output
        acts_f32 = acts.float()
        features = torch.relu(acts_f32 @ sae['W_enc'] + sae['b_enc'])
        features[:, :, self.feature_idx] *= (1 + self.strength)
        modified = features @ sae['W_dec'] + sae['b_dec']
        
        if isinstance(output, tuple):
            output[0].data = modified.to(output[0].dtype)
        else:
            output.data = modified.to(output.dtype)

# Install hook
hook = SteeringHook(sae, feature_idx=123, strength=0.5)
handle = model.model.encoder.layers[0].self_attn.o_proj.register_forward_hook(hook)

Checkpoint Structure

{
    'model_name': 'google/t5gemma-2-270m-270m',
    'layer_type': 'encoder' or 'decoder',
    'layer_idx': 0-17,
    'd_in': 640,
    'd_sae': 4096,
    'W_enc': Tensor([640, 4096]),
    'b_enc': Tensor([4096]),
    'W_dec': Tensor([4096, 640]),
    'b_dec': Tensor([640]),
    'history': {'loss': [...], 'l0': [...]}
}

Training Details

  • Dataset: 1500 diverse text samples, 5 epochs
  • Hook Point: self_attn.o_proj (attention output)
  • Final Loss: ~0.0014
  • Final L0: ~1367/4096 (33% sparsity)
  • MSE: 0.105
  • Cosine Similarity: 0.80

Related Work

License

MIT License

Credits

Trained by mindchain
Training Date: December 2025


Keywords: SAE, sparse autoencoder, T5, T5Gemma, interpretability, mechanistic interpretability, activation steering, feature visualization, Neuronpedia, Gemma Scope, explainable AI, XAI, model steering, feature engineering, representation learning, dictionary learning, layer interpretation, hidden state analysis, NLP, natural language processing, text-to-text, language model, LLM, large language model, LLM interpretability, encoder-decoder, transformer, neural network interpretation, sparse representations, features, embeddings

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for mindchain/t5gemma2-sae-all-layers

Finetuned
(3)
this model

Collection including mindchain/t5gemma2-sae-all-layers