""" MLP module w/ dropout and configurable activation layer Hacked together by / Copyright 2020 Ross Wightman Modified by Ensemble AI to use NdLinear instead of Linear. Copyright 2025 """ from functools import partial from torch import nn as nn from ndlinear import NdLinear from timm.layers.grn import GlobalResponseNorm from timm.layers.helpers import to_2tuple class NdMlp(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, norm_layer=None, bias=True, drop=0., use_variant=4 ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) self.use_variant = use_variant drop_probs = to_2tuple(drop) self.fc1 = NdLinear((in_features, 1), (hidden_features // 4, 1)) # (384, 1), (384, 1) self.fc2 = NdLinear((in_features, 1), (hidden_features // 4, 1)) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x_dim0, x_dim1, x_dim2 = x.shape # print(f"x.shape: {x.shape}") x = x.reshape(x_dim0 * x_dim1, x_dim2, 1) if self.use_variant != 9 else x x = self.fc1(x) x = self.act(x) x = self.drop1(x) # x = self.norm(x) # x = self.fc2(x) x = x.reshape(x_dim0, x_dim1, x_dim2) if self.use_variant != 9 else x x = self.drop2(x) return x class GluMlp(nn.Module): """ MLP w/ GLU style gating See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected. """ def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, norm_layer=None, bias=True, drop=0., use_conv=False, gate_last=True, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features assert hidden_features % 2 == 0 bias = to_2tuple(bias) drop_probs = to_2tuple(drop) linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear self.chunk_dim = 1 if use_conv else -1 self.gate_last = gate_last # use second half of width for gate self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.norm = norm_layer(hidden_features // 2) if norm_layer is not None else nn.Identity() self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def init_weights(self): # override init of fc1 w/ gate portion set to weight near zero, bias=1 if self.fc1.bias is not None: nn.init.ones_(self.fc1.bias[self.fc1.bias.shape[0] // 2:]) nn.init.normal_(self.fc1.weight[self.fc1.weight.shape[0] // 2:], std=1e-6) def forward(self, x): x = self.fc1(x) x1, x2 = x.chunk(2, dim=self.chunk_dim) x = x1 * self.act(x2) if self.gate_last else self.act(x1) * x2 x = self.drop1(x) x = self.norm(x) x = self.fc2(x) x = self.drop2(x) return x SwiGLUPacked = partial(GluMlp, act_layer=nn.SiLU, gate_last=False) class SwiGLU(nn.Module): """ SwiGLU NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and better matches some other common impl which makes mapping checkpoints simpler. """ def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, norm_layer=None, bias=True, drop=0., ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0]) self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def init_weights(self): # override init of fc1 w/ gate portion set to weight near zero, bias=1 if self.fc1_g.bias is not None: nn.init.ones_(self.fc1_g.bias) nn.init.normal_(self.fc1_g.weight, std=1e-6) def forward(self, x): x_gate = self.fc1_g(x) x = self.fc1_x(x) x = self.act(x_gate) * x x = self.drop1(x) x = self.norm(x) x = self.fc2(x) x = self.drop2(x) return x class GatedMlp(nn.Module): """ MLP as used in gMLP """ def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, norm_layer=None, gate_layer=None, bias=True, drop=0., ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) if gate_layer is not None: assert hidden_features % 2 == 0 self.gate = gate_layer(hidden_features) hidden_features = hidden_features // 2 # FIXME base reduction on gate property? else: self.gate = nn.Identity() self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.gate(x) x = self.norm(x) x = self.fc2(x) x = self.drop2(x) return x class ConvMlp(nn.Module): """ MLP using 1x1 convs that keeps spatial dims (for 2D NCHW tensors) """ def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, bias=True, drop=0., ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() self.act = act_layer() self.drop = nn.Dropout(drop) self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) def forward(self, x): x = self.fc1(x) x = self.norm(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) return x class GlobalResponseNormMlp(nn.Module): """ MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d NOTE: Intended for '2D' NCHW (use_conv=True) or NHWC (use_conv=False, channels-last) tensor layouts """ def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0., use_conv=False, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features bias = to_2tuple(bias) drop_probs = to_2tuple(drop) linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.grn = GlobalResponseNorm(hidden_features, channels_last=not use_conv) self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.grn(x) x = self.fc2(x) x = self.drop2(x) return x