--- library_name: transformers tags: - protein - biology - flash-attention - esm - esm2 license: mit --- # Flash Attention ESM2 (FAESM) This is an efficient Flash Attention implementation of ESM2 (Evolutionary Scale Modeling) that provides nearly 50% speedup and memory reduction compared to the original implementation. All the source code is from [FAPLM](https://github.com/pengzhangzhi/faplm). Give us a star if you find it useful `:)` ## Key Features - **Automatic Flash Attention**: Automatically uses FlashAttention for up to 70% faster inference and 60% memory reduction when available - **Smart Fallback**: Automatically falls back to PyTorch SDPA if Flash Attention is not installed - **Drop-in Replacement**: Same API as the original ESM2 models - **Memory Efficient**: Removes padding tokens during computation for better efficiency ## Installation Requirements ```bash # Install PyTorch (if not already installed) pip install torch # Install Flash Attention (optional, for best performance) pip install flash-attn --no-build-isolation --no-cache-dir # Install huggingface pip install transformers ``` ## Usage One change the repo name and turn on `trust_remote_code=True`. ```python from transformers import AutoModelForMaskedLM, AutoTokenizer model = AutoModelForMaskedLM.from_pretrained("fredzzp/esm2_t33_650M_UR50D", trust_remote_code=True).to("cuda").eval().half() tokenizer = AutoTokenizer.from_pretrained("fredzzp/esm2_t33_650M_UR50D") input_ids = tokenizer("AGC", return_tensors="pt").input_ids.to("cuda") output = model(input_ids) print(output['logits'].shape) print(output['last_hidden_state'].shape) ``` ## Supported ESM Versions | **Model** | **Num Layers** | **Num Parameters** | | ------------------------------------------------------------------------------------ | -------------- | ------------------ | | [fredzzp/esm2\_t36\_3B\_UR50D](https://huggingface.co/fredzzp/esm2_t36_3B_UR50D) | 36 | 3B | | [fredzzp/esm2\_t33\_650M\_UR50D](https://huggingface.co/fredzzp/esm2_t33_650M_UR50D) | 33 | 650M | | [fredzzp/esm2\_t30\_150M\_UR50D](https://huggingface.co/fredzzp/esm2_t30_150M_UR50D) | 30 | 150M | | [fredzzp/esm2\_t12\_35M\_UR50D](https://huggingface.co/fredzzp/esm2_t12_35M_UR50D) | 12 | 35M | | [fredzzp/esm2\_t6\_8M\_UR50D](https://huggingface.co/fredzzp/esm2_t6_8M_UR50D) | 6 | 8M | ## Citation If you use this implementation, please cite both the original ESM2 paper and this work: ```bibtex @article{lin2023evolutionary, title={Evolutionary-scale prediction of atomic-level protein structure with a language model}, author={Lin, Zeming and Akin, Halil and Rao, Roshan and Hie, Brian and Zhu, Zihang and Lu, Wenting and Smetanin, Nikita and Verkuil, Robert and Kabeli, Ori and Shmueli, Yair and others}, journal={Science}, volume={379}, number={6637}, pages={1123--1130}, year={2023}, publisher={American Association for the Advancement of Science} } @misc{faesm2024, author = {Fred Zhangzhi Peng, Pranam Chatterjee, and contributors}, title = {FAESM: An efficient PyTorch implementation of Evolutionary Scale Modeling (ESM)}, year = {2024}, howpublished = {\url{https://github.com/pengzhangzhi/faesm}}, note = {Efficient PyTorch implementation of ESM with FlashAttention and Scalar Dot-Product Attention (SDPA)}, abstract = {FAESM is a drop-in replacement for the official ESM implementation, designed to save up to 60% memory usage and 70% inference time, while maintaining compatibility with the ESM API.}, } ``` ## License This implementation is licensed under the MIT License. The ESM2 model weights maintain their original licensing terms.