|
|
import torch, torch.nn as nn, torch.nn.functional as F |
|
|
|
|
|
class DualCrossAttentionFusion(nn.Module): |
|
|
def __init__(self, d_sem: int, d_ac: int, d: int, n_heads: int = 4, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
self.sem_proj = nn.Linear(d_sem, d) |
|
|
self.ac_proj = nn.Linear(d_ac, d) |
|
|
self.ln_s = nn.LayerNorm(d); self.ln_a = nn.LayerNorm(d) |
|
|
self.att_sa = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True) |
|
|
self.att_as = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True) |
|
|
self.out_proj = nn.Linear(d*4, d) |
|
|
|
|
|
def forward(self, S, A): |
|
|
S_ = self.ln_s(self.sem_proj(S)) |
|
|
A_ = self.ln_a(self.ac_proj(A)) |
|
|
T = max(S_.size(1), A_.size(1)) |
|
|
if S_.size(1) != T: S_ = F.pad(S_, (0,0,0, T - S_.size(1))) |
|
|
if A_.size(1) != T: A_ = F.pad(A_, (0,0,0, T - A_.size(1))) |
|
|
Att_s, _ = self.att_sa(query=S_, key=A_, value=A_) |
|
|
Att_a, _ = self.att_as(query=A_, key=S_, value=S_) |
|
|
Fused = torch.cat([S_, A_, Att_s, Att_a], dim=-1) |
|
|
return self.out_proj(Fused) |
|
|
|
|
|
class FusedToLLMTokens(nn.Module): |
|
|
def __init__(self, d_in: int, d_llm: int, num_tokens: int = 8): |
|
|
super().__init__() |
|
|
self.num_tokens = num_tokens |
|
|
self.proj = nn.Linear(d_in, d_llm) |
|
|
def forward(self, F): |
|
|
B,T,D = F.shape |
|
|
splits = torch.linspace(0, T, steps=self.num_tokens + 1, device=F.device).long() |
|
|
toks = [] |
|
|
for i in range(self.num_tokens): |
|
|
s, e = splits[i].item(), splits[i+1].item() |
|
|
seg = F[:, s:e, :] if e > s else F[:, -1:, :] |
|
|
toks.append(seg.mean(dim=1)) |
|
|
X = torch.stack(toks, dim=1) |
|
|
return self.proj(X) |
|
|
|
|
|
class EAAEmotionRegressor(nn.Module): |
|
|
def __init__(self, d_sem: int, d_ac: int, llm_hidden: int, fusion_dim=512, n_heads=4, num_audio_tokens=8): |
|
|
super().__init__() |
|
|
self.fusion = DualCrossAttentionFusion(d_sem, d_ac, d=fusion_dim, n_heads=n_heads) |
|
|
self.f2tok = FusedToLLMTokens(d_in=fusion_dim, d_llm=llm_hidden, num_tokens=num_audio_tokens) |
|
|
self.reg_token = nn.Parameter(torch.randn(1, llm_hidden) * 0.02) |
|
|
self.reg_head = nn.Sequential( |
|
|
nn.Linear(llm_hidden, llm_hidden//2), |
|
|
nn.ReLU(), |
|
|
nn.Linear(llm_hidden//2, 1) |
|
|
) |
|
|
|
|
|
def forward(self, sem_feats: torch.Tensor, ac_feats: torch.Tensor, llm, input_ids=None, inputs_embeds=None): |
|
|
S = sem_feats |
|
|
A = ac_feats |
|
|
Fused = self.fusion(S, A) |
|
|
audio_tokens = self.f2tok(Fused) |
|
|
B = audio_tokens.size(0) |
|
|
reg = self.reg_token.unsqueeze(0).expand(B, -1, -1) |
|
|
if inputs_embeds is None and input_ids is not None: |
|
|
txt_embeds = llm.get_input_embeddings()(input_ids) |
|
|
else: |
|
|
txt_embeds = inputs_embeds |
|
|
llm_in = torch.cat([reg, txt_embeds, audio_tokens], dim=1) |
|
|
out = llm(inputs_embeds=llm_in, output_hidden_states=True) |
|
|
reg_state = out.hidden_states[-1][:, 0, :] |
|
|
return self.reg_head(reg_state).squeeze(-1) |