# T5Gemma 2 SAE - Quick Start Guide

This notebook shows how to use the **T5Gemma 2 Sparse Autoencoders** from [mindchain/t5gemma2-sae-all-layers](https://huggingface.co/mindchain/t5gemma2-sae-all-layers).

## What are SAEs?

Sparse Autoencoders (SAEs) help interpret what features a neural network has learned. They can be used for:
- **Mechanistic Interpretability** - Understanding model internals
- **Activation Steering** - Modifying model behavior 
- **Feature Visualization** - Seeing what concepts each feature detects

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/README.md)

## 1. Install Dependencies

First, install the required libraries:

In [None]:
!pip install -q transformers torch huggingface_hub

## 2. Import Libraries

In [None]:
import torch
from huggingface_hub import hf_hub_download

print("Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")

## 3. Load a Trained SAE

Load one of the 36 trained SAEs (18 encoder + 18 decoder layers).

In [None]:
from huggingface_hub import hf_hub_download

repo_id = "mindchain/t5gemma2-sae-all-layers"

# Load Encoder Layer 0 SAE
sae_path = hf_hub_download(
 repo_id=repo_id,
 filename="encoder/sae_encoder_00.pt"
)

sae = torch.load(sae_path, map_location="cpu")

print(f"SAE loaded from: {sae_path}")
print(f"Model: {sae['model_name']}")
print(f"Layer: {sae['layer_type']} {sae['layer_idx']}")
print(f"d_in: {sae['d_in']}, d_sae: {sae['d_sae']}")

# Show training history
if 'history' in sae:
 print(f"Final Loss: {sae['history']['loss'][-1]:.6f}")
 print(f"Final L0: {sae['history']['l0'][-1]:.1f}")

## 4. SAE Forward Pass

Define functions to run activations through the SAE.

In [None]:
def sae_encode(activations, sae):
 """Activations to Sparse Features"""
 acts_f32 = activations.float()
 return torch.relu(acts_f32 @ sae['W_enc'] + sae['b_enc'])

def sae_decode(features, sae):
 """Sparse Features to Activations"""
 return features @ sae['W_dec'] + sae['b_dec']

def sae_forward(activations, sae):
 """Full SAE forward pass"""
 features = sae_encode(activations, sae)
 recon = sae_decode(features, sae)
 return recon, features

print("SAE functions defined!")

## 5. Test the SAE

Create dummy activations and test reconstruction quality.

In [None]:
import torch.nn.functional as F

# Create dummy activation
dummy_activations = torch.randn(1, 10, 640)

# Run through SAE
recon, features = sae_forward(dummy_activations, sae)

# Calculate metrics
mse = F.mse_loss(recon, dummy_activations).item()
cosine = F.cosine_similarity(
 dummy_activations.flatten(), 
 recon.flatten(), 
 dim=0
).item()
l0 = (features > 0).sum().item()

print(f"Input shape: {dummy_activations.shape}")
print(f"Features shape: {features.shape}")
print(f"\nReconstruction Quality:")
print(f" MSE: {mse:.6f}")
print(f" Cosine Similarity: {cosine:.4f}")
print(f" L0 (active features): {l0} / {features.shape[-1]}")

## 6. All Available SAEs

This repository contains **36 SAEs** in total:

| Layer Type | Range | Count |
|------------|-------|-------|
| Encoder | 0-17 | 18 SAEs |
| Decoder | 0-17 | 18 SAEs |
| **Total** | - | **36 SAEs** |

To load a different layer:
```python
# Encoder Layer 5
sae_path = hf_hub_download(
 repo_id="mindchain/t5gemma2-sae-all-layers",
 filename="encoder/sae_encoder_05.pt"
)

# Decoder Layer 10
sae_path = hf_hub_download(
 repo_id="mindchain/t5gemma2-sae-all-layers",
 filename="decoder/sae_decoder_10.pt"
)
```

## 7. Usage with T5Gemma 2 Model

To use SAEs with the actual T5Gemma 2 model:

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Load model
model = AutoModelForSeq2SeqLM.from_pretrained(
 "google/t5gemma-2-270m-270m",
 device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("google/t5gemma-2-270m-270m")

print("Model loaded!")

## Links

- **HuggingFace Model**: [mindchain/t5gemma2-sae-all-layers](https://huggingface.co/mindchain/t5gemma2-sae-all-layers)
- **Base Model**: [google/t5gemma-2-270m-270m](https://huggingface.co/google/t5gemma-2-270m-270m)
- **SAELens**: [github.com/decoderesearch/SAELens](https://github.com/decoderesearch/SAELens)
- **Neuronpedia**: [neuronpedia.org](https://neuronpedia.org)

---

Trained by [mindchain](https://huggingface.co/mindchain) | December 2025