Add fusion head + config (best MAE 2.594)
Browse files- README.md +40 -0
- eaa_config.json +23 -0
- fusion_head.pt +3 -0
- modeling_eaa.py +66 -0
README.md
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EAA Fusion Head for Gemma (LoRA) + w2v-bert-2.0 + emotion2vec
|
| 2 |
+
|
| 3 |
+
This repo hosts the **fusion head** weights and code for the Emotion-Aware Audio LLM.
|
| 4 |
+
- LoRA adapter lives at: **marccgrau/eaa-gemma3-270m-adapter**
|
| 5 |
+
- Upstream encoders: `facebook/w2v-bert-2.0` (semantic) and `iic/emotion2vec_base` (acoustic via FunASR)
|
| 6 |
+
- LLM: `google/gemma-3-270m`
|
| 7 |
+
|
| 8 |
+
## Files
|
| 9 |
+
- `fusion_head.pt` — PyTorch state_dict of the fusion/regression head
|
| 10 |
+
- `eaa_config.json` — minimal config (IDs, dims, hyperparams)
|
| 11 |
+
- `modeling_eaa.py` — the fusion architecture (Dual X-Attn + pooling + [REG] head)
|
| 12 |
+
|
| 13 |
+
## Quickload (Python)
|
| 14 |
+
```python
|
| 15 |
+
import torch, json
|
| 16 |
+
from huggingface_hub import hf_hub_download
|
| 17 |
+
from modeling_eaa import EAAEmotionRegressor
|
| 18 |
+
|
| 19 |
+
# Download artifacts
|
| 20 |
+
cfg_path = hf_hub_download(repo_id="marccgrau/eaa-gemma3-270m-w2vbert-emotion2vec", filename="eaa_config.json")
|
| 21 |
+
with open(cfg_path) as f:
|
| 22 |
+
cfg = json.load(f)
|
| 23 |
+
|
| 24 |
+
# Recreate Gemma + load LoRA adapter
|
| 25 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 26 |
+
from peft import PeftModel
|
| 27 |
+
tok = AutoTokenizer.from_pretrained(cfg["gemma_id"], trust_remote_code=True)
|
| 28 |
+
llm_base = AutoModelForCausalLM.from_pretrained(cfg["gemma_id"], trust_remote_code=True, torch_dtype=torch.float16).cuda()
|
| 29 |
+
llm = PeftModel.from_pretrained(llm_base, cfg["adapter_repo"]).eval()
|
| 30 |
+
|
| 31 |
+
# Build fusion head and load weights
|
| 32 |
+
head = EAAEmotionRegressor(
|
| 33 |
+
d_sem=cfg["d_sem"], d_ac=cfg["d_ac"], llm_hidden=cfg["llm_hidden"],
|
| 34 |
+
fusion_dim=cfg["fusion_dim"], num_audio_tokens=cfg["num_audio_tokens"]
|
| 35 |
+
).cuda().eval()
|
| 36 |
+
sd_path = hf_hub_download(repo_id="marccgrau/eaa-gemma3-270m-w2vbert-emotion2vec", filename="fusion_head.pt")
|
| 37 |
+
head.load_state_dict(torch.load(sd_path, map_location="cpu"))
|
| 38 |
+
|
| 39 |
+
# Now pass (sem_feats, ac_feats) and (input_ids) to head.forward(..., llm=llm)
|
| 40 |
+
```
|
eaa_config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"EAAEmotionRegressor"
|
| 4 |
+
],
|
| 5 |
+
"gemma_id": "google/gemma-3-270m",
|
| 6 |
+
"sem_id": "facebook/w2v-bert-2.0",
|
| 7 |
+
"acous_id": "iic/emotion2vec_base",
|
| 8 |
+
"fusion_dim": 512,
|
| 9 |
+
"num_audio_tokens": 8,
|
| 10 |
+
"llm_hidden": 640,
|
| 11 |
+
"d_sem": 1024,
|
| 12 |
+
"d_ac": 768,
|
| 13 |
+
"best_mae": 2.5936076641082764,
|
| 14 |
+
"created": 1758628308,
|
| 15 |
+
"library": "torch",
|
| 16 |
+
"requires": [
|
| 17 |
+
"transformers>=4.47.0",
|
| 18 |
+
"peft>=0.13.0",
|
| 19 |
+
"funasr>=1.2.7",
|
| 20 |
+
"modelscope"
|
| 21 |
+
],
|
| 22 |
+
"adapter_repo": "marccgrau/eaa-gemma3-270m-adapter"
|
| 23 |
+
}
|
fusion_head.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cfefad4e803a75d44e2c7525fb383ecb4df6a4065479cefe999cb4bb1851a4e0
|
| 3 |
+
size 18430292
|
modeling_eaa.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, torch.nn as nn, torch.nn.functional as F
|
| 2 |
+
|
| 3 |
+
class DualCrossAttentionFusion(nn.Module):
|
| 4 |
+
def __init__(self, d_sem: int, d_ac: int, d: int, n_heads: int = 4, dropout: float = 0.1):
|
| 5 |
+
super().__init__()
|
| 6 |
+
self.sem_proj = nn.Linear(d_sem, d)
|
| 7 |
+
self.ac_proj = nn.Linear(d_ac, d)
|
| 8 |
+
self.ln_s = nn.LayerNorm(d); self.ln_a = nn.LayerNorm(d)
|
| 9 |
+
self.att_sa = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True)
|
| 10 |
+
self.att_as = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True)
|
| 11 |
+
self.out_proj = nn.Linear(d*4, d)
|
| 12 |
+
|
| 13 |
+
def forward(self, S, A):
|
| 14 |
+
S_ = self.ln_s(self.sem_proj(S))
|
| 15 |
+
A_ = self.ln_a(self.ac_proj(A))
|
| 16 |
+
T = max(S_.size(1), A_.size(1))
|
| 17 |
+
if S_.size(1) != T: S_ = F.pad(S_, (0,0,0, T - S_.size(1)))
|
| 18 |
+
if A_.size(1) != T: A_ = F.pad(A_, (0,0,0, T - A_.size(1)))
|
| 19 |
+
Att_s, _ = self.att_sa(query=S_, key=A_, value=A_)
|
| 20 |
+
Att_a, _ = self.att_as(query=A_, key=S_, value=S_)
|
| 21 |
+
Fused = torch.cat([S_, A_, Att_s, Att_a], dim=-1)
|
| 22 |
+
return self.out_proj(Fused)
|
| 23 |
+
|
| 24 |
+
class FusedToLLMTokens(nn.Module):
|
| 25 |
+
def __init__(self, d_in: int, d_llm: int, num_tokens: int = 8):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.num_tokens = num_tokens
|
| 28 |
+
self.proj = nn.Linear(d_in, d_llm)
|
| 29 |
+
def forward(self, F):
|
| 30 |
+
B,T,D = F.shape
|
| 31 |
+
splits = torch.linspace(0, T, steps=self.num_tokens + 1, device=F.device).long()
|
| 32 |
+
toks = []
|
| 33 |
+
for i in range(self.num_tokens):
|
| 34 |
+
s, e = splits[i].item(), splits[i+1].item()
|
| 35 |
+
seg = F[:, s:e, :] if e > s else F[:, -1:, :]
|
| 36 |
+
toks.append(seg.mean(dim=1))
|
| 37 |
+
X = torch.stack(toks, dim=1) # (B, N, D)
|
| 38 |
+
return self.proj(X) # (B, N, H)
|
| 39 |
+
|
| 40 |
+
class EAAEmotionRegressor(nn.Module):
|
| 41 |
+
def __init__(self, d_sem: int, d_ac: int, llm_hidden: int, fusion_dim=512, n_heads=4, num_audio_tokens=8):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.fusion = DualCrossAttentionFusion(d_sem, d_ac, d=fusion_dim, n_heads=n_heads)
|
| 44 |
+
self.f2tok = FusedToLLMTokens(d_in=fusion_dim, d_llm=llm_hidden, num_tokens=num_audio_tokens)
|
| 45 |
+
self.reg_token = nn.Parameter(torch.randn(1, llm_hidden) * 0.02)
|
| 46 |
+
self.reg_head = nn.Sequential(
|
| 47 |
+
nn.Linear(llm_hidden, llm_hidden//2),
|
| 48 |
+
nn.ReLU(),
|
| 49 |
+
nn.Linear(llm_hidden//2, 1)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def forward(self, sem_feats: torch.Tensor, ac_feats: torch.Tensor, llm, input_ids=None, inputs_embeds=None):
|
| 53 |
+
S = sem_feats
|
| 54 |
+
A = ac_feats
|
| 55 |
+
Fused = self.fusion(S, A) # (B, T, d_fuse)
|
| 56 |
+
audio_tokens = self.f2tok(Fused) # (B, N, H_llm)
|
| 57 |
+
B = audio_tokens.size(0)
|
| 58 |
+
reg = self.reg_token.unsqueeze(0).expand(B, -1, -1)
|
| 59 |
+
if inputs_embeds is None and input_ids is not None:
|
| 60 |
+
txt_embeds = llm.get_input_embeddings()(input_ids)
|
| 61 |
+
else:
|
| 62 |
+
txt_embeds = inputs_embeds
|
| 63 |
+
llm_in = torch.cat([reg, txt_embeds, audio_tokens], dim=1)
|
| 64 |
+
out = llm(inputs_embeds=llm_in, output_hidden_states=True)
|
| 65 |
+
reg_state = out.hidden_states[-1][:, 0, :]
|
| 66 |
+
return self.reg_head(reg_state).squeeze(-1)
|