File size: 1,905 Bytes
fdb2753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

import math, torch
import torch.nn as nn
from einops import repeat

class BridgeAttentionPolicy(nn.Module):
    def __init__(self, v_hidden, t_hidden, state_dim, policy_dim, n_heads, n_layers, n_queries, action_dim, dropout=0.1):
        super().__init__()
        self.n_queries = n_queries
        self.query = nn.Parameter(torch.randn(n_queries, policy_dim) / math.sqrt(policy_dim))
        self.v_proj = nn.Linear(v_hidden, policy_dim)
        self.t_proj = nn.Linear(t_hidden, policy_dim)
        self.s_proj = nn.Linear(state_dim, policy_dim)
        self.alpha_v = nn.Parameter(torch.tensor(0.7))
        self.alpha_t = nn.Parameter(torch.tensor(0.7))
        self.alpha_s = nn.Parameter(torch.tensor(0.7))
        enc = nn.TransformerEncoderLayer(d_model=policy_dim, nhead=n_heads, dim_feedforward=policy_dim*4,
                                         dropout=dropout, activation="gelu", batch_first=True, norm_first=True)
        self.blocks = nn.TransformerEncoder(enc, num_layers=n_layers)
        self.norm = nn.LayerNorm(policy_dim)
        self.head = nn.Sequential(nn.Linear(policy_dim, policy_dim), nn.GELU(), nn.Linear(policy_dim, action_dim))

    def forward(self, v_feats_layers, t_feats_layers, state_vec):
        B = state_vec.size(0)
        v_cat = torch.cat(v_feats_layers, dim=1) if v_feats_layers else None
        t_cat = torch.cat(t_feats_layers, dim=1)
        s_tok = self.s_proj(state_vec).unsqueeze(1)
        toks = [s_tok]
        if v_cat is not None:
            toks.append(self.v_proj(v_cat) * torch.sigmoid(self.alpha_v))
        toks.append(self.t_proj(t_cat) * torch.sigmoid(self.alpha_t))
        ctx = torch.cat(toks, dim=1)
        q = repeat(self.query, 'Q D -> B Q D', B=B)
        tokens = torch.cat([q, ctx], dim=1)
        tokens = self.blocks(tokens)
        pooled = self.norm(tokens[:, :self.n_queries].mean(dim=1))
        return self.head(pooled)