yairschiff commited on
Commit
9435876
·
verified ·
1 Parent(s): 7ceeeda

Update pytorch.bin; Add model and code

Browse files
config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "T": 0,
3
+ "architectures": [
4
+ "E2D2"
5
+ ],
6
+ "attn_backend": "sdpa",
7
+ "auto_map": {
8
+ "AutoConfig": "diffusion.E2D2Config",
9
+ "AutoModel": "diffusion.E2D2",
10
+ "AutoModelForMaskedLM": "diffusion.E2D2"
11
+ },
12
+ "backbone_config": {
13
+ "_target_": "backbone_encoder_decoder.LLMasEncoderDecoderShareKV",
14
+ "attn_backend": "sdpa",
15
+ "freeze_encoder": false,
16
+ "keep_top_decoder_layers": true,
17
+ "keep_top_encoder_layers": false,
18
+ "max_length": 768,
19
+ "num_decoder_layers": 14,
20
+ "num_encoder_layers": 28,
21
+ "pretrained_model_name_or_path": "Qwen/Qwen3-1.7B-Base",
22
+ "reinit_decoder": false,
23
+ "reinit_encoder": false,
24
+ "tie_encoder_decoder_weights": true,
25
+ "use_encoder_causal_mask": false,
26
+ "use_gradient_checkpointing": false
27
+ },
28
+ "backbone_is_decoder_only": false,
29
+ "block_size": 4,
30
+ "bos_token_id": 151643,
31
+ "diffusion_type": "absorbing",
32
+ "eos_token_id": 151643,
33
+ "eval_block_size": 4,
34
+ "keep_clean_bos": true,
35
+ "length": 768,
36
+ "mask_token_id": 151660,
37
+ "model_type": "e2d2",
38
+ "noise_config": {
39
+ "_target_": "noise_schedule_noise_schedules.LinearNoise"
40
+ },
41
+ "pad_token_id": 151643,
42
+ "pad_vocab_size_multiple": 1,
43
+ "shift_logits": false,
44
+ "time_conditioned_backbone": false,
45
+ "tokenization_config": {
46
+ "bos_token_id": 151643,
47
+ "eos_token_id": 151643,
48
+ "mask_token_id": 151660,
49
+ "pad_token_id": 151643,
50
+ "pad_vocab_size_multiple": 1,
51
+ "vocab_size": 151669
52
+ },
53
+ "tokenizer_name": "Qwen/Qwen3-1.7B-Base",
54
+ "torch_dtype": "float32",
55
+ "train_on_context": false,
56
+ "transformers_version": "4.52.4",
57
+ "vocab_size": 151669
58
+ }
diffusion.py ADDED
@@ -0,0 +1,1445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Dict, Literal, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from tqdm.auto import tqdm
6
+ from transformers import (
7
+ GenerationConfig,
8
+ LogitsProcessorList,
9
+ PreTrainedTokenizer,
10
+ StoppingCriteriaList,
11
+ )
12
+ from transformers.cache_utils import Cache, DynamicCache
13
+
14
+ try:
15
+ from torch.nn.attention.flex_attention import (
16
+ BlockMask,
17
+ and_masks,
18
+ create_block_mask,
19
+ )
20
+ except ImportError:
21
+ BlockMask, and_masks, create_block_mask = None, None, None
22
+
23
+
24
+ from src.denoiser.base import (
25
+ Denoiser,
26
+ DenoiserConfig,
27
+ DenoiserInput,
28
+ LossAndNllOutput,
29
+ )
30
+
31
+
32
+ def create_attn_mask(attn_mask):
33
+ # noinspection PyUnusedLocal
34
+ def padding(b, h, q_idx, kv_idx):
35
+ return attn_mask[b, q_idx] & attn_mask[b, kv_idx]
36
+
37
+ return padding
38
+
39
+
40
+
41
+ class DiffusionGenerationConfig(GenerationConfig):
42
+ def __init__(
43
+ self,
44
+ num_steps: int = 1000,
45
+ min_t: float = 1e-5,
46
+ block_size: Optional[int] = None,
47
+ first_hitting: bool = False,
48
+ sampling_strategy: Literal["posterior", "predict_then_noise"] = "posterior",
49
+ confidence_based_noising: bool = False,
50
+ confidence_margin_based_noising: bool = False,
51
+ confidence_threshold: float = 1e6,
52
+ use_model_output_cache: bool = True,
53
+ align_inputs_to_blocks: bool = True,
54
+ **kwargs,
55
+ ):
56
+ """Generation config with additional parameters relevant for diffusion model
57
+ sampling.
58
+
59
+ Args:
60
+ num_steps (int): Number of diffusion / iterative refinement steps.
61
+ Defaults to 1000.
62
+ min_t (float): Minimum time to use.
63
+ Diffusion models use t=1 for noise and t=0 for signal.
64
+ Setting t=0 exactly can lead to certain numerical instabilities.
65
+ Defaults to 1e-5.
66
+ block_size (int): Block size to use for semi-autoregressive decoding.
67
+ Defaults to None (in which case block_size is set to max_new_tokens).
68
+ first_hitting (bool): Whether to use first hitting sampler.
69
+ When set to true, rather than following the diffusion time and sampling
70
+ from posterior, which can result in no tokens changing between steps,
71
+ e.g., for masked diffusion, we explicitly determine the next time step
72
+ at which a token will be decoded / generated.
73
+ Note: this will negate the `num_steps` parameter, as we will decode one
74
+ token at a time, hence, when True, num_steps = seq_length
75
+ (or block_size, for semi-autoregressive).
76
+ See https://arxiv.org/abs/2409.02908 for details.
77
+ Defaults to False.
78
+ sampling_strategy (str): Method for transitioning between latents.
79
+ Options:
80
+ - "posterior" - Compute and sample from the posterior
81
+ q(x_s | x_t, x_theta).
82
+ - "predict_then_noise" - Sample from the denoising model x_theta,
83
+ then add back noise to produce x_s.
84
+ Only implemented for absorbing diffusion.
85
+ Defaults to "posterior".
86
+ confidence_based_noising (bool): When using the "predict_then_noise"
87
+ strategy, whether to add noise to random positions or to those that have
88
+ the lowest probability under x_theta.
89
+ Cannot be used in conjunction with confidence_margin_based_noising.
90
+ Defaults to False.
91
+ confidence_margin_based_noising (bool): When using the "predict_then_noise"
92
+ strategy, whether to add noise to random positions or to those that have
93
+ the lowest probability margins under x_theta, where margin is defined as
94
+ the absolute difference between the top two probabilities at a given
95
+ position.
96
+ See https://arxiv.org/abs/2502.06768 for details.
97
+ Cannot be used in conjunction with confidence_based_noising.
98
+ Defaults to False.
99
+ confidence_threshold (float): Confidence threshold to use for sampling.
100
+ Any tokens that exceed threshold are decoded.
101
+ See https://arxiv.org/abs/2505.22618 for details.
102
+ Defaults to 1e6.
103
+ use_model_output_cache (bool): Whether to re-use model's output, if sequence
104
+ is unchanged, because if xt == xs, we can simply re-use the denoising
105
+ model's outputs and save a function evaluation.
106
+ Relevant if model.backbone is not time/noise-conditioned.
107
+ Defaults to True.
108
+ align_inputs_to_blocks (bool): Whether to align input tokens to block size,
109
+ e.g., for an input of length C and block size S, context will be C // S,
110
+ and generation will begin with a block whose first C % S tokens come
111
+ from the input.
112
+ kwargs: Keyword arguments passed to `GenerationConfig`.
113
+ """
114
+ super().__init__(**kwargs)
115
+ self.num_steps = num_steps
116
+ self.min_t = min_t
117
+ # TODO: assumes we are setting max_new_tokens, which may not be the case!
118
+ self.block_size = block_size if block_size is not None else self.max_new_tokens
119
+ self.first_hitting = first_hitting
120
+ if self.first_hitting:
121
+ # TODO: log.warn that this is being overridden
122
+ self.num_steps = min(num_steps, self.block_size)
123
+ self.sampling_strategy = sampling_strategy
124
+ assert not confidence_based_noising or not confidence_margin_based_noising, (
125
+ "Cannot use both `confidence_based_noising` and"
126
+ " `confidence_margin_based_noising`."
127
+ )
128
+ self.confidence_based_noising = confidence_based_noising
129
+ self.confidence_margin_based_noising = confidence_margin_based_noising
130
+ self.confidence_threshold = confidence_threshold
131
+ self.use_model_output_cache = use_model_output_cache
132
+ self.align_inputs_to_blocks = align_inputs_to_blocks
133
+
134
+
135
+ class D3PMConfig(DenoiserConfig):
136
+ """Configuration class for D3PM models."""
137
+
138
+ model_type = "d3pm"
139
+ auto_map = {
140
+ "AutoConfig": "diffusion.D3PMConfig",
141
+ "AutoModel": "diffusion.D3PM",
142
+ "AutoModelForMaskedLM": "diffusion.D3PM",
143
+ }
144
+
145
+ def __init__(
146
+ self,
147
+ keep_clean_bos: Optional[bool] = None, # Whether to enforce un-noised BOS token
148
+ T: int = 1000,
149
+ diffusion_type: Literal["absorbing", "uniform"] = "absorbing",
150
+ **kwargs,
151
+ ):
152
+ super().__init__(**kwargs)
153
+ self.keep_clean_bos = keep_clean_bos
154
+ self.diffusion_type = diffusion_type
155
+ self.T = T
156
+
157
+
158
+ class D3PM(Denoiser):
159
+ """Denoiser class for D3PM models.
160
+
161
+ This class implements the Denoiser interface for D3PM models.
162
+ """
163
+
164
+ config_class = D3PMConfig
165
+
166
+ def __init__(self, config: D3PMConfig, **kwargs):
167
+ super().__init__(config, **kwargs)
168
+ self.T = config.T
169
+ self.diffusion_type = config.diffusion_type
170
+ self._create_static_mask()
171
+
172
+ def _create_static_mask(self) -> None:
173
+ static_mask = torch.ones(
174
+ self.config.length, self.config.length, dtype=torch.bool
175
+ )
176
+ self.register_buffer(
177
+ "static_attention_mask",
178
+ static_mask,
179
+ )
180
+
181
+ def _sample_q_xt(
182
+ self,
183
+ x0: torch.LongTensor,
184
+ alpha_t: torch.FloatTensor,
185
+ context_mask: torch.FloatTensor,
186
+ ) -> torch.LongTensor:
187
+ """Sample from the pre-defined forward / noising process.
188
+
189
+ Parameters:
190
+ x0 (Tensor): Signal / data sample;
191
+ can potentially include context tokens.
192
+ alpha_t (Tensor): Amount of signal to retain.
193
+ context_mask (Tensor): Indicator of context tokens (to remain
194
+ unchanged).
195
+ """
196
+ move_indices = torch.rand(*x0.shape, device=x0.device) < (1.0 - alpha_t)
197
+ if self.diffusion_type == "absorbing":
198
+ xt = torch.where(
199
+ (move_indices * (1 - context_mask)).bool(), self.mask_token_id, x0
200
+ )
201
+ if self.config.keep_clean_bos:
202
+ xt[..., 0] = x0[..., 0]
203
+ return xt # type: ignore
204
+ if self.diffusion_type == "uniform":
205
+ xt = torch.randint(0, self.vocab_size, x0.shape, device=x0.device)
206
+ xt = torch.where(context_mask.bool(), x0, xt)
207
+ if self.config.keep_clean_bos:
208
+ xt[..., 0] = x0[..., 0]
209
+ return xt # type: ignore
210
+ raise NotImplementedError(
211
+ f"Diffusion type '{self.diffusion_type}' not implemented."
212
+ )
213
+
214
+ def _prepare_inputs(
215
+ self,
216
+ input_ids: torch.LongTensor,
217
+ attention_mask: Optional[torch.FloatTensor] = None,
218
+ context_mask: Optional[torch.FloatTensor] = None,
219
+ t: Optional[torch.FloatTensor] = None,
220
+ past_key_values: Optional[Cache] = None,
221
+ ):
222
+ # Prepare inputs for D3PM model
223
+ if attention_mask is None:
224
+ attention_mask = torch.ones_like(input_ids)
225
+ if context_mask is None:
226
+ context_mask = torch.zeros_like(attention_mask)
227
+
228
+ if torch.is_floating_point(attention_mask):
229
+ attention_mask = attention_mask.to(torch.int)
230
+ context_mask = context_mask.to(torch.int)
231
+
232
+ if t is None:
233
+ t = torch.rand(input_ids.shape[0], device=input_ids.device)
234
+ alpha_t, alpha_t_prime = self.noise_schedule(t)
235
+ while alpha_t.ndim < 2:
236
+ alpha_t = alpha_t[..., None]
237
+ alpha_t_prime = alpha_t_prime[..., None]
238
+ xt = self._sample_q_xt(
239
+ x0=input_ids,
240
+ alpha_t=alpha_t,
241
+ context_mask=context_mask,
242
+ )
243
+ if (context_mask is not None
244
+ and context_mask.sum() == 0
245
+ and (attention_mask == 1).all()):
246
+ processed_attention_mask = None
247
+ else:
248
+ processed_attention_mask = (
249
+ self.static_attention_mask[None, ...]
250
+ & attention_mask[:, None, :]
251
+ & attention_mask[..., None]
252
+ )[:, None, ...] # Make attention mask 4D
253
+ processed_attention_mask = self._preprocess_attention_mask(
254
+ processed_attention_mask, dtype=torch.float
255
+ )
256
+ if self.training and self.config.train_on_context:
257
+ tokens_mask = attention_mask
258
+ else:
259
+ tokens_mask = attention_mask * (1 - context_mask)
260
+ return DenoiserInput(
261
+ xt=xt,
262
+ x0=input_ids,
263
+ attention_mask=processed_attention_mask,
264
+ context_mask=context_mask,
265
+ tokens_mask=tokens_mask,
266
+ t=t,
267
+ alpha_t=alpha_t,
268
+ alpha_t_prime=alpha_t_prime,
269
+ )
270
+
271
+ def _prepare_inputs_inference(
272
+ self,
273
+ input_ids: Optional[torch.LongTensor] = None,
274
+ attention_mask: Optional[torch.FloatTensor] = None,
275
+ context: Optional[torch.LongTensor] = None,
276
+ context_mask: Optional[torch.FloatTensor] = None,
277
+ cache: Optional[Dict[str, Any]] = None,
278
+ **backbone_kwargs: Any,
279
+ ) -> Tuple[DenoiserInput, Dict[str, Any]]:
280
+ assert input_ids is not None or context is not None, (
281
+ "Must provide either input_ids or context."
282
+ )
283
+ cache = cache if cache is not None else {}
284
+ past_key_values = cache.pop("past_key_values", DynamicCache())
285
+ if context is not None:
286
+ if input_ids is not None:
287
+ if context_mask is None:
288
+ context_mask = torch.cat(
289
+ [torch.ones_like(context), torch.zeros_like(input_ids)], dim=-1
290
+ )
291
+ input_ids = torch.cat([context, input_ids], dim=-1)
292
+ else:
293
+ input_ids = context
294
+ context_mask = torch.ones_like(input_ids)
295
+ if attention_mask is None:
296
+ cache_length = self._get_past_key_values_seq_length(past_key_values)
297
+ full_seq_length = cache_length + input_ids.shape[-1]
298
+ attention_mask = torch.ones(
299
+ (input_ids.shape[0], 1, input_ids.shape[1], full_seq_length),
300
+ device=input_ids.device,
301
+ ) # Make attention mask 4D
302
+ attention_mask = self._preprocess_attention_mask(
303
+ attention_mask, dtype=torch.float
304
+ )
305
+ return DenoiserInput(
306
+ xt=input_ids,
307
+ attention_mask=attention_mask,
308
+ past_key_values=past_key_values,
309
+ context_mask=context_mask,
310
+ backbone_kwargs=backbone_kwargs | {"use_cache": False},
311
+ ), cache
312
+
313
+ def _forward(
314
+ self,
315
+ backbone_output: torch.FloatTensor,
316
+ denoiser_inputs: DenoiserInput,
317
+ **kwargs,
318
+ ) -> torch.FloatTensor:
319
+ return torch.log_softmax(backbone_output, dim=-1) # type: ignore
320
+
321
+ def _compute_loss(
322
+ self,
323
+ model_output: torch.FloatTensor,
324
+ denoiser_inputs: DenoiserInput,
325
+ **kwargs: Any,
326
+ ) -> LossAndNllOutput:
327
+ raise NotImplementedError
328
+
329
+ def _sample_prior(self, device, batch_size, length):
330
+ """Samples from prior / limiting distribution."""
331
+ if self.diffusion_type == "absorbing":
332
+ return self.mask_token_id * torch.ones(
333
+ (batch_size, length), dtype=torch.int64, device=device
334
+ )
335
+ if self.diffusion_type == "uniform":
336
+ return torch.randint(
337
+ 0,
338
+ self.vocab_size,
339
+ (batch_size, length),
340
+ device=device,
341
+ dtype=torch.int64,
342
+ )
343
+ raise NotImplementedError(
344
+ f"Diffusion type '{self.diffusion_type}' not implemented."
345
+ )
346
+
347
+ def _compute_posterior(
348
+ self,
349
+ x: Union[torch.FloatTensor, torch.LongTensor],
350
+ xt: torch.LongTensor,
351
+ alpha_t: torch.FloatTensor,
352
+ alpha_s: torch.FloatTensor,
353
+ ) -> torch.FloatTensor:
354
+ """Computes posterior / approximate posterior q(x_s | x_t, x),
355
+ where x represents clean sequence (as one-hots) or the output of the
356
+ denoising model.
357
+
358
+ Args:
359
+ x (Tensor): True (one-hot) / predicted clean signal (B, L, V).
360
+ xt (Tensor): Noised signal at time t (B, L).
361
+ alpha_t (Tensor): Noise schedule parameter at time t (B, 1, 1).
362
+ alpha_s (Tensor): Noise schedule parameter at time s (B, 1, 1).
363
+ """
364
+ if self.diffusion_type == "absorbing":
365
+ q_xs = x * (alpha_s - alpha_t)
366
+ q_xs[..., self.mask_token_id] = 1 - alpha_s[..., 0]
367
+ q_xs /= 1 - alpha_t
368
+ return q_xs # type: ignore
369
+
370
+ alpha_ts = alpha_t / alpha_s
371
+ d_alpha = alpha_s - alpha_t
372
+ xt_one_hot = torch.nn.functional.one_hot(x, self.vocab_size)
373
+ limiting_distribution = torch.ones_like(xt_one_hot) / self.vocab_size
374
+ if self.diffusion_type == "uniform":
375
+ return (
376
+ alpha_t * self.vocab_size * x * xt_one_hot
377
+ + (alpha_ts - alpha_t) * xt_one_hot
378
+ + d_alpha * x
379
+ + (1 - alpha_ts) * (1 - alpha_s) * limiting_distribution
380
+ ) / (
381
+ alpha_t * self.vocab_size * torch.gather(x, -1, xt[..., None])
382
+ + (1 - alpha_t)
383
+ )
384
+ raise NotImplementedError(
385
+ f"Diffusion type {self.diffusion_type} not implemented."
386
+ )
387
+
388
+ @staticmethod
389
+ def _sample_generation_timesteps(
390
+ generation_config: DiffusionGenerationConfig,
391
+ max_length: Optional[int] = None,
392
+ device: Optional[str] = None,
393
+ ) -> torch.FloatTensor:
394
+ """Sample timesteps for diffusion generation process."""
395
+ if device is None:
396
+ device = "cuda" if torch.cuda.is_available() else "cpu"
397
+ if max_length is None:
398
+ max_length = generation_config.max_new_tokens
399
+
400
+ if (
401
+ generation_config.first_hitting
402
+ # TODO: first-hitting does not work with posterior
403
+ and generation_config.sampling_strategy == "posterior"
404
+ ):
405
+ timesteps = torch.FloatTensor([1.0])
406
+ for i in range(max_length, 0, -1):
407
+ u = torch.rand(1)
408
+ next_t = timesteps[-1] * u ** (1 / i)
409
+ timesteps = torch.cat((timesteps, next_t), dim=0)
410
+ return timesteps[1:].to(device) # type: ignore
411
+ return torch.linspace( # type: ignore
412
+ 1.0,
413
+ generation_config.min_t,
414
+ generation_config.num_steps + 1,
415
+ device=device,
416
+ )[:-1]
417
+
418
+ def _generate_unconditional(
419
+ self,
420
+ generation_config: DiffusionGenerationConfig,
421
+ alpha_t: torch.FloatTensor,
422
+ alpha_s: torch.FloatTensor,
423
+ denoiser_inputs: Optional[DenoiserInput] = None,
424
+ model_output_cache: Optional[Dict[str, torch.FloatTensor]] = None,
425
+ cache: Optional[Dict[str, Any]] = None,
426
+ running_generation: Optional[torch.LongTensor] = None,
427
+ logits_processor: Optional[LogitsProcessorList] = None,
428
+ **kwargs: Any,
429
+ ) -> Tuple[torch.LongTensor, Dict[str, torch.FloatTensor], Dict[str, Any]]:
430
+ cache = cache if cache is not None else {}
431
+ if model_output_cache is None: # execute function evaluation
432
+ backbone_output = self._backbone_forward(
433
+ denoiser_inputs,
434
+ fix_cache_length=True, # Do not let kv cache grow on each forward call
435
+ **cache,
436
+ **kwargs,
437
+ )
438
+ backbone_output = {k: v for k, v in backbone_output.items()}
439
+ logits = backbone_output.pop("logits")
440
+ cache = cache | backbone_output
441
+ log_x_theta = self._forward(logits, denoiser_inputs, **kwargs)
442
+ if logits_processor is not None:
443
+ for token_idx in range(log_x_theta.shape[1]):
444
+ # TODO: Looping over token positions like this does not allow for
445
+ # some processors, e.g. length penalty which could be applied all
446
+ # at once to the entire block, to be applied in parallel.
447
+ log_x_theta[:, token_idx] = logits_processor(
448
+ input_ids=running_generation,
449
+ scores=log_x_theta[:, token_idx], # type: ignore
450
+ )
451
+ log_x_theta = torch.log_softmax(log_x_theta, dim=-1) # re-normalize
452
+ x_theta = log_x_theta.exp()
453
+ else:
454
+ x_theta = model_output_cache["x_theta"]
455
+ model_output_cache = {"x_theta": x_theta}
456
+ prob_check_denom = denoiser_inputs.xt.numel()
457
+ if generation_config.sampling_strategy == "posterior":
458
+ q_xs = self._compute_posterior(
459
+ x_theta, denoiser_inputs.xt, alpha_t, alpha_s
460
+ )
461
+
462
+ assert abs((q_xs.sum() / prob_check_denom).item() - 1.0) < 1e-6, (
463
+ "Posterior probabilities not summing to 1."
464
+ )
465
+ assert q_xs.isnan().sum().item() == 0, "NaN found in the posterior."
466
+ xs = self._sample_categorical(q_xs, generation_config.do_sample)
467
+ output = torch.where(
468
+ (denoiser_inputs.xt != self.mask_token_id).bool(), # type: ignore
469
+ denoiser_inputs.xt,
470
+ xs,
471
+ )
472
+ elif generation_config.sampling_strategy == "predict_and_noise":
473
+ assert self.config.diffusion_type == "absorbing", (
474
+ "predict_and_noise decoding strategy only supports absorbing diffusion."
475
+ )
476
+ # assert (
477
+ # abs((x_theta.sum() / prob_check_denom).item() - 1.0) < 1e-6
478
+ # ), "Denoising output probabilities not summing to 1."
479
+ # assert x_theta.isnan().sum().item() == 0, (
480
+ # "NaN found in the denoising output."
481
+ # )
482
+
483
+ # Predict
484
+ xs = self._sample_categorical(x_theta, generation_config.do_sample)
485
+ xs_probs = x_theta.gather(-1, xs[..., None]).squeeze(dim=-1)
486
+ output = xs.clone()
487
+
488
+ # Noise
489
+ num_noise_indices = torch.minimum(
490
+ ((1 - alpha_s) * generation_config.block_size).to(torch.int),
491
+ (denoiser_inputs.xt == self.mask_token_id).sum() - 1, # type: ignore
492
+ )
493
+ if generation_config.confidence_based_noising:
494
+ conf = x_theta.gather(-1, xs[..., None]).squeeze(-1)
495
+ conf = torch.where( # already decoded tokens have 'inf' confidence
496
+ (denoiser_inputs.xt == self.mask_token_id).bool(), # type: ignore
497
+ conf,
498
+ torch.inf,
499
+ )
500
+ noise_indices = conf.argsort(dim=-1)[..., :num_noise_indices]
501
+ elif generation_config.confidence_margin_based_noising:
502
+ top2 = torch.topk(x_theta, k=2, dim=-1).values # shape: (B, L, 2)
503
+ conf = (top2[..., 0] - top2[..., 1]).abs()
504
+ conf = torch.where( # already decoded tokens have 'inf' confidence
505
+ (denoiser_inputs.xt == self.mask_token_id).bool(), # type: ignore
506
+ conf,
507
+ torch.inf,
508
+ )
509
+ noise_indices = conf.argsort(dim=-1)[..., :num_noise_indices]
510
+ else:
511
+ # TODO: implement random noise indices selection
512
+ raise NotImplementedError
513
+ output[..., noise_indices] = self.mask_token_id
514
+ output = torch.where(
515
+ xs_probs >= generation_config.confidence_threshold, xs, output
516
+ )
517
+ else:
518
+ raise NotImplementedError(
519
+ f"Sampling strategy {generation_config.sampling_strategy} not"
520
+ " implemented."
521
+ )
522
+ return output, model_output_cache, cache # type: ignore
523
+
524
+ @torch.no_grad()
525
+ def generate(
526
+ self,
527
+ inputs: Optional[torch.LongTensor] = None,
528
+ generation_config: Optional[DiffusionGenerationConfig] = None,
529
+ logits_processor: Optional[LogitsProcessorList] = None,
530
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
531
+ max_length: Optional[int] = None,
532
+ max_new_tokens: Optional[int] = None,
533
+ batch_size: Optional[int] = None,
534
+ device: Optional[str] = None,
535
+ tokenizer: Optional[PreTrainedTokenizer] = None,
536
+ disable_pbar: bool = False,
537
+ **kwargs: Any,
538
+ ) -> torch.LongTensor:
539
+ # Setup sampling variables
540
+ if generation_config is None:
541
+ assert getattr(self, "generation_config", None) is not None, (
542
+ "Generation config must be provided if not present in the model."
543
+ )
544
+ generation_config = self.generation_config
545
+ if inputs is None:
546
+ inputs = torch.ones((batch_size, 1), device=device) * self.bos_token_id
547
+ if max_length is None:
548
+ if hasattr(generation_config, "max_length"):
549
+ max_length = generation_config.max_length
550
+ else:
551
+ max_length = self.max_length
552
+ if max_new_tokens is None:
553
+ if hasattr(generation_config, "max_new_tokens"):
554
+ max_new_tokens = generation_config.max_new_tokens
555
+ else:
556
+ max_new_tokens = max_length - inputs.shape[-1]
557
+ batch_size = batch_size if batch_size is not None else inputs.shape[0]
558
+ assert batch_size == 1, "Batched sampling not supported yet"
559
+ if device is None:
560
+ device = "cuda" if torch.cuda.is_available() else "cpu"
561
+ block_size = generation_config.block_size
562
+ max_blocks = max_new_tokens // block_size
563
+
564
+ # Sample max generation length tensor from prior
565
+ accumulated_samples = self._sample_prior(
566
+ device=device,
567
+ batch_size=batch_size,
568
+ length=max_blocks * block_size,
569
+ )
570
+ accumulated_samples = torch.cat([inputs, accumulated_samples], dim=-1)
571
+ if generation_config.use_cache and inputs.numel() > 0:
572
+ cache = self.update_cache(
573
+ inputs=inputs[:, : block_size * (inputs.shape[-1] // block_size)]
574
+ if generation_config.align_inputs_to_blocks
575
+ else inputs,
576
+ cache={},
577
+ )
578
+ else:
579
+ cache = None
580
+
581
+ if generation_config.align_inputs_to_blocks:
582
+ inputs_offset = (
583
+ block_size * (inputs.shape[-1] // block_size)
584
+ if inputs.numel() > 0
585
+ else 0
586
+ )
587
+ else:
588
+ inputs_offset = inputs.shape[-1] if inputs.numel() > 0 else 0
589
+
590
+ total_NFEs = 0
591
+ timesteps = self._sample_generation_timesteps( # Re-use in every block
592
+ generation_config, max_length=block_size, device=device
593
+ )
594
+ dt = (1 - generation_config.min_t) / len(timesteps)
595
+ block_pbar = tqdm(
596
+ range(max_blocks),
597
+ desc="Blocks",
598
+ leave=True,
599
+ disable=disable_pbar,
600
+ )
601
+ for block_id in block_pbar:
602
+ block_NFEs = 0
603
+ xt = accumulated_samples[
604
+ :,
605
+ inputs_offset + (block_id * block_size) : inputs_offset
606
+ + ((block_id + 1) * block_size),
607
+ ]
608
+ if self.mask_token_id not in xt:
609
+ continue
610
+ step_pbar = tqdm(
611
+ timesteps,
612
+ desc="T",
613
+ total=timesteps.shape[0],
614
+ leave=False,
615
+ disable=disable_pbar,
616
+ )
617
+ model_output_cache = None
618
+ context = (
619
+ accumulated_samples[:, : (block_id * block_size) + inputs_offset]
620
+ if not generation_config.use_cache
621
+ else None
622
+ )
623
+ # Used for logit processing
624
+ running_generation = accumulated_samples[
625
+ :,
626
+ inputs_offset : inputs_offset + (block_id * block_size),
627
+ ]
628
+ for t in step_pbar:
629
+ if model_output_cache is None:
630
+ block_NFEs += 1
631
+ total_NFEs += 1
632
+ # t is 0-dim tensor, reshape to (1, 1, 1) for broadcasting
633
+ alpha_t, _ = self.noise_schedule(t)
634
+ alpha_s, _ = self.noise_schedule(t - dt)
635
+ alpha_t = alpha_t[None, None, None]
636
+ alpha_s = alpha_s[None, None, None]
637
+ denoiser_inputs, cache = self._prepare_inputs_inference(
638
+ input_ids=xt,
639
+ context=context,
640
+ cache=cache if generation_config.use_cache else None,
641
+ )
642
+ xs, model_output_cache, cache = self._generate_unconditional(
643
+ generation_config=generation_config,
644
+ alpha_t=alpha_t,
645
+ alpha_s=alpha_s,
646
+ denoiser_inputs=denoiser_inputs,
647
+ model_output_cache=model_output_cache,
648
+ cache=cache,
649
+ running_generation=running_generation, # type: ignore
650
+ logits_processor=logits_processor,
651
+ tokenizer=tokenizer,
652
+ **kwargs,
653
+ )
654
+ block_pbar.set_postfix(
655
+ NFEs=total_NFEs,
656
+ block_NFEs=block_NFEs,
657
+ )
658
+
659
+ if (
660
+ not torch.allclose(xs, denoiser_inputs.xt)
661
+ or not generation_config.use_model_output_cache
662
+ ):
663
+ model_output_cache = None
664
+ if not generation_config.use_cache:
665
+ xt[..., -block_size:] = xs[..., -block_size:]
666
+ else:
667
+ xt = xs
668
+ if (
669
+ xt == self.mask_token_id
670
+ ).sum().item() == 0 and self.config.diffusion_type == "absorbing":
671
+ break
672
+ accumulated_samples[
673
+ :,
674
+ inputs_offset + (block_id * block_size) : inputs_offset
675
+ + ((block_id + 1) * block_size),
676
+ ] = xt
677
+ if tokenizer is not None: # Useful for debugging
678
+ print(tokenizer.batch_decode(accumulated_samples))
679
+ if stopping_criteria is not None:
680
+ is_done = stopping_criteria(
681
+ input_ids=accumulated_samples[ # type: ignore
682
+ :,
683
+ inputs_offset : inputs_offset + ((block_id + 1) * block_size),
684
+ ],
685
+ scores=None, # type: ignore
686
+ )
687
+ if torch.any(is_done):
688
+ accumulated_samples = accumulated_samples[
689
+ :,
690
+ : inputs_offset + ((block_id + 1) * block_size),
691
+ ]
692
+ break
693
+ if generation_config.use_cache:
694
+ cache = self.update_cache(
695
+ inputs=xt,
696
+ cache=cache,
697
+ )
698
+ return accumulated_samples # type: ignore
699
+
700
+
701
+ class MDLMConfig(D3PMConfig):
702
+ """Configuration class for MDLM models."""
703
+
704
+ model_type = "mdlm"
705
+ auto_map = {
706
+ "AutoConfig": "diffusion.MDLMConfig",
707
+ "AutoModel": "diffusion.MDLM",
708
+ "AutoModelForMaskedLM": "diffusion.MDLM",
709
+ }
710
+
711
+
712
+ class MDLM(D3PM):
713
+ """Denoiser class for MDLM models."""
714
+
715
+ config_class = MDLMConfig
716
+
717
+ def __init__(self, config: MDLMConfig, **kwargs):
718
+ super().__init__(config, **kwargs)
719
+ self.neg_infinity = -1e12
720
+
721
+ def _forward(
722
+ self,
723
+ backbone_output: torch.FloatTensor,
724
+ denoiser_inputs: DenoiserInput,
725
+ **kwargs,
726
+ ) -> torch.FloatTensor:
727
+ # Zero-mask probability
728
+ backbone_output[..., self.mask_token_id] = self.neg_infinity
729
+ log_probs = backbone_output - torch.logsumexp(
730
+ backbone_output, dim=-1, keepdim=True
731
+ )
732
+ # Copy-over unmasked: For the log_probs of the unmasked tokens, set all values
733
+ # to -infinity except for the indices corresponding to
734
+ # the unmasked tokens.
735
+ xt = denoiser_inputs.xt
736
+ unmasked_indices = xt != self.mask_token_id
737
+ log_probs[unmasked_indices] = self.neg_infinity
738
+ log_probs[unmasked_indices, xt[unmasked_indices]] = 0
739
+ return log_probs # type: ignore
740
+
741
+ def _compute_loss(
742
+ self,
743
+ model_output: torch.FloatTensor,
744
+ denoiser_inputs: DenoiserInput,
745
+ **kwargs: Any,
746
+ ) -> LossAndNllOutput:
747
+ log_p_theta = torch.gather(
748
+ input=model_output, dim=-1, index=denoiser_inputs.x0[:, :, None]
749
+ ).squeeze(-1)
750
+ nlls = (
751
+ log_p_theta
752
+ * denoiser_inputs.alpha_t_prime
753
+ / (1 - denoiser_inputs.alpha_t)
754
+ * denoiser_inputs.tokens_mask
755
+ )
756
+ if self.training:
757
+ batch_nll = -(log_p_theta * denoiser_inputs.tokens_mask).sum(dim=-1)
758
+ else:
759
+ batch_nll = nlls.sum(dim=-1)
760
+ count = denoiser_inputs.tokens_mask.sum(dim=-1)
761
+ token_nll = (batch_nll / count).mean()
762
+ return LossAndNllOutput(
763
+ loss=token_nll, # type: ignore
764
+ nlls=nlls,
765
+ other_loss_terms={
766
+ "masked_tokens": (denoiser_inputs.xt == self.mask_token_id).int()
767
+ },
768
+ )
769
+
770
+
771
+ class BD3LMConfig(MDLMConfig):
772
+ """Configuration class for BD3LM models."""
773
+
774
+ model_type = "bd3lm"
775
+ auto_map = {
776
+ "AutoConfig": "diffusion.BD3LMConfig",
777
+ "AutoModel": "diffusion.BD3LM",
778
+ "AutoModelForMaskedLM": "diffusion.BD3LM",
779
+ }
780
+
781
+ def __init__(
782
+ self,
783
+ block_size: Optional[int] = None,
784
+ eval_block_size: Optional[int] = None,
785
+ **kwargs,
786
+ ):
787
+ super().__init__(**kwargs)
788
+ self.block_size = block_size
789
+ self.eval_block_size = (
790
+ eval_block_size if eval_block_size is not None else block_size
791
+ )
792
+
793
+
794
+ class BD3LM(MDLM):
795
+ """Denoiser class for BD3LM models."""
796
+
797
+ config_class = BD3LMConfig
798
+
799
+ def __init__(self, config: BD3LMConfig, **kwargs):
800
+ super().__init__(config, **kwargs)
801
+
802
+ # noinspection PyUnusedLocal
803
+ @staticmethod
804
+ def _block_mask(
805
+ b,
806
+ h,
807
+ q_idx,
808
+ kv_idx,
809
+ block_size: Optional[int] = None,
810
+ seq_length: Optional[int] = None,
811
+ ) -> torch.Tensor:
812
+ del b, h
813
+
814
+ # Indicate whether token belongs to xt or x0:
815
+ xt_flag_q = (q_idx >= seq_length).bool()
816
+ xt_flag_kv = (kv_idx >= seq_length).bool()
817
+
818
+ # Compute block indices
819
+ block_q = torch.where(
820
+ xt_flag_q, (q_idx - seq_length) // block_size, q_idx // block_size
821
+ )
822
+ block_kv = torch.where(
823
+ xt_flag_kv, (kv_idx - seq_length) // block_size, kv_idx // block_size
824
+ )
825
+ # **1. Offset Block-Causal Mask (M_OBC) **
826
+ offset_block_causal = (block_q > block_kv) & ~xt_flag_kv & xt_flag_q
827
+
828
+ # **2. Block Diagonal Mask (M_BD) **
829
+ block_diagonal = (block_q == block_kv) & (xt_flag_q == xt_flag_kv)
830
+
831
+ # **3. Block-Causal Mask (M_BC) **
832
+ block_causal = (block_q >= block_kv) & ~xt_flag_kv & ~xt_flag_q
833
+
834
+ # **3. Combine Masks **
835
+ return block_diagonal | offset_block_causal | block_causal
836
+
837
+ def _create_static_mask(self) -> None:
838
+ if self.config.attn_backend == "sdpa":
839
+ static_mask = self._block_mask(
840
+ b=None,
841
+ h=None,
842
+ q_idx=torch.arange(self.config.length * 2)[:, None],
843
+ kv_idx=torch.arange(self.config.length * 2)[None, :],
844
+ block_size=self.config.block_size
845
+ if self.training
846
+ else self.config.eval_block_size,
847
+ seq_length=self.config.length,
848
+ )
849
+ self.register_buffer(
850
+ "static_attention_mask",
851
+ static_mask,
852
+ )
853
+ elif self.config.attn_backend == "flex_attention":
854
+ mask = partial(
855
+ self._block_mask,
856
+ block_size=self.config.block_size
857
+ if self.training
858
+ else self.config.eval_block_size,
859
+ seq_length=self.config.length,
860
+ )
861
+ self.static_attention_mask = create_block_mask(
862
+ mask,
863
+ B=None,
864
+ H=None,
865
+ Q_LEN=self.config.length * 2,
866
+ KV_LEN=self.config.length * 2,
867
+ )
868
+
869
+ def _ensure_no_unmasked_blocks(
870
+ self,
871
+ input_ids: torch.LongTensor,
872
+ xt: torch.LongTensor,
873
+ context_mask: Optional[torch.FloatTensor] = None,
874
+ ) -> torch.Tensor:
875
+ n_blocks = xt.shape[1] // self.config.block_size
876
+ # If context overlaps w/block, ignore it
877
+ blocks_without_masks = ((xt == self.mask_token_id) + context_mask).reshape(
878
+ -1, n_blocks, self.config.block_size
879
+ ).sum(dim=-1) == 0
880
+ if blocks_without_masks.sum() > 0:
881
+ num_remasks_per_block = torch.randint(
882
+ 0,
883
+ self.config.block_size,
884
+ blocks_without_masks.shape,
885
+ device=xt.device,
886
+ )
887
+ rand = torch.rand(xt.shape[0], xt.shape[1], device=xt.device)
888
+ perm_indices = torch.argsort(
889
+ rand.view(xt.shape[0], n_blocks, self.config.block_size),
890
+ stable=True,
891
+ dim=-1,
892
+ )
893
+ remask_indices = perm_indices <= num_remasks_per_block[..., None]
894
+ xt = torch.where(
895
+ remask_indices.view(xt.shape[0], xt.shape[1])
896
+ * blocks_without_masks.repeat_interleave(self.config.block_size, dim=1),
897
+ self.mask_token_id,
898
+ xt,
899
+ )
900
+ if self.config.keep_clean_bos:
901
+ xt[..., 0] = input_ids[..., 0]
902
+ return xt
903
+
904
+ def _prepare_inputs(
905
+ self,
906
+ input_ids: torch.LongTensor,
907
+ attention_mask: Optional[torch.FloatTensor] = None,
908
+ context_mask: Optional[torch.FloatTensor] = None,
909
+ t: Optional[torch.FloatTensor] = None,
910
+ past_key_values: Optional[Cache] = None,
911
+ ):
912
+ if attention_mask is None:
913
+ attention_mask = torch.ones_like(input_ids)
914
+ if context_mask is None:
915
+ context_mask = torch.zeros_like(attention_mask)
916
+
917
+ if torch.is_floating_point(attention_mask):
918
+ attention_mask = attention_mask.to(torch.int)
919
+ context_mask = context_mask.to(torch.int)
920
+
921
+ if t is None:
922
+ t = torch.rand(
923
+ input_ids.shape[0],
924
+ input_ids.shape[1] // self.config.block_size
925
+ if self.training
926
+ else self.config.eval_block_size,
927
+ device=input_ids.device,
928
+ ).repeat_interleave(
929
+ self.config.block_size
930
+ if self.training
931
+ else self.config.eval_block_size,
932
+ dim=-1,
933
+ )
934
+ alpha_t, alpha_t_prime = self.noise_schedule(t)
935
+ while alpha_t.ndim < 2:
936
+ alpha_t = alpha_t[..., None]
937
+ alpha_t_prime = alpha_t_prime[..., None]
938
+ xt = self._sample_q_xt(x0=input_ids, alpha_t=alpha_t, context_mask=context_mask)
939
+ # Ensure each block has at least 1 masked token
940
+ if self.training:
941
+ xt = self._ensure_no_unmasked_blocks(
942
+ input_ids,
943
+ xt,
944
+ context_mask,
945
+ )
946
+ if self.config.attn_backend == "sdpa":
947
+ decoder_attention_mask = (
948
+ self.static_attention_mask[None, ...]
949
+ & attention_mask.repeat(1, 2)[:, None, :]
950
+ & attention_mask.repeat(1, 2)[..., None]
951
+ )[:, None, ...] # Make attention mask 4D
952
+ decoder_attention_mask = self._preprocess_attention_mask(
953
+ decoder_attention_mask, dtype=torch.float
954
+ )
955
+ elif self.config.attn_backend == "flex_attention":
956
+ if context_mask.any():
957
+ raise NotImplementedError(
958
+ "flex_attention with context_mask not implemented yet."
959
+ )
960
+ elif attention_mask is not None and (attention_mask != 1).any():
961
+ padding_mask = create_attn_mask(
962
+ attention_mask.bool().repeat(2, 2).bool()
963
+ )
964
+ dec_masks = [
965
+ partial(
966
+ self._block_mask,
967
+ block_size=self.config.block_size
968
+ if self.training
969
+ else self.config.eval_block_size,
970
+ seq_length=self.config.length,
971
+ ),
972
+ padding_mask,
973
+ ]
974
+ decoder_attention_mask = create_block_mask(
975
+ and_masks(*dec_masks),
976
+ B=input_ids.shape[0],
977
+ H=None,
978
+ Q_LEN=input_ids.shape[1] * 2,
979
+ KV_LEN=input_ids.shape[1] * 2,
980
+ )
981
+ else:
982
+ decoder_attention_mask = self.static_attention_mask
983
+ else:
984
+ raise ValueError("Unknown backbone backend")
985
+ backbone_input_ids = torch.cat((input_ids, xt), dim=-1)
986
+ position_ids = (
987
+ torch.arange(input_ids.shape[1]).repeat(2).to(input_ids.device)[None, :]
988
+ )
989
+ if self.training and self.config.train_on_context:
990
+ tokens_mask = attention_mask
991
+ else:
992
+ tokens_mask = attention_mask * (1 - context_mask)
993
+ return DenoiserInput(
994
+ xt=backbone_input_ids, # type: ignore
995
+ x0=input_ids,
996
+ attention_mask=decoder_attention_mask, # type: ignore
997
+ tokens_mask=tokens_mask,
998
+ t=t,
999
+ alpha_t=alpha_t,
1000
+ alpha_t_prime=alpha_t_prime,
1001
+ backbone_kwargs={
1002
+ "cache_position": position_ids[0],
1003
+ "position_ids": position_ids,
1004
+ },
1005
+ )
1006
+
1007
+ def _prepare_inputs_inference(
1008
+ self,
1009
+ input_ids: Optional[torch.LongTensor] = None,
1010
+ attention_mask: Optional[torch.FloatTensor] = None,
1011
+ context: Optional[torch.LongTensor] = None,
1012
+ context_mask: Optional[torch.FloatTensor] = None,
1013
+ cache: Optional[Dict[str, Any]] = None,
1014
+ return_updated_cache: bool = False,
1015
+ **backbone_kwargs: Dict[str, Any],
1016
+ ) -> Tuple[DenoiserInput, Union[Dict[str, Any], None]]:
1017
+ device = input_ids.device if input_ids is not None else context.device
1018
+ assert input_ids is not None or context is not None, (
1019
+ "Must provide either input_ids or context."
1020
+ )
1021
+ cache = cache if cache is not None else {}
1022
+ past_key_values = cache.pop("past_key_values", DynamicCache())
1023
+ if context is not None:
1024
+ if input_ids is not None:
1025
+ input_ids = torch.cat([context, input_ids], dim=-1)
1026
+ else:
1027
+ input_ids = context
1028
+ cache_length = self._get_past_key_values_seq_length(past_key_values)
1029
+ full_seq_length = cache_length + input_ids.shape[-1]
1030
+ decoder_attention_mask = self.static_attention_mask[
1031
+ None,
1032
+ None,
1033
+ cache_length:full_seq_length,
1034
+ :full_seq_length,
1035
+ ] # Make attention mask 4D
1036
+ decoder_attention_mask = self._preprocess_attention_mask(
1037
+ decoder_attention_mask, dtype=torch.float
1038
+ )
1039
+ position_ids = torch.arange(cache_length, full_seq_length).to(device)[None, :]
1040
+ return DenoiserInput(
1041
+ xt=input_ids,
1042
+ attention_mask=decoder_attention_mask,
1043
+ context_mask=context_mask,
1044
+ past_key_values=past_key_values,
1045
+ backbone_kwargs={
1046
+ "position_ids": position_ids,
1047
+ }
1048
+ | backbone_kwargs,
1049
+ ), cache
1050
+
1051
+ def _compute_loss(
1052
+ self,
1053
+ model_output: torch.FloatTensor,
1054
+ denoiser_inputs: DenoiserInput,
1055
+ **kwargs: Any,
1056
+ ) -> LossAndNllOutput:
1057
+ if self.config.backbone_is_decoder_only:
1058
+ input_length = denoiser_inputs.xt.shape[1] // 2
1059
+ model_output = model_output[:, input_length:, ...]
1060
+ return super()._compute_loss(
1061
+ model_output=model_output,
1062
+ denoiser_inputs=denoiser_inputs,
1063
+ **kwargs,
1064
+ )
1065
+
1066
+
1067
+ class E2D2Config(BD3LMConfig):
1068
+ """Configuration class for E2D2 models."""
1069
+
1070
+ model_type = "e2d2"
1071
+ auto_map = {
1072
+ "AutoConfig": "diffusion.E2D2Config",
1073
+ "AutoModel": "diffusion.E2D2",
1074
+ "AutoModelForMaskedLM": "diffusion.E2D2",
1075
+ }
1076
+
1077
+ def __init__(
1078
+ self,
1079
+ **kwargs,
1080
+ ):
1081
+ super().__init__(**kwargs)
1082
+
1083
+
1084
+ class E2D2(BD3LM):
1085
+ """Denoiser class for E2D2 models."""
1086
+
1087
+ config_class = E2D2Config
1088
+
1089
+ def __init__(self, config: E2D2Config, **kwargs):
1090
+ super().__init__(config, **kwargs)
1091
+
1092
+ # noinspection PyUnusedLocal
1093
+ @staticmethod
1094
+ def _encoder_block_mask(
1095
+ b,
1096
+ h,
1097
+ q_idx,
1098
+ kv_idx,
1099
+ block_size: Optional[int] = None,
1100
+ ) -> torch.Tensor:
1101
+ """
1102
+ Args:
1103
+ q_idx (Tensor): Query indices.
1104
+ kv_idx (Tensor): Key indices
1105
+ b (Optional: int): batch size
1106
+ h (Optional: int): number of heads
1107
+ block_size (Optional: int): Defines the block structure.
1108
+
1109
+ Returns:
1110
+ Encoder block-causal attention mask.
1111
+ """
1112
+
1113
+ # Compute block indices
1114
+ block_q = q_idx // block_size
1115
+ block_kv = kv_idx // block_size
1116
+
1117
+ # ** Block-Causal Mask **
1118
+ return block_q >= block_kv
1119
+
1120
+ # noinspection PyUnusedLocal
1121
+ @staticmethod
1122
+ def _decoder_block_mask(
1123
+ b,
1124
+ h,
1125
+ q_idx,
1126
+ kv_idx,
1127
+ block_size: Optional[int] = None,
1128
+ seq_length: Optional[int] = None,
1129
+ ) -> torch.Tensor:
1130
+ # Indicate whether token belongs to xt or x0:
1131
+ xt_flag_kv = (kv_idx >= seq_length).bool()
1132
+
1133
+ # Compute block indices
1134
+ block_q = q_idx // block_size
1135
+ block_kv = torch.where(
1136
+ xt_flag_kv, (kv_idx - seq_length) // block_size, kv_idx // block_size
1137
+ )
1138
+ # **1. Offset Block-Causal Mask (M_OBC) **
1139
+ offset_block_causal = (block_q > block_kv) & ~xt_flag_kv
1140
+
1141
+ # **2. Block Diagonal Mask (M_BD) **
1142
+ block_diagonal = (block_q == block_kv) & xt_flag_kv
1143
+
1144
+ # **3. Combine Masks **
1145
+ return block_diagonal | offset_block_causal
1146
+
1147
+ def _create_static_mask(self) -> None:
1148
+ if self.config.attn_backend == "flex_attention":
1149
+ enc_mask = partial(
1150
+ self._encoder_block_mask,
1151
+ block_size=self.config.block_size
1152
+ if self.training
1153
+ else self.config.eval_block_size,
1154
+ )
1155
+ encoder_attention_mask = create_block_mask(
1156
+ enc_mask,
1157
+ B=None,
1158
+ H=None,
1159
+ Q_LEN=self.config.length,
1160
+ KV_LEN=self.config.length,
1161
+ )
1162
+ dec_mask = partial(
1163
+ self._decoder_block_mask,
1164
+ block_size=self.config.block_size
1165
+ if self.training
1166
+ else self.config.eval_block_size,
1167
+ seq_length=self.config.length,
1168
+ )
1169
+ decoder_attention_mask = create_block_mask(
1170
+ dec_mask,
1171
+ B=None,
1172
+ H=None,
1173
+ Q_LEN=self.config.length,
1174
+ KV_LEN=self.config.length * 2,
1175
+ )
1176
+ self.encoder_static_attention_mask = encoder_attention_mask
1177
+ self.static_attention_mask = decoder_attention_mask
1178
+ else:
1179
+ encoder_static_mask = self._encoder_block_mask(
1180
+ b=None, # type: ignore
1181
+ h=None, # type: ignore
1182
+ q_idx=torch.arange(self.config.length)[:, None],
1183
+ kv_idx=torch.arange(self.config.length)[None, :],
1184
+ block_size=self.config.block_size
1185
+ if self.training
1186
+ else self.config.eval_block_size,
1187
+ )
1188
+ decoder_static_mask = self._decoder_block_mask(
1189
+ b=None,
1190
+ h=None,
1191
+ q_idx=torch.arange(self.config.length)[:, None],
1192
+ kv_idx=torch.arange(self.config.length * 2)[None, :],
1193
+ block_size=self.config.block_size
1194
+ if self.training
1195
+ else self.config.eval_block_size,
1196
+ seq_length=self.config.length,
1197
+ )
1198
+ self.register_buffer(
1199
+ "encoder_static_attention_mask",
1200
+ encoder_static_mask,
1201
+ )
1202
+ self.register_buffer(
1203
+ "static_attention_mask",
1204
+ decoder_static_mask,
1205
+ )
1206
+
1207
+ def _prepare_inputs(
1208
+ self,
1209
+ input_ids: torch.LongTensor,
1210
+ attention_mask: Optional[torch.FloatTensor] = None,
1211
+ context_mask: Optional[torch.FloatTensor] = None,
1212
+ t: Optional[torch.FloatTensor] = None,
1213
+ past_key_values: Optional[Cache] = None,
1214
+ ):
1215
+ if attention_mask is None:
1216
+ attention_mask = torch.ones_like(input_ids)
1217
+ if context_mask is None:
1218
+ context_mask = torch.zeros_like(attention_mask)
1219
+
1220
+ if torch.is_floating_point(attention_mask):
1221
+ attention_mask = attention_mask.to(torch.int)
1222
+ context_mask = context_mask.to(torch.int)
1223
+
1224
+ if t is None:
1225
+ t = torch.rand(
1226
+ input_ids.shape[0],
1227
+ input_ids.shape[1] // self.config.block_size
1228
+ if self.training
1229
+ else self.config.eval_block_size,
1230
+ device=input_ids.device,
1231
+ ).repeat_interleave(
1232
+ self.config.block_size
1233
+ if self.training
1234
+ else self.config.eval_block_size,
1235
+ dim=-1,
1236
+ )
1237
+ alpha_t, alpha_t_prime = self.noise_schedule(t)
1238
+ while alpha_t.ndim < 2:
1239
+ alpha_t = alpha_t[..., None]
1240
+ alpha_t_prime = alpha_t_prime[..., None]
1241
+ xt = self._sample_q_xt(x0=input_ids, alpha_t=alpha_t, context_mask=context_mask)
1242
+ # Ensure each block has at least 1 masked token
1243
+ if self.training:
1244
+ xt = self._ensure_no_unmasked_blocks(
1245
+ input_ids,
1246
+ xt,
1247
+ context_mask,
1248
+ )
1249
+ if self.config.attn_backend == "sdpa":
1250
+ decoder_attention_mask = (
1251
+ self.static_attention_mask[None, ...]
1252
+ & attention_mask.repeat(1, 2)[:, None, :]
1253
+ & attention_mask[..., None]
1254
+ )[:, None, ...] # Make attention mask 4D
1255
+ encoder_attention_mask = (
1256
+ (
1257
+ self.encoder_static_attention_mask[None, ...]
1258
+ | context_mask[:, None, :]
1259
+ )
1260
+ & attention_mask[:, None, :]
1261
+ & attention_mask[..., None]
1262
+ )[:, None, ...] # Make attention mask 4D
1263
+ encoder_attention_mask = self._preprocess_attention_mask(
1264
+ encoder_attention_mask, dtype=torch.float
1265
+ )
1266
+ decoder_attention_mask = self._preprocess_attention_mask(
1267
+ decoder_attention_mask, dtype=torch.float
1268
+ )
1269
+ elif self.config.attn_backend == "flex_attention":
1270
+ # TODO enable bidirectional attention on context for seq2seq tasks
1271
+ if context_mask.any():
1272
+ raise NotImplementedError(
1273
+ "flex_attention with context_mask not implemented yet."
1274
+ )
1275
+ elif attention_mask is not None and (attention_mask != 1).any():
1276
+ padding_mask = create_attn_mask(attention_mask.bool())
1277
+ dec_padding_mask = create_attn_mask(attention_mask.repeat(1, 2).bool())
1278
+ enc_masks = [
1279
+ partial(
1280
+ self._encoder_block_mask,
1281
+ block_size=self.config.block_size
1282
+ if self.training
1283
+ else self.config.eval_block_size,
1284
+ ),
1285
+ padding_mask,
1286
+ ]
1287
+ encoder_attention_mask = create_block_mask(
1288
+ and_masks(*enc_masks),
1289
+ B=input_ids.shape[0],
1290
+ H=None,
1291
+ Q_LEN=input_ids.shape[1],
1292
+ KV_LEN=input_ids.shape[1],
1293
+ )
1294
+ dec_masks = [
1295
+ partial(
1296
+ self._decoder_block_mask,
1297
+ block_size=self.config.block_size
1298
+ if self.training
1299
+ else self.config.eval_block_size,
1300
+ seq_length=input_ids.shape[1],
1301
+ ),
1302
+ dec_padding_mask,
1303
+ ]
1304
+ decoder_attention_mask = create_block_mask(
1305
+ and_masks(*dec_masks),
1306
+ B=input_ids.shape[0],
1307
+ H=None,
1308
+ Q_LEN=input_ids.shape[1],
1309
+ KV_LEN=input_ids.shape[1] * 2,
1310
+ )
1311
+ else:
1312
+ encoder_attention_mask = self.encoder_static_attention_mask
1313
+ decoder_attention_mask = self.static_attention_mask
1314
+ else:
1315
+ raise ValueError("Unknown backbone backend")
1316
+ position_ids = torch.arange(input_ids.shape[1]).to(input_ids.device)[None, :]
1317
+ if self.training and self.config.train_on_context:
1318
+ tokens_mask = attention_mask
1319
+ else:
1320
+ tokens_mask = attention_mask * (1 - context_mask)
1321
+ return DenoiserInput(
1322
+ xt=xt,
1323
+ x0=input_ids,
1324
+ attention_mask=decoder_attention_mask,
1325
+ tokens_mask=tokens_mask,
1326
+ t=t,
1327
+ alpha_t=alpha_t,
1328
+ alpha_t_prime=alpha_t_prime,
1329
+ backbone_kwargs={
1330
+ "encoder_input_ids": input_ids,
1331
+ "encoder_attention_mask": encoder_attention_mask,
1332
+ "encoder_position_ids": position_ids,
1333
+ "encoder_cache_position": position_ids[0],
1334
+ },
1335
+ )
1336
+
1337
+ def _prepare_inputs_inference(
1338
+ self,
1339
+ input_ids: Optional[torch.LongTensor] = None,
1340
+ attention_mask: Optional[torch.FloatTensor] = None,
1341
+ context: Optional[torch.LongTensor] = None,
1342
+ context_mask: Optional[torch.FloatTensor] = None,
1343
+ cache: Optional[Dict[str, Any]] = None,
1344
+ return_updated_cache: bool = False,
1345
+ **backbone_kwargs: Dict[str, Any],
1346
+ ) -> Tuple[DenoiserInput, Union[Dict[str, Any], None]]:
1347
+ device = input_ids.device if input_ids is not None else context.device
1348
+ batch_size = input_ids.shape[0] if input_ids is not None else context.shape[0]
1349
+ assert input_ids is not None or context is not None, (
1350
+ "Must provide either input_ids or context."
1351
+ )
1352
+ if return_updated_cache: # Indicates this is a cache update step
1353
+ context = input_ids
1354
+ input_ids = None
1355
+ position_ids, encoder_position_ids = None, None
1356
+ if cache is not None:
1357
+ past_key_values = cache.pop("past_key_values", DynamicCache())
1358
+ encoder_past_key_values = cache.pop(
1359
+ "encoder_past_key_values", DynamicCache()
1360
+ )
1361
+ encoder_last_hidden_state = cache.pop("encoder_last_hidden_state", None)
1362
+ if input_ids is not None: # Skip enc: nothing new to cache
1363
+ cache_length = self._get_past_key_values_seq_length(past_key_values)
1364
+ if encoder_last_hidden_state is not None:
1365
+ full_seq_length = (
1366
+ cache_length
1367
+ + encoder_last_hidden_state.shape[1] # type: ignore
1368
+ + input_ids.shape[-1]
1369
+ )
1370
+ else:
1371
+ full_seq_length = cache_length + input_ids.shape[-1]
1372
+ encoder_attention_mask = None
1373
+ position_ids = torch.arange(
1374
+ cache_length, full_seq_length, device=device
1375
+ )[None, :]
1376
+ else: # Caching new tokens in the enc
1377
+ encoder_cache_length = self._get_past_key_values_seq_length(
1378
+ encoder_past_key_values
1379
+ if len(encoder_past_key_values) > 0
1380
+ else past_key_values
1381
+ )
1382
+ encoder_full_seq_length = encoder_cache_length + context.shape[-1]
1383
+ encoder_attention_mask = torch.ones(
1384
+ (
1385
+ 1,
1386
+ 1,
1387
+ encoder_full_seq_length - encoder_cache_length,
1388
+ encoder_full_seq_length,
1389
+ ),
1390
+ device=context.device,
1391
+ )
1392
+ encoder_position_ids = torch.arange(
1393
+ encoder_cache_length, encoder_full_seq_length
1394
+ ).to(device)[None, :]
1395
+ encoder_attention_mask = self._preprocess_attention_mask(
1396
+ encoder_attention_mask, dtype=torch.float
1397
+ )
1398
+ full_seq_length = -1 # Not used
1399
+ else: # Not using kv-cache
1400
+ past_key_values = None
1401
+ encoder_past_key_values, encoder_last_hidden_state = None, None
1402
+ if context is not None:
1403
+ context_len = context.shape[1]
1404
+ encoder_attention_mask = torch.ones(
1405
+ (1, 1, context_len, context_len), device=context.device
1406
+ )
1407
+ encoder_attention_mask = self._preprocess_attention_mask(
1408
+ encoder_attention_mask, dtype=torch.float
1409
+ )
1410
+ encoder_position_ids = torch.arange(context_len).to(device)[None, :]
1411
+ else:
1412
+ context_len = 0
1413
+ encoder_attention_mask = None
1414
+ if input_ids is not None:
1415
+ full_seq_length = context_len + input_ids.shape[1]
1416
+ else:
1417
+ full_seq_length = context_len
1418
+ position_ids = torch.arange(context_len, full_seq_length).to(device)[
1419
+ None, :
1420
+ ]
1421
+ if input_ids is not None:
1422
+ decoder_attention_mask = torch.ones(
1423
+ (batch_size, 1, input_ids.shape[1], full_seq_length),
1424
+ device=device,
1425
+ ) # Make attention mask 4D
1426
+ decoder_attention_mask = self._preprocess_attention_mask(
1427
+ decoder_attention_mask, dtype=torch.float
1428
+ )
1429
+ else:
1430
+ decoder_attention_mask = None
1431
+ return DenoiserInput(
1432
+ xt=input_ids,
1433
+ attention_mask=decoder_attention_mask,
1434
+ context_mask=context_mask,
1435
+ past_key_values=past_key_values,
1436
+ backbone_kwargs={
1437
+ "position_ids": position_ids,
1438
+ "encoder_input_ids": context,
1439
+ "encoder_position_ids": encoder_position_ids,
1440
+ "encoder_attention_mask": encoder_attention_mask,
1441
+ "encoder_past_key_values": encoder_past_key_values,
1442
+ "encoder_last_hidden_state": encoder_last_hidden_state,
1443
+ }
1444
+ | backbone_kwargs,
1445
+ ), cache # TODO: potentially returning cache None, violates return type
pytorch_model-00001-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2ec1e125c60bdd45b3cfabf5607e7f78309167237b87ae933a979356f7025e9
3
+ size 4971388425
pytorch_model-00002-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a42802ca08b0866382f6cde873eaeccb3dc59555aca81b107c18290a04f96fd2
3
+ size 1912835445
pytorch_model.bin.index.json ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 6884069376
4
+ },
5
+ "weight_map": {
6
+ "backbone.decoder.lm_head.weight": "pytorch_model-00001-of-00002.bin",
7
+ "backbone.decoder.model.embed_tokens.weight": "pytorch_model-00001-of-00002.bin",
8
+ "backbone.decoder.model.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
9
+ "backbone.decoder.model.layers.0.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
10
+ "backbone.decoder.model.layers.0.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
11
+ "backbone.decoder.model.layers.0.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
12
+ "backbone.decoder.model.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
13
+ "backbone.decoder.model.layers.0.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
14
+ "backbone.decoder.model.layers.0.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
15
+ "backbone.decoder.model.layers.0.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
16
+ "backbone.decoder.model.layers.0.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
17
+ "backbone.decoder.model.layers.0.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
18
+ "backbone.decoder.model.layers.0.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
19
+ "backbone.decoder.model.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
20
+ "backbone.decoder.model.layers.1.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
21
+ "backbone.decoder.model.layers.1.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
22
+ "backbone.decoder.model.layers.1.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
23
+ "backbone.decoder.model.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
24
+ "backbone.decoder.model.layers.1.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
25
+ "backbone.decoder.model.layers.1.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
26
+ "backbone.decoder.model.layers.1.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
27
+ "backbone.decoder.model.layers.1.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
28
+ "backbone.decoder.model.layers.1.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
29
+ "backbone.decoder.model.layers.1.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
30
+ "backbone.decoder.model.layers.10.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
31
+ "backbone.decoder.model.layers.10.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
32
+ "backbone.decoder.model.layers.10.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
33
+ "backbone.decoder.model.layers.10.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
34
+ "backbone.decoder.model.layers.10.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
35
+ "backbone.decoder.model.layers.10.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
36
+ "backbone.decoder.model.layers.10.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
37
+ "backbone.decoder.model.layers.10.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
38
+ "backbone.decoder.model.layers.10.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
39
+ "backbone.decoder.model.layers.10.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
40
+ "backbone.decoder.model.layers.10.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
41
+ "backbone.decoder.model.layers.11.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
42
+ "backbone.decoder.model.layers.11.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
43
+ "backbone.decoder.model.layers.11.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
44
+ "backbone.decoder.model.layers.11.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
45
+ "backbone.decoder.model.layers.11.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
46
+ "backbone.decoder.model.layers.11.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
47
+ "backbone.decoder.model.layers.11.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
48
+ "backbone.decoder.model.layers.11.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
49
+ "backbone.decoder.model.layers.11.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
50
+ "backbone.decoder.model.layers.11.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
51
+ "backbone.decoder.model.layers.11.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
52
+ "backbone.decoder.model.layers.12.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
53
+ "backbone.decoder.model.layers.12.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
54
+ "backbone.decoder.model.layers.12.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
55
+ "backbone.decoder.model.layers.12.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
56
+ "backbone.decoder.model.layers.12.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
57
+ "backbone.decoder.model.layers.12.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
58
+ "backbone.decoder.model.layers.12.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
59
+ "backbone.decoder.model.layers.12.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
60
+ "backbone.decoder.model.layers.12.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
61
+ "backbone.decoder.model.layers.12.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
62
+ "backbone.decoder.model.layers.12.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
63
+ "backbone.decoder.model.layers.13.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
64
+ "backbone.decoder.model.layers.13.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
65
+ "backbone.decoder.model.layers.13.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
66
+ "backbone.decoder.model.layers.13.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
67
+ "backbone.decoder.model.layers.13.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
68
+ "backbone.decoder.model.layers.13.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
69
+ "backbone.decoder.model.layers.13.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
70
+ "backbone.decoder.model.layers.13.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
71
+ "backbone.decoder.model.layers.13.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
72
+ "backbone.decoder.model.layers.13.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
73
+ "backbone.decoder.model.layers.13.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
74
+ "backbone.decoder.model.layers.14.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
75
+ "backbone.decoder.model.layers.14.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
76
+ "backbone.decoder.model.layers.14.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
77
+ "backbone.decoder.model.layers.14.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
78
+ "backbone.decoder.model.layers.14.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
79
+ "backbone.decoder.model.layers.14.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
80
+ "backbone.decoder.model.layers.14.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
81
+ "backbone.decoder.model.layers.14.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
82
+ "backbone.decoder.model.layers.14.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
83
+ "backbone.decoder.model.layers.14.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
84
+ "backbone.decoder.model.layers.14.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
85
+ "backbone.decoder.model.layers.15.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
86
+ "backbone.decoder.model.layers.15.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
87
+ "backbone.decoder.model.layers.15.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
88
+ "backbone.decoder.model.layers.15.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
89
+ "backbone.decoder.model.layers.15.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
90
+ "backbone.decoder.model.layers.15.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
91
+ "backbone.decoder.model.layers.15.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
92
+ "backbone.decoder.model.layers.15.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
93
+ "backbone.decoder.model.layers.15.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
94
+ "backbone.decoder.model.layers.15.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
95
+ "backbone.decoder.model.layers.15.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
96
+ "backbone.decoder.model.layers.16.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
97
+ "backbone.decoder.model.layers.16.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
98
+ "backbone.decoder.model.layers.16.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
99
+ "backbone.decoder.model.layers.16.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
100
+ "backbone.decoder.model.layers.16.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
101
+ "backbone.decoder.model.layers.16.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
102
+ "backbone.decoder.model.layers.16.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
103
+ "backbone.decoder.model.layers.16.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
104
+ "backbone.decoder.model.layers.16.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
105
+ "backbone.decoder.model.layers.16.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
106
+ "backbone.decoder.model.layers.16.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
107
+ "backbone.decoder.model.layers.17.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
108
+ "backbone.decoder.model.layers.17.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
109
+ "backbone.decoder.model.layers.17.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
110
+ "backbone.decoder.model.layers.17.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
111
+ "backbone.decoder.model.layers.17.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
112
+ "backbone.decoder.model.layers.17.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
113
+ "backbone.decoder.model.layers.17.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
114
+ "backbone.decoder.model.layers.17.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
115
+ "backbone.decoder.model.layers.17.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
116
+ "backbone.decoder.model.layers.17.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
117
+ "backbone.decoder.model.layers.17.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
118
+ "backbone.decoder.model.layers.18.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
119
+ "backbone.decoder.model.layers.18.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
120
+ "backbone.decoder.model.layers.18.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
121
+ "backbone.decoder.model.layers.18.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
122
+ "backbone.decoder.model.layers.18.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
123
+ "backbone.decoder.model.layers.18.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
124
+ "backbone.decoder.model.layers.18.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
125
+ "backbone.decoder.model.layers.18.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
126
+ "backbone.decoder.model.layers.18.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
127
+ "backbone.decoder.model.layers.18.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
128
+ "backbone.decoder.model.layers.18.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
129
+ "backbone.decoder.model.layers.19.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
130
+ "backbone.decoder.model.layers.19.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
131
+ "backbone.decoder.model.layers.19.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
132
+ "backbone.decoder.model.layers.19.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
133
+ "backbone.decoder.model.layers.19.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
134
+ "backbone.decoder.model.layers.19.self_attn.k_norm.weight": "pytorch_model-00002-of-00002.bin",
135
+ "backbone.decoder.model.layers.19.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
136
+ "backbone.decoder.model.layers.19.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
137
+ "backbone.decoder.model.layers.19.self_attn.q_norm.weight": "pytorch_model-00002-of-00002.bin",
138
+ "backbone.decoder.model.layers.19.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
139
+ "backbone.decoder.model.layers.19.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
140
+ "backbone.decoder.model.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
141
+ "backbone.decoder.model.layers.2.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
142
+ "backbone.decoder.model.layers.2.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
143
+ "backbone.decoder.model.layers.2.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
144
+ "backbone.decoder.model.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
145
+ "backbone.decoder.model.layers.2.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
146
+ "backbone.decoder.model.layers.2.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
147
+ "backbone.decoder.model.layers.2.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
148
+ "backbone.decoder.model.layers.2.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
149
+ "backbone.decoder.model.layers.2.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
150
+ "backbone.decoder.model.layers.2.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
151
+ "backbone.decoder.model.layers.20.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
152
+ "backbone.decoder.model.layers.20.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
153
+ "backbone.decoder.model.layers.20.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
154
+ "backbone.decoder.model.layers.20.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
155
+ "backbone.decoder.model.layers.20.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
156
+ "backbone.decoder.model.layers.20.self_attn.k_norm.weight": "pytorch_model-00002-of-00002.bin",
157
+ "backbone.decoder.model.layers.20.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
158
+ "backbone.decoder.model.layers.20.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
159
+ "backbone.decoder.model.layers.20.self_attn.q_norm.weight": "pytorch_model-00002-of-00002.bin",
160
+ "backbone.decoder.model.layers.20.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
161
+ "backbone.decoder.model.layers.20.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
162
+ "backbone.decoder.model.layers.21.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
163
+ "backbone.decoder.model.layers.21.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
164
+ "backbone.decoder.model.layers.21.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
165
+ "backbone.decoder.model.layers.21.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
166
+ "backbone.decoder.model.layers.21.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
167
+ "backbone.decoder.model.layers.21.self_attn.k_norm.weight": "pytorch_model-00002-of-00002.bin",
168
+ "backbone.decoder.model.layers.21.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
169
+ "backbone.decoder.model.layers.21.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
170
+ "backbone.decoder.model.layers.21.self_attn.q_norm.weight": "pytorch_model-00002-of-00002.bin",
171
+ "backbone.decoder.model.layers.21.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
172
+ "backbone.decoder.model.layers.21.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
173
+ "backbone.decoder.model.layers.22.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
174
+ "backbone.decoder.model.layers.22.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
175
+ "backbone.decoder.model.layers.22.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
176
+ "backbone.decoder.model.layers.22.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
177
+ "backbone.decoder.model.layers.22.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
178
+ "backbone.decoder.model.layers.22.self_attn.k_norm.weight": "pytorch_model-00002-of-00002.bin",
179
+ "backbone.decoder.model.layers.22.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
180
+ "backbone.decoder.model.layers.22.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
181
+ "backbone.decoder.model.layers.22.self_attn.q_norm.weight": "pytorch_model-00002-of-00002.bin",
182
+ "backbone.decoder.model.layers.22.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
183
+ "backbone.decoder.model.layers.22.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
184
+ "backbone.decoder.model.layers.23.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
185
+ "backbone.decoder.model.layers.23.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
186
+ "backbone.decoder.model.layers.23.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
187
+ "backbone.decoder.model.layers.23.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
188
+ "backbone.decoder.model.layers.23.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
189
+ "backbone.decoder.model.layers.23.self_attn.k_norm.weight": "pytorch_model-00002-of-00002.bin",
190
+ "backbone.decoder.model.layers.23.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
191
+ "backbone.decoder.model.layers.23.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
192
+ "backbone.decoder.model.layers.23.self_attn.q_norm.weight": "pytorch_model-00002-of-00002.bin",
193
+ "backbone.decoder.model.layers.23.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
194
+ "backbone.decoder.model.layers.23.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
195
+ "backbone.decoder.model.layers.24.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
196
+ "backbone.decoder.model.layers.24.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
197
+ "backbone.decoder.model.layers.24.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
198
+ "backbone.decoder.model.layers.24.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
199
+ "backbone.decoder.model.layers.24.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
200
+ "backbone.decoder.model.layers.24.self_attn.k_norm.weight": "pytorch_model-00002-of-00002.bin",
201
+ "backbone.decoder.model.layers.24.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
202
+ "backbone.decoder.model.layers.24.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
203
+ "backbone.decoder.model.layers.24.self_attn.q_norm.weight": "pytorch_model-00002-of-00002.bin",
204
+ "backbone.decoder.model.layers.24.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
205
+ "backbone.decoder.model.layers.24.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
206
+ "backbone.decoder.model.layers.25.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
207
+ "backbone.decoder.model.layers.25.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
208
+ "backbone.decoder.model.layers.25.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
209
+ "backbone.decoder.model.layers.25.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
210
+ "backbone.decoder.model.layers.25.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
211
+ "backbone.decoder.model.layers.25.self_attn.k_norm.weight": "pytorch_model-00002-of-00002.bin",
212
+ "backbone.decoder.model.layers.25.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
213
+ "backbone.decoder.model.layers.25.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
214
+ "backbone.decoder.model.layers.25.self_attn.q_norm.weight": "pytorch_model-00002-of-00002.bin",
215
+ "backbone.decoder.model.layers.25.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
216
+ "backbone.decoder.model.layers.25.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
217
+ "backbone.decoder.model.layers.26.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
218
+ "backbone.decoder.model.layers.26.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
219
+ "backbone.decoder.model.layers.26.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
220
+ "backbone.decoder.model.layers.26.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
221
+ "backbone.decoder.model.layers.26.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
222
+ "backbone.decoder.model.layers.26.self_attn.k_norm.weight": "pytorch_model-00002-of-00002.bin",
223
+ "backbone.decoder.model.layers.26.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
224
+ "backbone.decoder.model.layers.26.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
225
+ "backbone.decoder.model.layers.26.self_attn.q_norm.weight": "pytorch_model-00002-of-00002.bin",
226
+ "backbone.decoder.model.layers.26.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
227
+ "backbone.decoder.model.layers.26.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
228
+ "backbone.decoder.model.layers.27.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
229
+ "backbone.decoder.model.layers.27.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
230
+ "backbone.decoder.model.layers.27.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
231
+ "backbone.decoder.model.layers.27.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
232
+ "backbone.decoder.model.layers.27.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
233
+ "backbone.decoder.model.layers.27.self_attn.k_norm.weight": "pytorch_model-00002-of-00002.bin",
234
+ "backbone.decoder.model.layers.27.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
235
+ "backbone.decoder.model.layers.27.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
236
+ "backbone.decoder.model.layers.27.self_attn.q_norm.weight": "pytorch_model-00002-of-00002.bin",
237
+ "backbone.decoder.model.layers.27.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
238
+ "backbone.decoder.model.layers.27.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
239
+ "backbone.decoder.model.layers.3.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
240
+ "backbone.decoder.model.layers.3.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
241
+ "backbone.decoder.model.layers.3.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
242
+ "backbone.decoder.model.layers.3.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
243
+ "backbone.decoder.model.layers.3.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
244
+ "backbone.decoder.model.layers.3.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
245
+ "backbone.decoder.model.layers.3.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
246
+ "backbone.decoder.model.layers.3.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
247
+ "backbone.decoder.model.layers.3.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
248
+ "backbone.decoder.model.layers.3.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
249
+ "backbone.decoder.model.layers.3.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
250
+ "backbone.decoder.model.layers.4.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
251
+ "backbone.decoder.model.layers.4.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
252
+ "backbone.decoder.model.layers.4.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
253
+ "backbone.decoder.model.layers.4.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
254
+ "backbone.decoder.model.layers.4.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
255
+ "backbone.decoder.model.layers.4.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
256
+ "backbone.decoder.model.layers.4.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
257
+ "backbone.decoder.model.layers.4.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
258
+ "backbone.decoder.model.layers.4.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
259
+ "backbone.decoder.model.layers.4.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
260
+ "backbone.decoder.model.layers.4.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
261
+ "backbone.decoder.model.layers.5.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
262
+ "backbone.decoder.model.layers.5.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
263
+ "backbone.decoder.model.layers.5.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
264
+ "backbone.decoder.model.layers.5.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
265
+ "backbone.decoder.model.layers.5.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
266
+ "backbone.decoder.model.layers.5.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
267
+ "backbone.decoder.model.layers.5.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
268
+ "backbone.decoder.model.layers.5.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
269
+ "backbone.decoder.model.layers.5.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
270
+ "backbone.decoder.model.layers.5.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
271
+ "backbone.decoder.model.layers.5.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
272
+ "backbone.decoder.model.layers.6.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
273
+ "backbone.decoder.model.layers.6.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
274
+ "backbone.decoder.model.layers.6.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
275
+ "backbone.decoder.model.layers.6.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
276
+ "backbone.decoder.model.layers.6.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
277
+ "backbone.decoder.model.layers.6.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
278
+ "backbone.decoder.model.layers.6.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
279
+ "backbone.decoder.model.layers.6.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
280
+ "backbone.decoder.model.layers.6.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
281
+ "backbone.decoder.model.layers.6.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
282
+ "backbone.decoder.model.layers.6.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
283
+ "backbone.decoder.model.layers.7.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
284
+ "backbone.decoder.model.layers.7.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
285
+ "backbone.decoder.model.layers.7.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
286
+ "backbone.decoder.model.layers.7.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
287
+ "backbone.decoder.model.layers.7.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
288
+ "backbone.decoder.model.layers.7.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
289
+ "backbone.decoder.model.layers.7.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
290
+ "backbone.decoder.model.layers.7.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
291
+ "backbone.decoder.model.layers.7.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
292
+ "backbone.decoder.model.layers.7.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
293
+ "backbone.decoder.model.layers.7.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
294
+ "backbone.decoder.model.layers.8.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
295
+ "backbone.decoder.model.layers.8.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
296
+ "backbone.decoder.model.layers.8.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
297
+ "backbone.decoder.model.layers.8.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
298
+ "backbone.decoder.model.layers.8.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
299
+ "backbone.decoder.model.layers.8.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
300
+ "backbone.decoder.model.layers.8.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
301
+ "backbone.decoder.model.layers.8.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
302
+ "backbone.decoder.model.layers.8.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
303
+ "backbone.decoder.model.layers.8.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
304
+ "backbone.decoder.model.layers.8.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
305
+ "backbone.decoder.model.layers.9.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
306
+ "backbone.decoder.model.layers.9.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
307
+ "backbone.decoder.model.layers.9.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
308
+ "backbone.decoder.model.layers.9.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
309
+ "backbone.decoder.model.layers.9.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
310
+ "backbone.decoder.model.layers.9.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
311
+ "backbone.decoder.model.layers.9.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
312
+ "backbone.decoder.model.layers.9.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
313
+ "backbone.decoder.model.layers.9.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
314
+ "backbone.decoder.model.layers.9.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
315
+ "backbone.decoder.model.layers.9.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
316
+ "backbone.decoder.model.norm.weight": "pytorch_model-00002-of-00002.bin",
317
+ "backbone.encoder.lm_head.weight": "pytorch_model-00001-of-00002.bin",
318
+ "backbone.encoder.model.embed_tokens.weight": "pytorch_model-00001-of-00002.bin",
319
+ "backbone.encoder.model.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
320
+ "backbone.encoder.model.layers.0.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
321
+ "backbone.encoder.model.layers.0.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
322
+ "backbone.encoder.model.layers.0.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
323
+ "backbone.encoder.model.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
324
+ "backbone.encoder.model.layers.0.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
325
+ "backbone.encoder.model.layers.0.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
326
+ "backbone.encoder.model.layers.0.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
327
+ "backbone.encoder.model.layers.0.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
328
+ "backbone.encoder.model.layers.0.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
329
+ "backbone.encoder.model.layers.0.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
330
+ "backbone.encoder.model.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
331
+ "backbone.encoder.model.layers.1.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
332
+ "backbone.encoder.model.layers.1.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
333
+ "backbone.encoder.model.layers.1.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
334
+ "backbone.encoder.model.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
335
+ "backbone.encoder.model.layers.1.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
336
+ "backbone.encoder.model.layers.1.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
337
+ "backbone.encoder.model.layers.1.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
338
+ "backbone.encoder.model.layers.1.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
339
+ "backbone.encoder.model.layers.1.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
340
+ "backbone.encoder.model.layers.1.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
341
+ "backbone.encoder.model.layers.10.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
342
+ "backbone.encoder.model.layers.10.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
343
+ "backbone.encoder.model.layers.10.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
344
+ "backbone.encoder.model.layers.10.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
345
+ "backbone.encoder.model.layers.10.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
346
+ "backbone.encoder.model.layers.10.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
347
+ "backbone.encoder.model.layers.10.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
348
+ "backbone.encoder.model.layers.10.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
349
+ "backbone.encoder.model.layers.10.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
350
+ "backbone.encoder.model.layers.10.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
351
+ "backbone.encoder.model.layers.10.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
352
+ "backbone.encoder.model.layers.11.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
353
+ "backbone.encoder.model.layers.11.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
354
+ "backbone.encoder.model.layers.11.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
355
+ "backbone.encoder.model.layers.11.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
356
+ "backbone.encoder.model.layers.11.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
357
+ "backbone.encoder.model.layers.11.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
358
+ "backbone.encoder.model.layers.11.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
359
+ "backbone.encoder.model.layers.11.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
360
+ "backbone.encoder.model.layers.11.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
361
+ "backbone.encoder.model.layers.11.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
362
+ "backbone.encoder.model.layers.11.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
363
+ "backbone.encoder.model.layers.12.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
364
+ "backbone.encoder.model.layers.12.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
365
+ "backbone.encoder.model.layers.12.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
366
+ "backbone.encoder.model.layers.12.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
367
+ "backbone.encoder.model.layers.12.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
368
+ "backbone.encoder.model.layers.12.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
369
+ "backbone.encoder.model.layers.12.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
370
+ "backbone.encoder.model.layers.12.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
371
+ "backbone.encoder.model.layers.12.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
372
+ "backbone.encoder.model.layers.12.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
373
+ "backbone.encoder.model.layers.12.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
374
+ "backbone.encoder.model.layers.13.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
375
+ "backbone.encoder.model.layers.13.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
376
+ "backbone.encoder.model.layers.13.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
377
+ "backbone.encoder.model.layers.13.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
378
+ "backbone.encoder.model.layers.13.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
379
+ "backbone.encoder.model.layers.13.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
380
+ "backbone.encoder.model.layers.13.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
381
+ "backbone.encoder.model.layers.13.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
382
+ "backbone.encoder.model.layers.13.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
383
+ "backbone.encoder.model.layers.13.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
384
+ "backbone.encoder.model.layers.13.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
385
+ "backbone.encoder.model.layers.14.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
386
+ "backbone.encoder.model.layers.14.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
387
+ "backbone.encoder.model.layers.14.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
388
+ "backbone.encoder.model.layers.14.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
389
+ "backbone.encoder.model.layers.14.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
390
+ "backbone.encoder.model.layers.14.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
391
+ "backbone.encoder.model.layers.14.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
392
+ "backbone.encoder.model.layers.14.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
393
+ "backbone.encoder.model.layers.14.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
394
+ "backbone.encoder.model.layers.14.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
395
+ "backbone.encoder.model.layers.14.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
396
+ "backbone.encoder.model.layers.15.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
397
+ "backbone.encoder.model.layers.15.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
398
+ "backbone.encoder.model.layers.15.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
399
+ "backbone.encoder.model.layers.15.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
400
+ "backbone.encoder.model.layers.15.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
401
+ "backbone.encoder.model.layers.15.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
402
+ "backbone.encoder.model.layers.15.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
403
+ "backbone.encoder.model.layers.15.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
404
+ "backbone.encoder.model.layers.15.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
405
+ "backbone.encoder.model.layers.15.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
406
+ "backbone.encoder.model.layers.15.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
407
+ "backbone.encoder.model.layers.16.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
408
+ "backbone.encoder.model.layers.16.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
409
+ "backbone.encoder.model.layers.16.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
410
+ "backbone.encoder.model.layers.16.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
411
+ "backbone.encoder.model.layers.16.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
412
+ "backbone.encoder.model.layers.16.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
413
+ "backbone.encoder.model.layers.16.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
414
+ "backbone.encoder.model.layers.16.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
415
+ "backbone.encoder.model.layers.16.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
416
+ "backbone.encoder.model.layers.16.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
417
+ "backbone.encoder.model.layers.16.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
418
+ "backbone.encoder.model.layers.17.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
419
+ "backbone.encoder.model.layers.17.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
420
+ "backbone.encoder.model.layers.17.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
421
+ "backbone.encoder.model.layers.17.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
422
+ "backbone.encoder.model.layers.17.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
423
+ "backbone.encoder.model.layers.17.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
424
+ "backbone.encoder.model.layers.17.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
425
+ "backbone.encoder.model.layers.17.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
426
+ "backbone.encoder.model.layers.17.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
427
+ "backbone.encoder.model.layers.17.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
428
+ "backbone.encoder.model.layers.17.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
429
+ "backbone.encoder.model.layers.18.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
430
+ "backbone.encoder.model.layers.18.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
431
+ "backbone.encoder.model.layers.18.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
432
+ "backbone.encoder.model.layers.18.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
433
+ "backbone.encoder.model.layers.18.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
434
+ "backbone.encoder.model.layers.18.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
435
+ "backbone.encoder.model.layers.18.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
436
+ "backbone.encoder.model.layers.18.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
437
+ "backbone.encoder.model.layers.18.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
438
+ "backbone.encoder.model.layers.18.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
439
+ "backbone.encoder.model.layers.18.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
440
+ "backbone.encoder.model.layers.19.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
441
+ "backbone.encoder.model.layers.19.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
442
+ "backbone.encoder.model.layers.19.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
443
+ "backbone.encoder.model.layers.19.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
444
+ "backbone.encoder.model.layers.19.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
445
+ "backbone.encoder.model.layers.19.self_attn.k_norm.weight": "pytorch_model-00002-of-00002.bin",
446
+ "backbone.encoder.model.layers.19.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
447
+ "backbone.encoder.model.layers.19.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
448
+ "backbone.encoder.model.layers.19.self_attn.q_norm.weight": "pytorch_model-00002-of-00002.bin",
449
+ "backbone.encoder.model.layers.19.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
450
+ "backbone.encoder.model.layers.19.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
451
+ "backbone.encoder.model.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
452
+ "backbone.encoder.model.layers.2.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
453
+ "backbone.encoder.model.layers.2.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
454
+ "backbone.encoder.model.layers.2.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
455
+ "backbone.encoder.model.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
456
+ "backbone.encoder.model.layers.2.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
457
+ "backbone.encoder.model.layers.2.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
458
+ "backbone.encoder.model.layers.2.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
459
+ "backbone.encoder.model.layers.2.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
460
+ "backbone.encoder.model.layers.2.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
461
+ "backbone.encoder.model.layers.2.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
462
+ "backbone.encoder.model.layers.20.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
463
+ "backbone.encoder.model.layers.20.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
464
+ "backbone.encoder.model.layers.20.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
465
+ "backbone.encoder.model.layers.20.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
466
+ "backbone.encoder.model.layers.20.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
467
+ "backbone.encoder.model.layers.20.self_attn.k_norm.weight": "pytorch_model-00002-of-00002.bin",
468
+ "backbone.encoder.model.layers.20.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
469
+ "backbone.encoder.model.layers.20.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
470
+ "backbone.encoder.model.layers.20.self_attn.q_norm.weight": "pytorch_model-00002-of-00002.bin",
471
+ "backbone.encoder.model.layers.20.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
472
+ "backbone.encoder.model.layers.20.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
473
+ "backbone.encoder.model.layers.21.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
474
+ "backbone.encoder.model.layers.21.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
475
+ "backbone.encoder.model.layers.21.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
476
+ "backbone.encoder.model.layers.21.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
477
+ "backbone.encoder.model.layers.21.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
478
+ "backbone.encoder.model.layers.21.self_attn.k_norm.weight": "pytorch_model-00002-of-00002.bin",
479
+ "backbone.encoder.model.layers.21.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
480
+ "backbone.encoder.model.layers.21.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
481
+ "backbone.encoder.model.layers.21.self_attn.q_norm.weight": "pytorch_model-00002-of-00002.bin",
482
+ "backbone.encoder.model.layers.21.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
483
+ "backbone.encoder.model.layers.21.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
484
+ "backbone.encoder.model.layers.22.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
485
+ "backbone.encoder.model.layers.22.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
486
+ "backbone.encoder.model.layers.22.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
487
+ "backbone.encoder.model.layers.22.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
488
+ "backbone.encoder.model.layers.22.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
489
+ "backbone.encoder.model.layers.22.self_attn.k_norm.weight": "pytorch_model-00002-of-00002.bin",
490
+ "backbone.encoder.model.layers.22.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
491
+ "backbone.encoder.model.layers.22.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
492
+ "backbone.encoder.model.layers.22.self_attn.q_norm.weight": "pytorch_model-00002-of-00002.bin",
493
+ "backbone.encoder.model.layers.22.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
494
+ "backbone.encoder.model.layers.22.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
495
+ "backbone.encoder.model.layers.23.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
496
+ "backbone.encoder.model.layers.23.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
497
+ "backbone.encoder.model.layers.23.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
498
+ "backbone.encoder.model.layers.23.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
499
+ "backbone.encoder.model.layers.23.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
500
+ "backbone.encoder.model.layers.23.self_attn.k_norm.weight": "pytorch_model-00002-of-00002.bin",
501
+ "backbone.encoder.model.layers.23.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
502
+ "backbone.encoder.model.layers.23.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
503
+ "backbone.encoder.model.layers.23.self_attn.q_norm.weight": "pytorch_model-00002-of-00002.bin",
504
+ "backbone.encoder.model.layers.23.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
505
+ "backbone.encoder.model.layers.23.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
506
+ "backbone.encoder.model.layers.24.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
507
+ "backbone.encoder.model.layers.24.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
508
+ "backbone.encoder.model.layers.24.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
509
+ "backbone.encoder.model.layers.24.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
510
+ "backbone.encoder.model.layers.24.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
511
+ "backbone.encoder.model.layers.24.self_attn.k_norm.weight": "pytorch_model-00002-of-00002.bin",
512
+ "backbone.encoder.model.layers.24.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
513
+ "backbone.encoder.model.layers.24.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
514
+ "backbone.encoder.model.layers.24.self_attn.q_norm.weight": "pytorch_model-00002-of-00002.bin",
515
+ "backbone.encoder.model.layers.24.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
516
+ "backbone.encoder.model.layers.24.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
517
+ "backbone.encoder.model.layers.25.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
518
+ "backbone.encoder.model.layers.25.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
519
+ "backbone.encoder.model.layers.25.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
520
+ "backbone.encoder.model.layers.25.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
521
+ "backbone.encoder.model.layers.25.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
522
+ "backbone.encoder.model.layers.25.self_attn.k_norm.weight": "pytorch_model-00002-of-00002.bin",
523
+ "backbone.encoder.model.layers.25.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
524
+ "backbone.encoder.model.layers.25.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
525
+ "backbone.encoder.model.layers.25.self_attn.q_norm.weight": "pytorch_model-00002-of-00002.bin",
526
+ "backbone.encoder.model.layers.25.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
527
+ "backbone.encoder.model.layers.25.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
528
+ "backbone.encoder.model.layers.26.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
529
+ "backbone.encoder.model.layers.26.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
530
+ "backbone.encoder.model.layers.26.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
531
+ "backbone.encoder.model.layers.26.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
532
+ "backbone.encoder.model.layers.26.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
533
+ "backbone.encoder.model.layers.26.self_attn.k_norm.weight": "pytorch_model-00002-of-00002.bin",
534
+ "backbone.encoder.model.layers.26.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
535
+ "backbone.encoder.model.layers.26.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
536
+ "backbone.encoder.model.layers.26.self_attn.q_norm.weight": "pytorch_model-00002-of-00002.bin",
537
+ "backbone.encoder.model.layers.26.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
538
+ "backbone.encoder.model.layers.26.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
539
+ "backbone.encoder.model.layers.27.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
540
+ "backbone.encoder.model.layers.27.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
541
+ "backbone.encoder.model.layers.27.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
542
+ "backbone.encoder.model.layers.27.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
543
+ "backbone.encoder.model.layers.27.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
544
+ "backbone.encoder.model.layers.27.self_attn.k_norm.weight": "pytorch_model-00002-of-00002.bin",
545
+ "backbone.encoder.model.layers.27.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
546
+ "backbone.encoder.model.layers.27.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
547
+ "backbone.encoder.model.layers.27.self_attn.q_norm.weight": "pytorch_model-00002-of-00002.bin",
548
+ "backbone.encoder.model.layers.27.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
549
+ "backbone.encoder.model.layers.27.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
550
+ "backbone.encoder.model.layers.3.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
551
+ "backbone.encoder.model.layers.3.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
552
+ "backbone.encoder.model.layers.3.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
553
+ "backbone.encoder.model.layers.3.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
554
+ "backbone.encoder.model.layers.3.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
555
+ "backbone.encoder.model.layers.3.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
556
+ "backbone.encoder.model.layers.3.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
557
+ "backbone.encoder.model.layers.3.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
558
+ "backbone.encoder.model.layers.3.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
559
+ "backbone.encoder.model.layers.3.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
560
+ "backbone.encoder.model.layers.3.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
561
+ "backbone.encoder.model.layers.4.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
562
+ "backbone.encoder.model.layers.4.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
563
+ "backbone.encoder.model.layers.4.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
564
+ "backbone.encoder.model.layers.4.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
565
+ "backbone.encoder.model.layers.4.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
566
+ "backbone.encoder.model.layers.4.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
567
+ "backbone.encoder.model.layers.4.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
568
+ "backbone.encoder.model.layers.4.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
569
+ "backbone.encoder.model.layers.4.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
570
+ "backbone.encoder.model.layers.4.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
571
+ "backbone.encoder.model.layers.4.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
572
+ "backbone.encoder.model.layers.5.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
573
+ "backbone.encoder.model.layers.5.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
574
+ "backbone.encoder.model.layers.5.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
575
+ "backbone.encoder.model.layers.5.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
576
+ "backbone.encoder.model.layers.5.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
577
+ "backbone.encoder.model.layers.5.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
578
+ "backbone.encoder.model.layers.5.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
579
+ "backbone.encoder.model.layers.5.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
580
+ "backbone.encoder.model.layers.5.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
581
+ "backbone.encoder.model.layers.5.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
582
+ "backbone.encoder.model.layers.5.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
583
+ "backbone.encoder.model.layers.6.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
584
+ "backbone.encoder.model.layers.6.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
585
+ "backbone.encoder.model.layers.6.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
586
+ "backbone.encoder.model.layers.6.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
587
+ "backbone.encoder.model.layers.6.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
588
+ "backbone.encoder.model.layers.6.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
589
+ "backbone.encoder.model.layers.6.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
590
+ "backbone.encoder.model.layers.6.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
591
+ "backbone.encoder.model.layers.6.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
592
+ "backbone.encoder.model.layers.6.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
593
+ "backbone.encoder.model.layers.6.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
594
+ "backbone.encoder.model.layers.7.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
595
+ "backbone.encoder.model.layers.7.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
596
+ "backbone.encoder.model.layers.7.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
597
+ "backbone.encoder.model.layers.7.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
598
+ "backbone.encoder.model.layers.7.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
599
+ "backbone.encoder.model.layers.7.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
600
+ "backbone.encoder.model.layers.7.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
601
+ "backbone.encoder.model.layers.7.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
602
+ "backbone.encoder.model.layers.7.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
603
+ "backbone.encoder.model.layers.7.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
604
+ "backbone.encoder.model.layers.7.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
605
+ "backbone.encoder.model.layers.8.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
606
+ "backbone.encoder.model.layers.8.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
607
+ "backbone.encoder.model.layers.8.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
608
+ "backbone.encoder.model.layers.8.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
609
+ "backbone.encoder.model.layers.8.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
610
+ "backbone.encoder.model.layers.8.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
611
+ "backbone.encoder.model.layers.8.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
612
+ "backbone.encoder.model.layers.8.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
613
+ "backbone.encoder.model.layers.8.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
614
+ "backbone.encoder.model.layers.8.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
615
+ "backbone.encoder.model.layers.8.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
616
+ "backbone.encoder.model.layers.9.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
617
+ "backbone.encoder.model.layers.9.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
618
+ "backbone.encoder.model.layers.9.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
619
+ "backbone.encoder.model.layers.9.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
620
+ "backbone.encoder.model.layers.9.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
621
+ "backbone.encoder.model.layers.9.self_attn.k_norm.weight": "pytorch_model-00001-of-00002.bin",
622
+ "backbone.encoder.model.layers.9.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
623
+ "backbone.encoder.model.layers.9.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
624
+ "backbone.encoder.model.layers.9.self_attn.q_norm.weight": "pytorch_model-00001-of-00002.bin",
625
+ "backbone.encoder.model.layers.9.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
626
+ "backbone.encoder.model.layers.9.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
627
+ "backbone.encoder.model.norm.weight": "pytorch_model-00002-of-00002.bin",
628
+ "encoder_static_attention_mask": "pytorch_model-00001-of-00002.bin",
629
+ "static_attention_mask": "pytorch_model-00001-of-00002.bin"
630
+ }
631
+ }