{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# T5Gemma 2 SAE - Quick Start Guide\n", "\n", "This notebook shows how to use the **T5Gemma 2 Sparse Autoencoders** from [mindchain/t5gemma2-sae-all-layers](https://huggingface.co/mindchain/t5gemma2-sae-all-layers).\n", "\n", "## What are SAEs?\n", "\n", "Sparse Autoencoders (SAEs) help interpret what features a neural network has learned. They can be used for:\n", "- **Mechanistic Interpretability** - Understanding model internals\n", "- **Activation Steering** - Modifying model behavior \n", "- **Feature Visualization** - Seeing what concepts each feature detects\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/README.md)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Install Dependencies\n", "\n", "First, install the required libraries:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install -q transformers torch huggingface_hub" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Import Libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from huggingface_hub import hf_hub_download\n", "\n", "print(\"Libraries imported successfully!\")\n", "print(f\"PyTorch version: {torch.__version__}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Load a Trained SAE\n", "\n", "Load one of the 36 trained SAEs (18 encoder + 18 decoder layers)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import hf_hub_download\n", "\n", "repo_id = \"mindchain/t5gemma2-sae-all-layers\"\n", "\n", "# Load Encoder Layer 0 SAE\n", "sae_path = hf_hub_download(\n", " repo_id=repo_id,\n", " filename=\"encoder/sae_encoder_00.pt\"\n", ")\n", "\n", "sae = torch.load(sae_path, map_location=\"cpu\")\n", "\n", "print(f\"SAE loaded from: {sae_path}\")\n", "print(f\"Model: {sae['model_name']}\")\n", "print(f\"Layer: {sae['layer_type']} {sae['layer_idx']}\")\n", "print(f\"d_in: {sae['d_in']}, d_sae: {sae['d_sae']}\")\n", "\n", "# Show training history\n", "if 'history' in sae:\n", " print(f\"Final Loss: {sae['history']['loss'][-1]:.6f}\")\n", " print(f\"Final L0: {sae['history']['l0'][-1]:.1f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. SAE Forward Pass\n", "\n", "Define functions to run activations through the SAE." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def sae_encode(activations, sae):\n", " \"\"\"Activations to Sparse Features\"\"\"\n", " acts_f32 = activations.float()\n", " return torch.relu(acts_f32 @ sae['W_enc'] + sae['b_enc'])\n", "\n", "def sae_decode(features, sae):\n", " \"\"\"Sparse Features to Activations\"\"\"\n", " return features @ sae['W_dec'] + sae['b_dec']\n", "\n", "def sae_forward(activations, sae):\n", " \"\"\"Full SAE forward pass\"\"\"\n", " features = sae_encode(activations, sae)\n", " recon = sae_decode(features, sae)\n", " return recon, features\n", "\n", "print(\"SAE functions defined!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Test the SAE\n", "\n", "Create dummy activations and test reconstruction quality." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "# Create dummy activation\n", "dummy_activations = torch.randn(1, 10, 640)\n", "\n", "# Run through SAE\n", "recon, features = sae_forward(dummy_activations, sae)\n", "\n", "# Calculate metrics\n", "mse = F.mse_loss(recon, dummy_activations).item()\n", "cosine = F.cosine_similarity(\n", " dummy_activations.flatten(), \n", " recon.flatten(), \n", " dim=0\n", ").item()\n", "l0 = (features > 0).sum().item()\n", "\n", "print(f\"Input shape: {dummy_activations.shape}\")\n", "print(f\"Features shape: {features.shape}\")\n", "print(f\"\\nReconstruction Quality:\")\n", "print(f\" MSE: {mse:.6f}\")\n", "print(f\" Cosine Similarity: {cosine:.4f}\")\n", "print(f\" L0 (active features): {l0} / {features.shape[-1]}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. All Available SAEs\n", "\n", "This repository contains **36 SAEs** in total:\n", "\n", "| Layer Type | Range | Count |\n", "|------------|-------|-------|\n", "| Encoder | 0-17 | 18 SAEs |\n", "| Decoder | 0-17 | 18 SAEs |\n", "| **Total** | - | **36 SAEs** |\n", "\n", "To load a different layer:\n", "```python\n", "# Encoder Layer 5\n", "sae_path = hf_hub_download(\n", " repo_id=\"mindchain/t5gemma2-sae-all-layers\",\n", " filename=\"encoder/sae_encoder_05.pt\"\n", ")\n", "\n", "# Decoder Layer 10\n", "sae_path = hf_hub_download(\n", " repo_id=\"mindchain/t5gemma2-sae-all-layers\",\n", " filename=\"decoder/sae_decoder_10.pt\"\n", ")\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. Usage with T5Gemma 2 Model\n", "\n", "To use SAEs with the actual T5Gemma 2 model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n", "\n", "# Load model\n", "model = AutoModelForSeq2SeqLM.from_pretrained(\n", " \"google/t5gemma-2-270m-270m\",\n", " device_map=\"auto\"\n", ")\n", "tokenizer = AutoTokenizer.from_pretrained(\"google/t5gemma-2-270m-270m\")\n", "\n", "print(\"Model loaded!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Links\n", "\n", "- **HuggingFace Model**: [mindchain/t5gemma2-sae-all-layers](https://huggingface.co/mindchain/t5gemma2-sae-all-layers)\n", "- **Base Model**: [google/t5gemma-2-270m-270m](https://huggingface.co/google/t5gemma-2-270m-270m)\n", "- **SAELens**: [github.com/decoderesearch/SAELens](https://github.com/decoderesearch/SAELens)\n", "- **Neuronpedia**: [neuronpedia.org](https://neuronpedia.org)\n", "\n", "---\n", "\n", "Trained by [mindchain](https://huggingface.co/mindchain) | December 2025" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 0 }