Upload 2 files
Browse files- yuan_hf_model.py +61 -13
- yuan_hf_model_cpu.py +60 -12
yuan_hf_model.py
CHANGED
|
@@ -25,7 +25,6 @@ import torch
|
|
| 25 |
import torch.utils.checkpoint
|
| 26 |
from torch import nn
|
| 27 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 28 |
-
from transformers.models.llama.modeling_llama import LlamaRMSNorm,LlamaRotaryEmbedding
|
| 29 |
from transformers.activations import ACT2FN
|
| 30 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 31 |
from transformers.modeling_utils import PreTrainedModel
|
|
@@ -58,9 +57,7 @@ class LocalizedFiltering(torch.nn.Module):
|
|
| 58 |
|
| 59 |
self.conv1 = torch.nn.Conv2d(self.embed_dim, self.embed_dim // 2, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
|
| 60 |
self.conv2 = torch.nn.Conv2d(self.embed_dim // 2, self.embed_dim, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
|
| 61 |
-
|
| 62 |
-
#Use the same RMSNorm as llama
|
| 63 |
-
self.output_layernorm = LlamaRMSNorm(self.embed_dim)
|
| 64 |
|
| 65 |
def _train_forward(self, inputs):
|
| 66 |
inputs = inputs.transpose(0,1)
|
|
@@ -197,7 +194,61 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
|
| 197 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 198 |
return q_embed, k_embed
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
class YuanMLP(nn.Module):
|
| 203 |
def __init__(
|
|
@@ -240,8 +291,7 @@ class YuanAttention(nn.Module):
|
|
| 240 |
)
|
| 241 |
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 242 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 243 |
-
|
| 244 |
-
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
| 245 |
if self.use_shareqk:
|
| 246 |
self.qk_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 247 |
self.qk_weight = nn.Parameter(torch.Tensor(2, self.hidden_size))
|
|
@@ -268,8 +318,8 @@ class YuanAttention(nn.Module):
|
|
| 268 |
is_first_step = False
|
| 269 |
if use_cache:
|
| 270 |
if past_key_value is None:
|
| 271 |
-
inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype ,device=torch.cuda.current_device())
|
| 272 |
-
|
| 273 |
is_first_step = True
|
| 274 |
else:
|
| 275 |
before_hidden_states = past_key_value[2]
|
|
@@ -393,9 +443,8 @@ class YuanDecoderLayer(nn.Module):
|
|
| 393 |
intermediate_size=config.intermediate_size,
|
| 394 |
hidden_act=config.hidden_act,
|
| 395 |
)
|
| 396 |
-
|
| 397 |
-
self.
|
| 398 |
-
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 399 |
|
| 400 |
def forward(
|
| 401 |
self,
|
|
@@ -583,8 +632,7 @@ class YuanModel(YuanPreTrainedModel):
|
|
| 583 |
self.reset_position_ids = config.reset_position_ids
|
| 584 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 585 |
self.layers = nn.ModuleList([YuanDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 586 |
-
|
| 587 |
-
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 588 |
self.gradient_checkpointing = False
|
| 589 |
# Initialize weights and apply final processing
|
| 590 |
self.post_init()
|
|
|
|
| 25 |
import torch.utils.checkpoint
|
| 26 |
from torch import nn
|
| 27 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
| 28 |
from transformers.activations import ACT2FN
|
| 29 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 30 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
| 57 |
|
| 58 |
self.conv1 = torch.nn.Conv2d(self.embed_dim, self.embed_dim // 2, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
|
| 59 |
self.conv2 = torch.nn.Conv2d(self.embed_dim // 2, self.embed_dim, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
|
| 60 |
+
self.output_layernorm = YuanRMSNorm(self.embed_dim)
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def _train_forward(self, inputs):
|
| 63 |
inputs = inputs.transpose(0,1)
|
|
|
|
| 194 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 195 |
return q_embed, k_embed
|
| 196 |
|
| 197 |
+
class YuanRMSNorm(nn.Module):
|
| 198 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 199 |
+
"""
|
| 200 |
+
YuanRMSNorm is equivalent to LlamaRMSNorm
|
| 201 |
+
"""
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 204 |
+
self.variance_epsilon = eps
|
| 205 |
+
|
| 206 |
+
def forward(self, hidden_states):
|
| 207 |
+
input_dtype = hidden_states.dtype
|
| 208 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 209 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 210 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 211 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 212 |
+
|
| 213 |
+
class YuanRotaryEmbedding(torch.nn.Module):
|
| 214 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 215 |
+
|
| 216 |
+
"""
|
| 217 |
+
YuanRotaryEmbedding is equivalent to LlamaRotaryEmbedding in transformers v4.36
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
super().__init__()
|
| 221 |
+
|
| 222 |
+
self.dim = dim
|
| 223 |
+
self.max_position_embeddings = max_position_embeddings
|
| 224 |
+
self.base = base
|
| 225 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 226 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 227 |
+
|
| 228 |
+
# Build here to make `torch.jit.trace` work.
|
| 229 |
+
self._set_cos_sin_cache(
|
| 230 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
| 231 |
+
)
|
| 232 |
|
| 233 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 234 |
+
self.max_seq_len_cached = seq_len
|
| 235 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 236 |
+
|
| 237 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 238 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 239 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 240 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
| 241 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
| 242 |
+
|
| 243 |
+
def forward(self, x, seq_len=None):
|
| 244 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 245 |
+
if seq_len > self.max_seq_len_cached:
|
| 246 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 247 |
+
|
| 248 |
+
return (
|
| 249 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
| 250 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
| 251 |
+
)
|
| 252 |
|
| 253 |
class YuanMLP(nn.Module):
|
| 254 |
def __init__(
|
|
|
|
| 291 |
)
|
| 292 |
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 293 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 294 |
+
self.rotary_emb = YuanRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
|
|
|
| 295 |
if self.use_shareqk:
|
| 296 |
self.qk_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 297 |
self.qk_weight = nn.Parameter(torch.Tensor(2, self.hidden_size))
|
|
|
|
| 318 |
is_first_step = False
|
| 319 |
if use_cache:
|
| 320 |
if past_key_value is None:
|
| 321 |
+
#inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype ,device=torch.cuda.current_device())
|
| 322 |
+
inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype)
|
| 323 |
is_first_step = True
|
| 324 |
else:
|
| 325 |
before_hidden_states = past_key_value[2]
|
|
|
|
| 443 |
intermediate_size=config.intermediate_size,
|
| 444 |
hidden_act=config.hidden_act,
|
| 445 |
)
|
| 446 |
+
self.input_layernorm = YuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 447 |
+
self.post_attention_layernorm = YuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
| 448 |
|
| 449 |
def forward(
|
| 450 |
self,
|
|
|
|
| 632 |
self.reset_position_ids = config.reset_position_ids
|
| 633 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 634 |
self.layers = nn.ModuleList([YuanDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 635 |
+
self.norm = YuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
| 636 |
self.gradient_checkpointing = False
|
| 637 |
# Initialize weights and apply final processing
|
| 638 |
self.post_init()
|
yuan_hf_model_cpu.py
CHANGED
|
@@ -25,7 +25,6 @@ import torch
|
|
| 25 |
import torch.utils.checkpoint
|
| 26 |
from torch import nn
|
| 27 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 28 |
-
from transformers.models.llama.modeling_llama import LlamaRMSNorm,LlamaRotaryEmbedding
|
| 29 |
from transformers.activations import ACT2FN
|
| 30 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 31 |
from transformers.modeling_utils import PreTrainedModel
|
|
@@ -58,9 +57,7 @@ class LocalizedFiltering(torch.nn.Module):
|
|
| 58 |
|
| 59 |
self.conv1 = torch.nn.Conv2d(self.embed_dim, self.embed_dim // 2, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
|
| 60 |
self.conv2 = torch.nn.Conv2d(self.embed_dim // 2, self.embed_dim, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
|
| 61 |
-
|
| 62 |
-
#Use the same RMSNorm as llama
|
| 63 |
-
self.output_layernorm = LlamaRMSNorm(self.embed_dim)
|
| 64 |
|
| 65 |
def _train_forward(self, inputs):
|
| 66 |
inputs = inputs.transpose(0,1)
|
|
@@ -197,7 +194,61 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
|
| 197 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 198 |
return q_embed, k_embed
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
class YuanMLP(nn.Module):
|
| 203 |
def __init__(
|
|
@@ -240,8 +291,7 @@ class YuanAttention(nn.Module):
|
|
| 240 |
)
|
| 241 |
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 242 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 243 |
-
|
| 244 |
-
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
| 245 |
if self.use_shareqk:
|
| 246 |
self.qk_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 247 |
self.qk_weight = nn.Parameter(torch.Tensor(2, self.hidden_size))
|
|
@@ -268,7 +318,7 @@ class YuanAttention(nn.Module):
|
|
| 268 |
is_first_step = False
|
| 269 |
if use_cache:
|
| 270 |
if past_key_value is None:
|
| 271 |
-
#
|
| 272 |
inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype)
|
| 273 |
is_first_step = True
|
| 274 |
else:
|
|
@@ -393,9 +443,8 @@ class YuanDecoderLayer(nn.Module):
|
|
| 393 |
intermediate_size=config.intermediate_size,
|
| 394 |
hidden_act=config.hidden_act,
|
| 395 |
)
|
| 396 |
-
|
| 397 |
-
self.
|
| 398 |
-
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 399 |
|
| 400 |
def forward(
|
| 401 |
self,
|
|
@@ -583,8 +632,7 @@ class YuanModel(YuanPreTrainedModel):
|
|
| 583 |
self.reset_position_ids = config.reset_position_ids
|
| 584 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 585 |
self.layers = nn.ModuleList([YuanDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 586 |
-
|
| 587 |
-
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 588 |
self.gradient_checkpointing = False
|
| 589 |
# Initialize weights and apply final processing
|
| 590 |
self.post_init()
|
|
|
|
| 25 |
import torch.utils.checkpoint
|
| 26 |
from torch import nn
|
| 27 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
| 28 |
from transformers.activations import ACT2FN
|
| 29 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
| 30 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
| 57 |
|
| 58 |
self.conv1 = torch.nn.Conv2d(self.embed_dim, self.embed_dim // 2, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
|
| 59 |
self.conv2 = torch.nn.Conv2d(self.embed_dim // 2, self.embed_dim, (2, 1), stride=(1, 1), padding=(self.lf_conv2d_num_pad, 0), groups=self.lf_conv2d_group)
|
| 60 |
+
self.output_layernorm = YuanRMSNorm(self.embed_dim)
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def _train_forward(self, inputs):
|
| 63 |
inputs = inputs.transpose(0,1)
|
|
|
|
| 194 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 195 |
return q_embed, k_embed
|
| 196 |
|
| 197 |
+
class YuanRMSNorm(nn.Module):
|
| 198 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 199 |
+
"""
|
| 200 |
+
YuanRMSNorm is equivalent to LlamaRMSNorm
|
| 201 |
+
"""
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 204 |
+
self.variance_epsilon = eps
|
| 205 |
+
|
| 206 |
+
def forward(self, hidden_states):
|
| 207 |
+
input_dtype = hidden_states.dtype
|
| 208 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 209 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 210 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 211 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 212 |
+
|
| 213 |
+
class YuanRotaryEmbedding(torch.nn.Module):
|
| 214 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 215 |
+
|
| 216 |
+
"""
|
| 217 |
+
YuanRotaryEmbedding is equivalent to LlamaRotaryEmbedding in transformers v4.36
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
super().__init__()
|
| 221 |
+
|
| 222 |
+
self.dim = dim
|
| 223 |
+
self.max_position_embeddings = max_position_embeddings
|
| 224 |
+
self.base = base
|
| 225 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
| 226 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 227 |
+
|
| 228 |
+
# Build here to make `torch.jit.trace` work.
|
| 229 |
+
self._set_cos_sin_cache(
|
| 230 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
| 231 |
+
)
|
| 232 |
|
| 233 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 234 |
+
self.max_seq_len_cached = seq_len
|
| 235 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 236 |
+
|
| 237 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 238 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 239 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 240 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
| 241 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
| 242 |
+
|
| 243 |
+
def forward(self, x, seq_len=None):
|
| 244 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 245 |
+
if seq_len > self.max_seq_len_cached:
|
| 246 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 247 |
+
|
| 248 |
+
return (
|
| 249 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
| 250 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
| 251 |
+
)
|
| 252 |
|
| 253 |
class YuanMLP(nn.Module):
|
| 254 |
def __init__(
|
|
|
|
| 291 |
)
|
| 292 |
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 293 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 294 |
+
self.rotary_emb = YuanRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
|
|
|
| 295 |
if self.use_shareqk:
|
| 296 |
self.qk_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 297 |
self.qk_weight = nn.Parameter(torch.Tensor(2, self.hidden_size))
|
|
|
|
| 318 |
is_first_step = False
|
| 319 |
if use_cache:
|
| 320 |
if past_key_value is None:
|
| 321 |
+
#inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype ,device=torch.cuda.current_device())
|
| 322 |
inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype)
|
| 323 |
is_first_step = True
|
| 324 |
else:
|
|
|
|
| 443 |
intermediate_size=config.intermediate_size,
|
| 444 |
hidden_act=config.hidden_act,
|
| 445 |
)
|
| 446 |
+
self.input_layernorm = YuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 447 |
+
self.post_attention_layernorm = YuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
| 448 |
|
| 449 |
def forward(
|
| 450 |
self,
|
|
|
|
| 632 |
self.reset_position_ids = config.reset_position_ids
|
| 633 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 634 |
self.layers = nn.ModuleList([YuanDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 635 |
+
self.norm = YuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
| 636 |
self.gradient_checkpointing = False
|
| 637 |
# Initialize weights and apply final processing
|
| 638 |
self.post_init()
|