import torch, torch.nn as nn, torch.nn.functional as F class DualCrossAttentionFusion(nn.Module): def __init__(self, d_sem: int, d_ac: int, d: int, n_heads: int = 4, dropout: float = 0.1): super().__init__() self.sem_proj = nn.Linear(d_sem, d) self.ac_proj = nn.Linear(d_ac, d) self.ln_s = nn.LayerNorm(d); self.ln_a = nn.LayerNorm(d) self.att_sa = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True) self.att_as = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True) self.out_proj = nn.Linear(d*4, d) def forward(self, S, A): S_ = self.ln_s(self.sem_proj(S)) A_ = self.ln_a(self.ac_proj(A)) T = max(S_.size(1), A_.size(1)) if S_.size(1) != T: S_ = F.pad(S_, (0,0,0, T - S_.size(1))) if A_.size(1) != T: A_ = F.pad(A_, (0,0,0, T - A_.size(1))) Att_s, _ = self.att_sa(query=S_, key=A_, value=A_) Att_a, _ = self.att_as(query=A_, key=S_, value=S_) Fused = torch.cat([S_, A_, Att_s, Att_a], dim=-1) return self.out_proj(Fused) class FusedToLLMTokens(nn.Module): def __init__(self, d_in: int, d_llm: int, num_tokens: int = 8): super().__init__() self.num_tokens = num_tokens self.proj = nn.Linear(d_in, d_llm) def forward(self, F): B,T,D = F.shape splits = torch.linspace(0, T, steps=self.num_tokens + 1, device=F.device).long() toks = [] for i in range(self.num_tokens): s, e = splits[i].item(), splits[i+1].item() seg = F[:, s:e, :] if e > s else F[:, -1:, :] toks.append(seg.mean(dim=1)) X = torch.stack(toks, dim=1) # (B, N, D) return self.proj(X) # (B, N, H) class EAAEmotionRegressor(nn.Module): def __init__(self, d_sem: int, d_ac: int, llm_hidden: int, fusion_dim=512, n_heads=4, num_audio_tokens=8): super().__init__() self.fusion = DualCrossAttentionFusion(d_sem, d_ac, d=fusion_dim, n_heads=n_heads) self.f2tok = FusedToLLMTokens(d_in=fusion_dim, d_llm=llm_hidden, num_tokens=num_audio_tokens) self.reg_token = nn.Parameter(torch.randn(1, llm_hidden) * 0.02) self.reg_head = nn.Sequential( nn.Linear(llm_hidden, llm_hidden//2), nn.ReLU(), nn.Linear(llm_hidden//2, 1) ) def forward(self, sem_feats: torch.Tensor, ac_feats: torch.Tensor, llm, input_ids=None, inputs_embeds=None): S = sem_feats A = ac_feats Fused = self.fusion(S, A) # (B, T, d_fuse) audio_tokens = self.f2tok(Fused) # (B, N, H_llm) B = audio_tokens.size(0) reg = self.reg_token.unsqueeze(0).expand(B, -1, -1) if inputs_embeds is None and input_ids is not None: txt_embeds = llm.get_input_embeddings()(input_ids) else: txt_embeds = inputs_embeds llm_in = torch.cat([reg, txt_embeds, audio_tokens], dim=1) out = llm(inputs_embeds=llm_in, output_hidden_states=True) reg_state = out.hidden_states[-1][:, 0, :] return self.reg_head(reg_state).squeeze(-1)