File size: 3,888 Bytes
75189a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
---
library_name: transformers
tags:
- protein
- biology
- flash-attention
- esm
- esm2
license: mit
---

# Flash Attention ESM2 (FAESM)

This is an efficient Flash Attention implementation of ESM2 (Evolutionary Scale Modeling) that provides nearly 50% speedup and memory reduction compared to the original implementation. All the source code is from [FAPLM](https://github.com/pengzhangzhi/faplm). Give us a star if you find it useful `:)`

## Key Features

- **Automatic Flash Attention**: Automatically uses FlashAttention for up to 70% faster inference and 60% memory reduction when available
- **Smart Fallback**: Automatically falls back to PyTorch SDPA if Flash Attention is not installed
- **Drop-in Replacement**: Same API as the original ESM2 models
- **Memory Efficient**: Removes padding tokens during computation for better efficiency

## Installation Requirements

```bash
# Install PyTorch (if not already installed)
pip install torch

# Install Flash Attention (optional, for best performance)
pip install flash-attn --no-build-isolation --no-cache-dir

# Install huggingface 
pip install transformers
```

## Usage
One change the repo name and turn on `trust_remote_code=True`.

```python
from transformers import AutoModelForMaskedLM, AutoTokenizer

model = AutoModelForMaskedLM.from_pretrained("fredzzp/esm2_t33_650M_UR50D", trust_remote_code=True).to("cuda").eval().half()
tokenizer = AutoTokenizer.from_pretrained("fredzzp/esm2_t33_650M_UR50D")

input_ids = tokenizer("AGC", return_tensors="pt").input_ids.to("cuda")
output = model(input_ids)
print(output['logits'].shape)
print(output['last_hidden_state'].shape)
```

## Supported ESM Versions

| **Model**                                                                            | **Num Layers** | **Num Parameters** |
| ------------------------------------------------------------------------------------ | -------------- | ------------------ |
| [fredzzp/esm2\_t36\_3B\_UR50D](https://huggingface.co/fredzzp/esm2_t36_3B_UR50D)     | 36             | 3B                 |
| [fredzzp/esm2\_t33\_650M\_UR50D](https://huggingface.co/fredzzp/esm2_t33_650M_UR50D) | 33             | 650M               |
| [fredzzp/esm2\_t30\_150M\_UR50D](https://huggingface.co/fredzzp/esm2_t30_150M_UR50D) | 30             | 150M               |
| [fredzzp/esm2\_t12\_35M\_UR50D](https://huggingface.co/fredzzp/esm2_t12_35M_UR50D)   | 12             | 35M                |
| [fredzzp/esm2\_t6\_8M\_UR50D](https://huggingface.co/fredzzp/esm2_t6_8M_UR50D)       | 6              | 8M                 |


## Citation

If you use this implementation, please cite both the original ESM2 paper and this work:

```bibtex
@article{lin2023evolutionary,
  title={Evolutionary-scale prediction of atomic-level protein structure with a language model},
  author={Lin, Zeming and Akin, Halil and Rao, Roshan and Hie, Brian and Zhu, Zihang and Lu, Wenting and Smetanin, Nikita and Verkuil, Robert and Kabeli, Ori and Shmueli, Yair and others},
  journal={Science},
  volume={379},
  number={6637},
  pages={1123--1130},
  year={2023},
  publisher={American Association for the Advancement of Science}
}

@misc{faesm2024,
  author       = {Fred Zhangzhi Peng, Pranam Chatterjee, and contributors},
  title        = {FAESM: An efficient PyTorch implementation of Evolutionary Scale Modeling (ESM)},
  year         = {2024},
  howpublished = {\url{https://github.com/pengzhangzhi/faesm}},
  note         = {Efficient PyTorch implementation of ESM with FlashAttention and Scalar Dot-Product Attention (SDPA)},
  abstract     = {FAESM is a drop-in replacement for the official ESM implementation, designed to save up to 60% memory usage and 70% inference time, while maintaining compatibility with the ESM API.},
}
```

## License

This implementation is licensed under the MIT License. The ESM2 model weights maintain their original licensing terms.