Update interactor.py
Browse files- interactor.py +3 -3
interactor.py
CHANGED
|
@@ -5,7 +5,7 @@ from torch import nn
|
|
| 5 |
|
| 6 |
|
| 7 |
|
| 8 |
-
class
|
| 9 |
def __init__(self,dim):
|
| 10 |
super().__init__()
|
| 11 |
|
|
@@ -55,7 +55,7 @@ class InteractorBlock(nn.Module):
|
|
| 55 |
super().__init__()
|
| 56 |
|
| 57 |
|
| 58 |
-
self.
|
| 59 |
self.interaction = InteractionUnit(d_model)
|
| 60 |
|
| 61 |
def forward(self, x):
|
|
@@ -68,7 +68,7 @@ class InteractorBlock(nn.Module):
|
|
| 68 |
|
| 69 |
residual = x
|
| 70 |
|
| 71 |
-
x = self.
|
| 72 |
|
| 73 |
|
| 74 |
out = x + residual
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
|
| 8 |
+
class MappingUnit(nn.Module):
|
| 9 |
def __init__(self,dim):
|
| 10 |
super().__init__()
|
| 11 |
|
|
|
|
| 55 |
super().__init__()
|
| 56 |
|
| 57 |
|
| 58 |
+
self.mapping = MappingUnit(d_model)
|
| 59 |
self.interaction = InteractionUnit(d_model)
|
| 60 |
|
| 61 |
def forward(self, x):
|
|
|
|
| 68 |
|
| 69 |
residual = x
|
| 70 |
|
| 71 |
+
x = self.mapping(x)
|
| 72 |
|
| 73 |
|
| 74 |
out = x + residual
|