Zhongfang Zhuang
Upload folder using huggingface_hub
a9dae3e verified
""" MLP module w/ dropout and configurable activation layer
Hacked together by / Copyright 2020 Ross Wightman
Modified by Ensemble AI to use NdLinear instead of Linear. Copyright 2025
"""
from functools import partial
from torch import nn as nn
from ndlinear import NdLinear
from timm.layers.grn import GlobalResponseNorm
from timm.layers.helpers import to_2tuple
class NdMlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.,
use_variant=4
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
self.use_variant = use_variant
drop_probs = to_2tuple(drop)
self.fc1 = NdLinear((in_features, 1), (hidden_features // 4, 1)) # (384, 1), (384, 1)
self.fc2 = NdLinear((in_features, 1), (hidden_features // 4, 1))
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x_dim0, x_dim1, x_dim2 = x.shape
# print(f"x.shape: {x.shape}")
x = x.reshape(x_dim0 * x_dim1, x_dim2, 1) if self.use_variant != 9 else x
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
# x = self.norm(x) #
x = self.fc2(x)
x = x.reshape(x_dim0, x_dim1, x_dim2) if self.use_variant != 9 else x
x = self.drop2(x)
return x
class GluMlp(nn.Module):
""" MLP w/ GLU style gating
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected.
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.Sigmoid,
norm_layer=None,
bias=True,
drop=0.,
use_conv=False,
gate_last=True,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
assert hidden_features % 2 == 0
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.chunk_dim = 1 if use_conv else -1
self.gate_last = gate_last # use second half of width for gate
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features // 2) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def init_weights(self):
# override init of fc1 w/ gate portion set to weight near zero, bias=1
if self.fc1.bias is not None:
nn.init.ones_(self.fc1.bias[self.fc1.bias.shape[0] // 2:])
nn.init.normal_(self.fc1.weight[self.fc1.weight.shape[0] // 2:], std=1e-6)
def forward(self, x):
x = self.fc1(x)
x1, x2 = x.chunk(2, dim=self.chunk_dim)
x = x1 * self.act(x2) if self.gate_last else self.act(x1) * x2
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
SwiGLUPacked = partial(GluMlp, act_layer=nn.SiLU, gate_last=False)
class SwiGLU(nn.Module):
""" SwiGLU
NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and
better matches some other common impl which makes mapping checkpoints simpler.
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.SiLU,
norm_layer=None,
bias=True,
drop=0.,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0])
self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def init_weights(self):
# override init of fc1 w/ gate portion set to weight near zero, bias=1
if self.fc1_g.bias is not None:
nn.init.ones_(self.fc1_g.bias)
nn.init.normal_(self.fc1_g.weight, std=1e-6)
def forward(self, x):
x_gate = self.fc1_g(x)
x = self.fc1_x(x)
x = self.act(x_gate) * x
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class GatedMlp(nn.Module):
""" MLP as used in gMLP
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
gate_layer=None,
bias=True,
drop=0.,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
if gate_layer is not None:
assert hidden_features % 2 == 0
self.gate = gate_layer(hidden_features)
hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
else:
self.gate = nn.Identity()
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.gate(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class ConvMlp(nn.Module):
""" MLP using 1x1 convs that keeps spatial dims (for 2D NCHW tensors)
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.ReLU,
norm_layer=None,
bias=True,
drop=0.,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0])
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
self.act = act_layer()
self.drop = nn.Dropout(drop)
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1])
def forward(self, x):
x = self.fc1(x)
x = self.norm(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
return x
class GlobalResponseNormMlp(nn.Module):
""" MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d
NOTE: Intended for '2D' NCHW (use_conv=True) or NHWC (use_conv=False, channels-last) tensor layouts
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.grn = GlobalResponseNorm(hidden_features, channels_last=not use_conv)
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.grn(x)
x = self.fc2(x)
x = self.drop2(x)
return x