File size: 1,919 Bytes
c6195c6 5c2da1c c6195c6 e59cbdd d3cadd1 075a642 c6195c6 e59cbdd 075a642 e59cbdd 075a642 c6195c6 ae032db c6195c6 075a642 c6195c6 075a642 c6195c6 075a642 c6195c6 4b6d4a9 c6195c6 5c2da1c 205ae8e c6195c6 5c2da1c c6195c6 4b6d4a9 c6195c6 4b6d4a9 c6195c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import torch
from torch import nn
class MappingUnit(nn.Module):
def __init__(self,dim):
super().__init__()
self.norm_token = nn.LayerNorm(dim)
self.proj_1 = nn.Linear(dim,dim,bias = False)
self.proj_2 = nn.Linear(dim,dim,bias = False)
self.proj_3 = nn.Linear(dim,dim,bias = False)
self.gelu = nn.GELU()
def forward(self, x):
x = self.norm_token(x)
u, v = x, x
u = self.proj_1(u)
u = self.gelu(u)
v = self.proj_2(v)
g = u * v
x = self.proj_3(g)
return x
class InteractionUnit(nn.Module):
def __init__(self,dim):
super().__init__()
self.norm_token = nn.LayerNorm(dim)
self.gelu = nn.GELU()
def forward(self, x):
x = self.norm_token(x)
dim0 = x.shape[0]
dim1 = x.shape[1]
dim2 = x.shape[2]
x = x.reshape([dim0,dim1*dim2])
x = self.gelu(x)
x = x.reshape([dim0,dim1,dim2])
return x
class InteractorBlock(nn.Module):
def __init__(self, d_model):
super().__init__()
self.mapping = MappingUnit(d_model)
self.interaction = InteractionUnit(d_model)
def forward(self, x):
residual = x
x = self.interaction(x)
x = x + residual
residual = x
x = self.mapping(x)
out = x + residual
return out
class Interactor(nn.Module):
def __init__(self, d_model, num_layers):
super().__init__()
self.model = nn.Sequential(
*[InteractorBlock(d_model) for _ in range(num_layers)]
)
def forward(self, x):
return self.model(x)
|