marccgrau's picture
Add fusion head + config (best MAE 2.594)
2f3d362 verified
import torch, torch.nn as nn, torch.nn.functional as F
class DualCrossAttentionFusion(nn.Module):
def __init__(self, d_sem: int, d_ac: int, d: int, n_heads: int = 4, dropout: float = 0.1):
super().__init__()
self.sem_proj = nn.Linear(d_sem, d)
self.ac_proj = nn.Linear(d_ac, d)
self.ln_s = nn.LayerNorm(d); self.ln_a = nn.LayerNorm(d)
self.att_sa = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True)
self.att_as = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True)
self.out_proj = nn.Linear(d*4, d)
def forward(self, S, A):
S_ = self.ln_s(self.sem_proj(S))
A_ = self.ln_a(self.ac_proj(A))
T = max(S_.size(1), A_.size(1))
if S_.size(1) != T: S_ = F.pad(S_, (0,0,0, T - S_.size(1)))
if A_.size(1) != T: A_ = F.pad(A_, (0,0,0, T - A_.size(1)))
Att_s, _ = self.att_sa(query=S_, key=A_, value=A_)
Att_a, _ = self.att_as(query=A_, key=S_, value=S_)
Fused = torch.cat([S_, A_, Att_s, Att_a], dim=-1)
return self.out_proj(Fused)
class FusedToLLMTokens(nn.Module):
def __init__(self, d_in: int, d_llm: int, num_tokens: int = 8):
super().__init__()
self.num_tokens = num_tokens
self.proj = nn.Linear(d_in, d_llm)
def forward(self, F):
B,T,D = F.shape
splits = torch.linspace(0, T, steps=self.num_tokens + 1, device=F.device).long()
toks = []
for i in range(self.num_tokens):
s, e = splits[i].item(), splits[i+1].item()
seg = F[:, s:e, :] if e > s else F[:, -1:, :]
toks.append(seg.mean(dim=1))
X = torch.stack(toks, dim=1) # (B, N, D)
return self.proj(X) # (B, N, H)
class EAAEmotionRegressor(nn.Module):
def __init__(self, d_sem: int, d_ac: int, llm_hidden: int, fusion_dim=512, n_heads=4, num_audio_tokens=8):
super().__init__()
self.fusion = DualCrossAttentionFusion(d_sem, d_ac, d=fusion_dim, n_heads=n_heads)
self.f2tok = FusedToLLMTokens(d_in=fusion_dim, d_llm=llm_hidden, num_tokens=num_audio_tokens)
self.reg_token = nn.Parameter(torch.randn(1, llm_hidden) * 0.02)
self.reg_head = nn.Sequential(
nn.Linear(llm_hidden, llm_hidden//2),
nn.ReLU(),
nn.Linear(llm_hidden//2, 1)
)
def forward(self, sem_feats: torch.Tensor, ac_feats: torch.Tensor, llm, input_ids=None, inputs_embeds=None):
S = sem_feats
A = ac_feats
Fused = self.fusion(S, A) # (B, T, d_fuse)
audio_tokens = self.f2tok(Fused) # (B, N, H_llm)
B = audio_tokens.size(0)
reg = self.reg_token.unsqueeze(0).expand(B, -1, -1)
if inputs_embeds is None and input_ids is not None:
txt_embeds = llm.get_input_embeddings()(input_ids)
else:
txt_embeds = inputs_embeds
llm_in = torch.cat([reg, txt_embeds, audio_tokens], dim=1)
out = llm(inputs_embeds=llm_in, output_hidden_states=True)
reg_state = out.hidden_states[-1][:, 0, :]
return self.reg_head(reg_state).squeeze(-1)