File size: 1,919 Bytes
c6195c6
 
 
 
 
 
 
5c2da1c
c6195c6
 
 
e59cbdd
 
d3cadd1
 
 
075a642
c6195c6
 
 
e59cbdd
 
 
075a642
e59cbdd
 
 
075a642
c6195c6
 
 
 
ae032db
c6195c6
075a642
c6195c6
075a642
c6195c6
 
 
 
075a642
 
 
 
 
 
 
 
c6195c6
 
 
 
4b6d4a9
c6195c6
 
 
5c2da1c
205ae8e
c6195c6
 
 
 
 
 
 
 
 
 
 
5c2da1c
c6195c6
 
 
 
 
 
 
 
 
 
4b6d4a9
c6195c6
 
 
4b6d4a9
c6195c6
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from torch import nn


       
 

class MappingUnit(nn.Module):
    def __init__(self,dim):
        super().__init__()
        
           
        self.norm_token = nn.LayerNorm(dim)
        self.proj_1 =  nn.Linear(dim,dim,bias = False)
        self.proj_2 =  nn.Linear(dim,dim,bias = False)
        self.proj_3 = nn.Linear(dim,dim,bias = False)  
        self.gelu = nn.GELU()
             	   
    def forward(self, x):
    
    	x = self.norm_token(x)    	
    	u, v = x, x 
    	u = self.proj_1(u)
    	u = self.gelu(u)
    	v = self.proj_2(v)
    	g = u * v
    	x = self.proj_3(g)
    	
    	
    	return x

class InteractionUnit(nn.Module):
    def __init__(self,dim):
        super().__init__()
            
        self.norm_token = nn.LayerNorm(dim)       
        self.gelu = nn.GELU()
             	   
    def forward(self, x):
    
    	x = self.norm_token(x)
    	dim0 = x.shape[0]
    	dim1 = x.shape[1]
    	dim2 = x.shape[2]
    	x = x.reshape([dim0,dim1*dim2])
    	x = self.gelu(x)
    	x = x.reshape([dim0,dim1,dim2])
    	
    	
    	
    	return x

class InteractorBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
       
         
        self.mapping = MappingUnit(d_model)
        self.interaction = InteractionUnit(d_model)
        
    def forward(self, x):
                  
        residual = x
        
        x = self.interaction(x)
    
        x = x + residual
        
        residual = x
        
        x = self.mapping(x)
        
                                          
        out = x + residual
        
        
        return out



class Interactor(nn.Module):
    def __init__(self, d_model, num_layers):
        super().__init__()
        
        self.model = nn.Sequential(
            *[InteractorBlock(d_model) for _ in range(num_layers)]
        )

    def forward(self, x):
       
        return self.model(x)