fredzzp commited on
Commit
90edf16
·
verified ·
1 Parent(s): 81d75ad

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.json +7 -3
  2. modeling_faesm.py +702 -0
  3. pytorch_model.bin +2 -2
config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "_name_or_path": "/scratch/atong01/path_learning/checkpoints/mdm_pl_wt_5/huggingface/step_100000",
3
  "architectures": [
4
- "EsmModel"
5
  ],
6
  "attention_probs_dropout_prob": 0.0,
7
  "classifier_dropout": null,
@@ -26,5 +26,9 @@
26
  "transformers_version": "4.26.0",
27
  "use_cache": true,
28
  "vocab_list": null,
29
- "vocab_size": 33
30
- }
 
 
 
 
 
1
  {
2
  "_name_or_path": "/scratch/atong01/path_learning/checkpoints/mdm_pl_wt_5/huggingface/step_100000",
3
  "architectures": [
4
+ "FAEsmForMaskedLM"
5
  ],
6
  "attention_probs_dropout_prob": 0.0,
7
  "classifier_dropout": null,
 
26
  "transformers_version": "4.26.0",
27
  "use_cache": true,
28
  "vocab_list": null,
29
+ "vocab_size": 33,
30
+ "auto_map": {
31
+ "AutoModel": "modeling_faesm.FAEsmModel",
32
+ "AutoModelForMaskedLM": "modeling_faesm.FAEsmForMaskedLM"
33
+ }
34
+ }
modeling_faesm.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 FAESM team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Flash Attention ESM2 model implementation for Hugging Face Hub."""
17
+
18
+ import logging
19
+ import math
20
+ from typing import List, Optional, Tuple, Union
21
+ import os
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from einops import rearrange
26
+ from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer
27
+ from transformers.models.esm.modeling_esm import (
28
+ BaseModelOutputWithPastAndCrossAttentions,
29
+ BaseModelOutputWithPoolingAndCrossAttentions,
30
+ EsmAttention,
31
+ EsmContactPredictionHead,
32
+ EsmEmbeddings,
33
+ EsmEncoder,
34
+ EsmForMaskedLM,
35
+ EsmIntermediate,
36
+ EsmLayer,
37
+ EsmLMHead,
38
+ EsmModel,
39
+ EsmOutput,
40
+ EsmPooler,
41
+ EsmPreTrainedModel,
42
+ EsmSelfAttention,
43
+ EsmSelfOutput,
44
+ )
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+ # Flash Attention check
49
+ try:
50
+ from flash_attn import flash_attn_varlen_qkvpacked_func
51
+ from flash_attn.bert_padding import pad_input, unpad_input
52
+ from flash_attn.ops.triton.rotary import apply_rotary
53
+ disable_fa = os.getenv("DISABLE_FA", "0")
54
+ print(f"✅ Flash Attention detected - using optimized implementation, disable_fa: {disable_fa}")
55
+ flash_attn_installed = True and not (disable_fa == "1")
56
+ except ImportError as e:
57
+ flash_attn_installed = False
58
+ print(
59
+ """
60
+ ⚠️ Flash Attention not available - using PyTorch SDPA fallback.
61
+ For optimal performance, install Flash Attention:
62
+ pip install flash-attn --no-build-isolation
63
+ """
64
+ )
65
+ import traceback
66
+ print(traceback.print_exec())
67
+
68
+
69
+ # ============================================================================
70
+ # Flash Attention Utilities (consolidated from fa_utils.py)
71
+ # ============================================================================
72
+
73
+ class ApplyRotaryEmbQKV_(torch.autograd.Function):
74
+ @staticmethod
75
+ def forward(ctx, qkv, cos, sin, cu_seqlens, max_seqlen):
76
+ q, k = qkv[:, 0], qkv[:, 1]
77
+ apply_rotary(q, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, inplace=True)
78
+ apply_rotary(k, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, inplace=True)
79
+ ctx.save_for_backward(cos, sin, cu_seqlens)
80
+ ctx.max_seqlen = max_seqlen
81
+ return qkv
82
+
83
+ @staticmethod
84
+ def backward(ctx, dqkv):
85
+ max_seqlen = ctx.max_seqlen
86
+ cos, sin, cu_seqlens = ctx.saved_tensors
87
+ dq, dk = dqkv[:, 0], dqkv[:, 1]
88
+ apply_rotary(
89
+ dq, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, inplace=True, conjugate=True
90
+ )
91
+ apply_rotary(
92
+ dk, cos, sin, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, inplace=True, conjugate=True
93
+ )
94
+ return dqkv, None, None, None, None
95
+
96
+
97
+ def apply_rotary_emb_qkv_(qkv, cos, sin, cu_seqlens: torch.Tensor, max_seqlen: int) -> torch.Tensor:
98
+ """Apply rotary embedding *inplace* to the first rotary_dim of Q and K."""
99
+ return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
100
+
101
+
102
+ class RotaryEmbedding(torch.nn.Module):
103
+ """The rotary position embeddings from RoFormer."""
104
+
105
+ def __init__(self, dim: int, base=10000.0, pos_idx_in_fp32=True, device=None, persistent=True):
106
+ super().__init__()
107
+ self.dim = dim
108
+ self.base = float(base)
109
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
110
+ inv_freq = self._compute_inv_freq(device)
111
+ self.register_buffer("inv_freq", inv_freq, persistent=persistent)
112
+ self._seq_len_cached = 0
113
+ self._cos_cached = None
114
+ self._sin_cached = None
115
+
116
+ def _compute_inv_freq(self, device=None):
117
+ return 1.0 / (
118
+ self.base
119
+ ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
120
+ )
121
+
122
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
123
+ if (
124
+ seqlen > self._seq_len_cached
125
+ or self._cos_cached is None
126
+ or self._cos_cached.device != device
127
+ or self._cos_cached.dtype != dtype
128
+ or (self.training and self._cos_cached.is_inference())
129
+ ):
130
+ self._seq_len_cached = seqlen
131
+ if self.pos_idx_in_fp32:
132
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
133
+ if self.inv_freq.dtype != torch.float32:
134
+ inv_freq = self._compute_inv_freq(device=device)
135
+ else:
136
+ inv_freq = self.inv_freq
137
+ else:
138
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
139
+ inv_freq = self.inv_freq
140
+
141
+ freqs = torch.outer(t, inv_freq)
142
+ self._cos_cached = torch.cos(freqs).to(dtype)
143
+ self._sin_cached = torch.sin(freqs).to(dtype)
144
+
145
+ def forward(
146
+ self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int, *args, **kwargs
147
+ ) -> torch.Tensor:
148
+ """Apply rotary embedding *inplace*."""
149
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
150
+ return apply_rotary_emb_qkv_(
151
+ qkv, self._cos_cached, self._sin_cached, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
152
+ )
153
+
154
+
155
+ def unpad(input, padding_mask):
156
+ """
157
+ Arguments:
158
+ input: (batch, seqlen, ...)
159
+ padding_mask: (batch, seqlen), bool type, True means to keep, False means to remove
160
+ Return:
161
+ output: (total_nnz, ...), where total_nnz = number of tokens in selected in padding_mask
162
+ indices: (total_nnz,), the indices of tokens in the original input
163
+ cu_seqlens: (batch + 1,), the cumulative sequence lengths, used to index into output
164
+ max_seqlen: int, the maximum sequence length in the batch
165
+ output_pad_fn: function, to pad the output back to the original shape
166
+ """
167
+ seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
168
+ indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
169
+ max_seqlen = seqlens_in_batch.max().item()
170
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
171
+ output = input.flatten(0, 1)[indices]
172
+
173
+ def output_pad_fn(output):
174
+ return pad_input(output, indices, batch=input.shape[0], seqlen=input.shape[1])
175
+
176
+ return output, cu_seqlens, max_seqlen, indices, output_pad_fn
177
+
178
+
179
+ # ============================================================================
180
+ # Flash Attention ESM Model Implementation
181
+ # ============================================================================
182
+
183
+ class FAEsmSelfAttention(EsmSelfAttention):
184
+ def __init__(self, config, position_embedding_type=None):
185
+ super().__init__(config, position_embedding_type)
186
+ self.config = config
187
+ if flash_attn_installed:
188
+ self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
189
+
190
+ def forward(self, **kwargs):
191
+ if flash_attn_installed:
192
+ return self.fa_forward(**kwargs)
193
+ else:
194
+ return self.sdpa_forward(**kwargs)
195
+
196
+ def sdpa_forward(
197
+ self,
198
+ hidden_states: torch.Tensor,
199
+ attention_mask: Optional[torch.FloatTensor] = None,
200
+ head_mask: Optional[torch.FloatTensor] = None,
201
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
202
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
203
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
204
+ output_attentions: Optional[bool] = False,
205
+ **kwargs,
206
+ ) -> Tuple[torch.Tensor]:
207
+ mixed_query_layer = self.query(hidden_states)
208
+
209
+ is_cross_attention = encoder_hidden_states is not None
210
+
211
+ if is_cross_attention and past_key_value is not None:
212
+ key_layer = past_key_value[0]
213
+ value_layer = past_key_value[1]
214
+ attention_mask = encoder_attention_mask
215
+ elif is_cross_attention:
216
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
217
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
218
+ attention_mask = encoder_attention_mask
219
+ elif past_key_value is not None:
220
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
221
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
222
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
223
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
224
+ else:
225
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
226
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
227
+
228
+ query_layer = self.transpose_for_scores(mixed_query_layer)
229
+ query_layer = query_layer * self.attention_head_size**-0.5
230
+
231
+ if self.is_decoder:
232
+ past_key_value = (key_layer, value_layer)
233
+
234
+ if self.position_embedding_type == "rotary":
235
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
236
+
237
+ if (
238
+ self.position_embedding_type == "relative_key"
239
+ or self.position_embedding_type == "relative_key_query"
240
+ ):
241
+ raise NotImplementedError
242
+
243
+ if head_mask is not None:
244
+ raise NotImplementedError
245
+
246
+ query_layer = query_layer.contiguous()
247
+ key_layer = key_layer.contiguous()
248
+ value_layer = value_layer.contiguous()
249
+
250
+ context_layer = F.scaled_dot_product_attention(
251
+ query_layer, key_layer, value_layer, attn_mask=attention_mask, scale=1.0
252
+ )
253
+
254
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
255
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
256
+ context_layer = context_layer.view(new_context_layer_shape)
257
+
258
+ outputs = (context_layer,)
259
+ if self.is_decoder:
260
+ outputs = outputs + (past_key_value,)
261
+ return outputs
262
+
263
+ def fa_forward(
264
+ self,
265
+ hidden_states: torch.Tensor,
266
+ cu_seqlens,
267
+ max_seqlen,
268
+ attention_mask: Optional[torch.FloatTensor] = None,
269
+ head_mask: Optional[torch.FloatTensor] = None,
270
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
271
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
272
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
273
+ output_attentions: Optional[bool] = False,
274
+ **kwargs,
275
+ ) -> Tuple[torch.Tensor]:
276
+ assert cu_seqlens is not None, "cu_seqlens must be provided for FlashAttention"
277
+ assert max_seqlen is not None, "max_seqlen must be provided for FlashAttention"
278
+
279
+ q = self.query(hidden_states) * self.attention_head_size**-0.5
280
+ k = self.key(hidden_states)
281
+ v = self.value(hidden_states)
282
+ q, k, v = map(
283
+ lambda x: rearrange(x, "n (h d) -> n h d", h=self.num_attention_heads),
284
+ (q, k, v),
285
+ )
286
+ qkv = torch.stack((q, k, v), dim=1) # (n, 3, h, d)
287
+ qkv = self.rotary_embeddings(qkv=qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
288
+
289
+ out = flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, softmax_scale=1.0)
290
+ out = rearrange(out, "n h d -> n (h d)")
291
+ outputs = (out,)
292
+ return outputs
293
+
294
+
295
+ class FAEsmAttention(EsmAttention):
296
+ def __init__(self, config):
297
+ nn.Module.__init__(self)
298
+ self.self = FAEsmSelfAttention(config)
299
+ self.output = EsmSelfOutput(config)
300
+ self.pruned_heads = set()
301
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
302
+
303
+ def forward(
304
+ self,
305
+ hidden_states,
306
+ cu_seqlens=None,
307
+ max_seqlen=None,
308
+ attention_mask=None,
309
+ head_mask=None,
310
+ encoder_hidden_states=None,
311
+ encoder_attention_mask=None,
312
+ past_key_value=None,
313
+ output_attentions=False,
314
+ ):
315
+ hidden_states_ln = self.LayerNorm(hidden_states)
316
+ self_outputs = self.self(
317
+ hidden_states=hidden_states_ln,
318
+ cu_seqlens=cu_seqlens,
319
+ max_seqlen=max_seqlen,
320
+ attention_mask=attention_mask,
321
+ head_mask=head_mask,
322
+ encoder_hidden_states=encoder_hidden_states,
323
+ encoder_attention_mask=encoder_attention_mask,
324
+ past_key_value=past_key_value,
325
+ output_attentions=output_attentions,
326
+ )
327
+ attention_output = self.output(self_outputs[0], hidden_states)
328
+ outputs = (attention_output,) + self_outputs[1:]
329
+ return outputs
330
+
331
+
332
+ class FAEsmLayer(EsmLayer):
333
+ def __init__(self, config):
334
+ nn.Module.__init__(self)
335
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
336
+ self.seq_len_dim = 1
337
+ self.attention = FAEsmAttention(config)
338
+ self.is_decoder = config.is_decoder
339
+ self.add_cross_attention = config.add_cross_attention
340
+ if self.add_cross_attention:
341
+ if not self.is_decoder:
342
+ raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
343
+ self.crossattention = FAEsmAttention(config)
344
+ self.intermediate = EsmIntermediate(config)
345
+ self.output = EsmOutput(config)
346
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
347
+
348
+ def forward(
349
+ self,
350
+ hidden_states,
351
+ cu_seqlens=None,
352
+ max_seqlen=None,
353
+ attention_mask=None,
354
+ head_mask=None,
355
+ encoder_hidden_states=None,
356
+ encoder_attention_mask=None,
357
+ past_key_value=None,
358
+ output_attentions=False,
359
+ ):
360
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
361
+ self_attention_outputs = self.attention(
362
+ hidden_states=hidden_states,
363
+ cu_seqlens=cu_seqlens,
364
+ max_seqlen=max_seqlen,
365
+ attention_mask=attention_mask,
366
+ head_mask=head_mask,
367
+ output_attentions=output_attentions,
368
+ past_key_value=self_attn_past_key_value,
369
+ )
370
+ attention_output = self_attention_outputs[0]
371
+
372
+ if self.is_decoder:
373
+ outputs = self_attention_outputs[1:-1]
374
+ present_key_value = self_attention_outputs[-1]
375
+ else:
376
+ outputs = self_attention_outputs[1:]
377
+
378
+ cross_attn_present_key_value = None
379
+ if self.is_decoder and encoder_hidden_states is not None:
380
+ if not hasattr(self, "crossattention"):
381
+ raise AttributeError(
382
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
383
+ " with cross-attention layers by setting `config.add_cross_attention=True`"
384
+ )
385
+
386
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
387
+ cross_attention_outputs = self.crossattention(
388
+ attention_output,
389
+ attention_mask,
390
+ head_mask,
391
+ encoder_hidden_states,
392
+ encoder_attention_mask,
393
+ cross_attn_past_key_value,
394
+ output_attentions,
395
+ )
396
+ attention_output = cross_attention_outputs[0]
397
+ outputs = outputs + cross_attention_outputs[1:-1]
398
+ cross_attn_present_key_value = cross_attention_outputs[-1]
399
+ present_key_value = present_key_value + cross_attn_present_key_value
400
+
401
+ layer_output = self.feed_forward_chunk(attention_output)
402
+ outputs = (layer_output,) + outputs
403
+
404
+ if self.is_decoder:
405
+ outputs = outputs + (present_key_value,)
406
+ return outputs
407
+
408
+
409
+ class FAEsmEncoder(EsmEncoder):
410
+ def __init__(self, config):
411
+ nn.Module.__init__(self)
412
+ self.config = config
413
+ self.layer = nn.ModuleList([FAEsmLayer(config) for _ in range(config.num_hidden_layers)])
414
+ self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
415
+ self.gradient_checkpointing = False
416
+
417
+ def forward(
418
+ self,
419
+ hidden_states,
420
+ cu_seqlens=None,
421
+ max_seqlen=None,
422
+ attention_mask=None,
423
+ head_mask=None,
424
+ encoder_hidden_states=None,
425
+ encoder_attention_mask=None,
426
+ past_key_values=None,
427
+ use_cache=None,
428
+ output_attentions=False,
429
+ output_hidden_states=False,
430
+ return_dict=True,
431
+ ):
432
+ if self.gradient_checkpointing and self.training:
433
+ if use_cache:
434
+ logger.warning_once(
435
+ "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
436
+ "`use_cache=False`..."
437
+ )
438
+ use_cache = False
439
+ all_hidden_states = () if output_hidden_states else None
440
+ all_self_attentions = () if output_attentions else None
441
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
442
+
443
+ next_decoder_cache = () if use_cache else None
444
+ for i, layer_module in enumerate(self.layer):
445
+ if output_hidden_states:
446
+ all_hidden_states = all_hidden_states + (hidden_states,)
447
+
448
+ layer_head_mask = head_mask[i] if head_mask is not None else None
449
+ past_key_value = past_key_values[i] if past_key_values is not None else None
450
+
451
+ if self.gradient_checkpointing and self.training:
452
+ layer_outputs = self._gradient_checkpointing_func(
453
+ layer_module.__call__,
454
+ hidden_states=hidden_states,
455
+ cu_seqlens=cu_seqlens,
456
+ max_seqlen=max_seqlen,
457
+ attention_mask=attention_mask,
458
+ head_mask=layer_head_mask,
459
+ encoder_hidden_states=encoder_hidden_states,
460
+ encoder_attention_mask=encoder_attention_mask,
461
+ past_key_value=past_key_value,
462
+ output_attentions=output_attentions,
463
+ )
464
+ else:
465
+ layer_outputs = layer_module(
466
+ hidden_states=hidden_states,
467
+ cu_seqlens=cu_seqlens,
468
+ max_seqlen=max_seqlen,
469
+ attention_mask=attention_mask,
470
+ head_mask=layer_head_mask,
471
+ encoder_hidden_states=encoder_hidden_states,
472
+ encoder_attention_mask=encoder_attention_mask,
473
+ past_key_value=past_key_value,
474
+ output_attentions=output_attentions,
475
+ )
476
+
477
+ hidden_states = layer_outputs[0]
478
+ if use_cache:
479
+ next_decoder_cache = next_decoder_cache + (layer_outputs[-1],)
480
+ if output_attentions:
481
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
482
+ if self.config.add_cross_attention:
483
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
484
+
485
+ if self.emb_layer_norm_after:
486
+ hidden_states = self.emb_layer_norm_after(hidden_states)
487
+
488
+ if output_hidden_states:
489
+ all_hidden_states = all_hidden_states + (hidden_states,)
490
+
491
+ if not return_dict:
492
+ return tuple(
493
+ v
494
+ for v in [
495
+ hidden_states,
496
+ next_decoder_cache,
497
+ all_hidden_states,
498
+ all_self_attentions,
499
+ all_cross_attentions,
500
+ ]
501
+ if v is not None
502
+ )
503
+ return BaseModelOutputWithPastAndCrossAttentions(
504
+ last_hidden_state=hidden_states,
505
+ past_key_values=next_decoder_cache,
506
+ hidden_states=all_hidden_states,
507
+ attentions=all_self_attentions,
508
+ cross_attentions=all_cross_attentions,
509
+ )
510
+
511
+
512
+ class FAEsmModel(EsmModel):
513
+ def __init__(self, config, add_pooling_layer=True):
514
+ EsmPreTrainedModel.__init__(self, config)
515
+ self.config = config
516
+
517
+ self.embeddings = EsmEmbeddings(config)
518
+ self.encoder = FAEsmEncoder(config)
519
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
520
+ self.contact_head = EsmContactPredictionHead(
521
+ in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
522
+ )
523
+
524
+ self.post_init()
525
+
526
+ def forward(
527
+ self,
528
+ input_ids: Optional[torch.Tensor] = None,
529
+ embed_cond: Optional[torch.Tensor] = None, # [B, L, D]
530
+ attention_mask: Optional[torch.Tensor] = None,
531
+ position_ids: Optional[torch.Tensor] = None,
532
+ head_mask: Optional[torch.Tensor] = None,
533
+ inputs_embeds: Optional[torch.Tensor] = None,
534
+ encoder_hidden_states: Optional[torch.Tensor] = None,
535
+ encoder_attention_mask: Optional[torch.Tensor] = None,
536
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
537
+ use_cache: Optional[bool] = None,
538
+ output_attentions: Optional[bool] = None,
539
+ output_hidden_states: Optional[bool] = None,
540
+ return_dict: Optional[bool] = None,
541
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
542
+ output_attentions = (
543
+ output_attentions if output_attentions is not None else self.config.output_attentions
544
+ )
545
+ output_hidden_states = (
546
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
547
+ )
548
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
549
+
550
+ if self.config.is_decoder:
551
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
552
+ else:
553
+ use_cache = False
554
+
555
+ if input_ids is not None and inputs_embeds is not None:
556
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
557
+ elif input_ids is not None:
558
+ input_shape = input_ids.size()
559
+ elif inputs_embeds is not None:
560
+ input_shape = inputs_embeds.size()[:-1]
561
+ else:
562
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
563
+
564
+ batch_size, seq_length = input_shape
565
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
566
+
567
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
568
+
569
+ if attention_mask is None:
570
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
571
+
572
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
573
+
574
+ if self.config.is_decoder and encoder_hidden_states is not None:
575
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
576
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
577
+ if encoder_attention_mask is None:
578
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
579
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
580
+ else:
581
+ encoder_extended_attention_mask = encoder_attention_mask
582
+
583
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
584
+ embedding_output = self.embeddings(
585
+ input_ids=input_ids,
586
+ position_ids=position_ids,
587
+ attention_mask=attention_mask,
588
+ inputs_embeds=inputs_embeds,
589
+ past_key_values_length=past_key_values_length,
590
+ )
591
+ if embed_cond is not None:
592
+ # add embed_cond to the embedding_output
593
+ embedding_output = embedding_output + embed_cond
594
+
595
+ # Automatically use Flash Attention if available, otherwise use SDPA
596
+ use_fa = flash_attn_installed
597
+
598
+ if use_fa:
599
+ embedding_output, cu_seqlens, max_seqlen, _, output_pad_fn = unpad(embedding_output, attention_mask)
600
+ else:
601
+ cu_seqlens = None
602
+ max_seqlen = None
603
+ output_pad_fn = lambda x: x
604
+
605
+ encoder_outputs = self.encoder(
606
+ embedding_output,
607
+ cu_seqlens=cu_seqlens,
608
+ max_seqlen=max_seqlen,
609
+ attention_mask=extended_attention_mask,
610
+ head_mask=head_mask,
611
+ encoder_hidden_states=encoder_hidden_states,
612
+ encoder_attention_mask=encoder_extended_attention_mask,
613
+ past_key_values=past_key_values,
614
+ use_cache=use_cache,
615
+ output_attentions=output_attentions,
616
+ output_hidden_states=output_hidden_states,
617
+ return_dict=return_dict,
618
+ )
619
+ sequence_output = encoder_outputs[0]
620
+ sequence_output = output_pad_fn(sequence_output)
621
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
622
+
623
+ if not return_dict:
624
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
625
+
626
+ return BaseModelOutputWithPoolingAndCrossAttentions(
627
+ last_hidden_state=sequence_output,
628
+ pooler_output=pooled_output,
629
+ past_key_values=encoder_outputs.past_key_values,
630
+ hidden_states=encoder_outputs.hidden_states,
631
+ attentions=encoder_outputs.attentions,
632
+ cross_attentions=encoder_outputs.cross_attentions,
633
+ )
634
+
635
+
636
+ class FAEsmForMaskedLM(EsmForMaskedLM):
637
+ """Flash Attention ESM For Masked Language Modeling."""
638
+
639
+ def __init__(self, config, dropout=0.1):
640
+ config.hidden_dropout_prob = dropout
641
+ EsmPreTrainedModel.__init__(self, config)
642
+ self.esm = FAEsmModel(config, add_pooling_layer=False)
643
+ self.lm_head = EsmLMHead(config)
644
+ self.init_weights()
645
+
646
+ # Initialize tokenizer-related attributes if tokenizer is available
647
+ try:
648
+ tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
649
+ self.mask_id = tokenizer.mask_token_id
650
+ self.pad_id = tokenizer.pad_token_id
651
+ self.bos_id = tokenizer.cls_token_id
652
+ self.eos_id = tokenizer.eos_token_id
653
+ self.x_id = tokenizer._token_to_id.get("X", None)
654
+ self.tokenizer = tokenizer
655
+ except:
656
+ # Set default values if tokenizer is not available
657
+ self.mask_id = 32
658
+ self.pad_id = 1
659
+ self.bos_id = 0
660
+ self.eos_id = 2
661
+ self.x_id = 24
662
+ self.tokenizer = None
663
+ self.contact_head = None
664
+
665
+ def forward(
666
+ self,
667
+ input_ids,
668
+ embed_cond=None,
669
+ attention_mask=None,
670
+ inputs_embeds=None,
671
+ decoder_input_ids=None,
672
+ decoder_attention_mask=None,
673
+ decoder_inputs_embeds=None,
674
+ labels=None,
675
+ output_attentions=None,
676
+ output_hidden_states=None,
677
+ return_dict=None,
678
+ encoder_hidden_states=None,
679
+ encoder_attention_mask=None,
680
+ ):
681
+ if attention_mask is None:
682
+ attention_mask = input_ids.ne(self.pad_id)
683
+
684
+ outputs = self.esm(
685
+ input_ids,
686
+ attention_mask=attention_mask,
687
+ embed_cond=embed_cond,
688
+ encoder_hidden_states=encoder_hidden_states,
689
+ encoder_attention_mask=encoder_attention_mask,
690
+ output_hidden_states=output_hidden_states,
691
+ )
692
+ sequence_output = outputs[0]
693
+ logits = self.lm_head(sequence_output)
694
+
695
+ result = {"logits": logits, "last_hidden_state": sequence_output}
696
+ if outputs.hidden_states is not None:
697
+ result["hidden_states"] = [x.unsqueeze(0) for x in outputs.hidden_states]
698
+
699
+ return result
700
+
701
+
702
+
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c40ed59ae0da2c641d4754901d18c71898282b098c0e1da32418fe5d3573186d
3
- size 595360319
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:530386bd8f7a4c009b9ab943d49afc8961e225a2fc753f303e318550f7337464
3
+ size 595368651