File size: 3,888 Bytes
75189a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
---
library_name: transformers
tags:
- protein
- biology
- flash-attention
- esm
- esm2
license: mit
---
# Flash Attention ESM2 (FAESM)
This is an efficient Flash Attention implementation of ESM2 (Evolutionary Scale Modeling) that provides nearly 50% speedup and memory reduction compared to the original implementation. All the source code is from [FAPLM](https://github.com/pengzhangzhi/faplm). Give us a star if you find it useful `:)`
## Key Features
- **Automatic Flash Attention**: Automatically uses FlashAttention for up to 70% faster inference and 60% memory reduction when available
- **Smart Fallback**: Automatically falls back to PyTorch SDPA if Flash Attention is not installed
- **Drop-in Replacement**: Same API as the original ESM2 models
- **Memory Efficient**: Removes padding tokens during computation for better efficiency
## Installation Requirements
```bash
# Install PyTorch (if not already installed)
pip install torch
# Install Flash Attention (optional, for best performance)
pip install flash-attn --no-build-isolation --no-cache-dir
# Install huggingface
pip install transformers
```
## Usage
One change the repo name and turn on `trust_remote_code=True`.
```python
from transformers import AutoModelForMaskedLM, AutoTokenizer
model = AutoModelForMaskedLM.from_pretrained("fredzzp/esm2_t33_650M_UR50D", trust_remote_code=True).to("cuda").eval().half()
tokenizer = AutoTokenizer.from_pretrained("fredzzp/esm2_t33_650M_UR50D")
input_ids = tokenizer("AGC", return_tensors="pt").input_ids.to("cuda")
output = model(input_ids)
print(output['logits'].shape)
print(output['last_hidden_state'].shape)
```
## Supported ESM Versions
| **Model** | **Num Layers** | **Num Parameters** |
| ------------------------------------------------------------------------------------ | -------------- | ------------------ |
| [fredzzp/esm2\_t36\_3B\_UR50D](https://huggingface.co/fredzzp/esm2_t36_3B_UR50D) | 36 | 3B |
| [fredzzp/esm2\_t33\_650M\_UR50D](https://huggingface.co/fredzzp/esm2_t33_650M_UR50D) | 33 | 650M |
| [fredzzp/esm2\_t30\_150M\_UR50D](https://huggingface.co/fredzzp/esm2_t30_150M_UR50D) | 30 | 150M |
| [fredzzp/esm2\_t12\_35M\_UR50D](https://huggingface.co/fredzzp/esm2_t12_35M_UR50D) | 12 | 35M |
| [fredzzp/esm2\_t6\_8M\_UR50D](https://huggingface.co/fredzzp/esm2_t6_8M_UR50D) | 6 | 8M |
## Citation
If you use this implementation, please cite both the original ESM2 paper and this work:
```bibtex
@article{lin2023evolutionary,
title={Evolutionary-scale prediction of atomic-level protein structure with a language model},
author={Lin, Zeming and Akin, Halil and Rao, Roshan and Hie, Brian and Zhu, Zihang and Lu, Wenting and Smetanin, Nikita and Verkuil, Robert and Kabeli, Ori and Shmueli, Yair and others},
journal={Science},
volume={379},
number={6637},
pages={1123--1130},
year={2023},
publisher={American Association for the Advancement of Science}
}
@misc{faesm2024,
author = {Fred Zhangzhi Peng, Pranam Chatterjee, and contributors},
title = {FAESM: An efficient PyTorch implementation of Evolutionary Scale Modeling (ESM)},
year = {2024},
howpublished = {\url{https://github.com/pengzhangzhi/faesm}},
note = {Efficient PyTorch implementation of ESM with FlashAttention and Scalar Dot-Product Attention (SDPA)},
abstract = {FAESM is a drop-in replacement for the official ESM implementation, designed to save up to 60% memory usage and 70% inference time, while maintaining compatibility with the ESM API.},
}
```
## License
This implementation is licensed under the MIT License. The ESM2 model weights maintain their original licensing terms.
|