MCDP_MAMBA / mamba.py
Abdullah-Nazhat's picture
Update mamba.py
e787de7 verified
import math
from dataclasses import dataclass
from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from pscan import pscan
@dataclass
class MambaConfig:
d_model: int
n_layers: int
dt_rank: Union[int, str] = 'auto'
d_state: int = 16
expand_factor: int = 2
d_conv: int = 4
dt_min: float = 0.001
dt_max: float = 0.1
dt_init: str = "random"
dt_scale: float = 1.0
dt_init_floor = 1e-4
bias: bool = False
conv_bias: bool = True
pscan: bool = True
def __post_init__(self):
self.d_inner = self.expand_factor * self.d_model
if self.dt_rank == 'auto':
self.dt_rank = math.ceil(self.d_model / 16)
class Mamba(nn.Module):
def __init__(self, config: MambaConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def step(self, x, caches):
for i, layer in enumerate(self.layers):
x, caches[i] = layer.step(x, caches[i])
return x, caches
class ResidualBlock(nn.Module):
def __init__(self, config: MambaConfig):
super().__init__()
self.mixer = MambaBlock(config)
self.norm = RMSNorm(config.d_model)
def forward(self, x):
output = self.mixer(self.norm(x)) + x
return output
def step(self, x, cache):
output, cache = self.mixer.step(self.norm(x), cache)
output = output + x
return output, cache
class MambaBlock(nn.Module):
def __init__(self, config: MambaConfig):
super().__init__()
self.config = config
self.in_proj = nn.Linear(config.d_model, 2 * config.d_inner, bias=config.bias)
self.conv1d = nn.Conv1d(in_channels=config.d_inner, out_channels=config.d_inner,
kernel_size=config.d_conv, bias=config.conv_bias,
groups=config.d_inner,
padding=config.d_conv - 1)
self.x_proj = nn.Linear(config.d_inner, config.dt_rank + 2 * config.d_state, bias=False)
self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True)
dt_init_std = config.dt_rank**-0.5 * config.dt_scale
if config.dt_init == "constant":
nn.init.constant_(self.dt_proj.weight, dt_init_std)
elif config.dt_init == "random":
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
dt = torch.exp(
torch.rand(config.d_inner) * (math.log(config.dt_max) - math.log(config.dt_min)) + math.log(config.dt_min)
).clamp(min=config.dt_init_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
A = torch.arange(1, config.d_state + 1, dtype=torch.float32).repeat(config.d_inner, 1)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(config.d_inner))
self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
def forward(self, x):
_, L, _ = x.shape
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1)
x = x.transpose(1, 2)
x = self.conv1d(x)[:, :, :L]
x = x.transpose(1, 2)
x = F.silu(x)
y = self.ssm(x)
z = F.silu(z)
output = y * z
output = self.out_proj(output)
return output
def ssm(self, x):
A = -torch.exp(self.A_log.float())
D = self.D.float()
deltaBC = self.x_proj(x)
delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1)
delta = F.softplus(self.dt_proj(delta))
if self.config.pscan:
y = self.selective_scan(x, delta, A, B, C, D)
else:
y = self.selective_scan_seq(x, delta, A, B, C, D)
return y
def selective_scan(self, x, delta, A, B, C, D):
deltaA = torch.exp(delta.unsqueeze(-1) * A)
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)
BX = deltaB * (x.unsqueeze(-1))
hs = pscan(deltaA, BX)
y = (hs @ C.unsqueeze(-1)).squeeze(3)
y = y + D * x
return y
def selective_scan_seq(self, x, delta, A, B, C, D):
_, L, _ = x.shape
deltaA = torch.exp(delta.unsqueeze(-1) * A)
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)
BX = deltaB * (x.unsqueeze(-1))
h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device)
hs = []
for t in range(0, L):
h = deltaA[:, t] * h + BX[:, t]
hs.append(h)
hs = torch.stack(hs, dim=1)
y = (hs @ C.unsqueeze(-1)).squeeze(3)
y = y + D * x
return y
def step(self, x, cache):
h, inputs = cache
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=1)
x_cache = x.unsqueeze(2)
x = self.conv1d(torch.cat([inputs, x_cache], dim=2))[:, :, self.config.d_conv-1]
x = F.silu(x)
y, h = self.ssm_step(x, h)
z = F.silu(z)
output = y * z
output = self.out_proj(output)
inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2)
cache = (h, inputs)
return output, cache
def ssm_step(self, x, h):
A = -torch.exp(self.A_log.float())
D = self.D.float()
deltaBC = self.x_proj(x)
delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1)
delta = F.softplus(self.dt_proj(delta))
deltaA = torch.exp(delta.unsqueeze(-1) * A)
deltaB = delta.unsqueeze(-1) * B.unsqueeze(1)
BX = deltaB * (x.unsqueeze(-1))
if h is None:
h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device)
h = deltaA * h + BX
y = (h @ C.unsqueeze(-1)).squeeze(2)
y = y + D * x
return y, h.squeeze(1)
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
return output