marccgrau commited on
Commit
2f3d362
·
verified ·
1 Parent(s): 280ab77

Add fusion head + config (best MAE 2.594)

Browse files
Files changed (4) hide show
  1. README.md +40 -0
  2. eaa_config.json +23 -0
  3. fusion_head.pt +3 -0
  4. modeling_eaa.py +66 -0
README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EAA Fusion Head for Gemma (LoRA) + w2v-bert-2.0 + emotion2vec
2
+
3
+ This repo hosts the **fusion head** weights and code for the Emotion-Aware Audio LLM.
4
+ - LoRA adapter lives at: **marccgrau/eaa-gemma3-270m-adapter**
5
+ - Upstream encoders: `facebook/w2v-bert-2.0` (semantic) and `iic/emotion2vec_base` (acoustic via FunASR)
6
+ - LLM: `google/gemma-3-270m`
7
+
8
+ ## Files
9
+ - `fusion_head.pt` — PyTorch state_dict of the fusion/regression head
10
+ - `eaa_config.json` — minimal config (IDs, dims, hyperparams)
11
+ - `modeling_eaa.py` — the fusion architecture (Dual X-Attn + pooling + [REG] head)
12
+
13
+ ## Quickload (Python)
14
+ ```python
15
+ import torch, json
16
+ from huggingface_hub import hf_hub_download
17
+ from modeling_eaa import EAAEmotionRegressor
18
+
19
+ # Download artifacts
20
+ cfg_path = hf_hub_download(repo_id="marccgrau/eaa-gemma3-270m-w2vbert-emotion2vec", filename="eaa_config.json")
21
+ with open(cfg_path) as f:
22
+ cfg = json.load(f)
23
+
24
+ # Recreate Gemma + load LoRA adapter
25
+ from transformers import AutoModelForCausalLM, AutoTokenizer
26
+ from peft import PeftModel
27
+ tok = AutoTokenizer.from_pretrained(cfg["gemma_id"], trust_remote_code=True)
28
+ llm_base = AutoModelForCausalLM.from_pretrained(cfg["gemma_id"], trust_remote_code=True, torch_dtype=torch.float16).cuda()
29
+ llm = PeftModel.from_pretrained(llm_base, cfg["adapter_repo"]).eval()
30
+
31
+ # Build fusion head and load weights
32
+ head = EAAEmotionRegressor(
33
+ d_sem=cfg["d_sem"], d_ac=cfg["d_ac"], llm_hidden=cfg["llm_hidden"],
34
+ fusion_dim=cfg["fusion_dim"], num_audio_tokens=cfg["num_audio_tokens"]
35
+ ).cuda().eval()
36
+ sd_path = hf_hub_download(repo_id="marccgrau/eaa-gemma3-270m-w2vbert-emotion2vec", filename="fusion_head.pt")
37
+ head.load_state_dict(torch.load(sd_path, map_location="cpu"))
38
+
39
+ # Now pass (sem_feats, ac_feats) and (input_ids) to head.forward(..., llm=llm)
40
+ ```
eaa_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "EAAEmotionRegressor"
4
+ ],
5
+ "gemma_id": "google/gemma-3-270m",
6
+ "sem_id": "facebook/w2v-bert-2.0",
7
+ "acous_id": "iic/emotion2vec_base",
8
+ "fusion_dim": 512,
9
+ "num_audio_tokens": 8,
10
+ "llm_hidden": 640,
11
+ "d_sem": 1024,
12
+ "d_ac": 768,
13
+ "best_mae": 2.5936076641082764,
14
+ "created": 1758628308,
15
+ "library": "torch",
16
+ "requires": [
17
+ "transformers>=4.47.0",
18
+ "peft>=0.13.0",
19
+ "funasr>=1.2.7",
20
+ "modelscope"
21
+ ],
22
+ "adapter_repo": "marccgrau/eaa-gemma3-270m-adapter"
23
+ }
fusion_head.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfefad4e803a75d44e2c7525fb383ecb4df6a4065479cefe999cb4bb1851a4e0
3
+ size 18430292
modeling_eaa.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torch.nn as nn, torch.nn.functional as F
2
+
3
+ class DualCrossAttentionFusion(nn.Module):
4
+ def __init__(self, d_sem: int, d_ac: int, d: int, n_heads: int = 4, dropout: float = 0.1):
5
+ super().__init__()
6
+ self.sem_proj = nn.Linear(d_sem, d)
7
+ self.ac_proj = nn.Linear(d_ac, d)
8
+ self.ln_s = nn.LayerNorm(d); self.ln_a = nn.LayerNorm(d)
9
+ self.att_sa = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True)
10
+ self.att_as = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True)
11
+ self.out_proj = nn.Linear(d*4, d)
12
+
13
+ def forward(self, S, A):
14
+ S_ = self.ln_s(self.sem_proj(S))
15
+ A_ = self.ln_a(self.ac_proj(A))
16
+ T = max(S_.size(1), A_.size(1))
17
+ if S_.size(1) != T: S_ = F.pad(S_, (0,0,0, T - S_.size(1)))
18
+ if A_.size(1) != T: A_ = F.pad(A_, (0,0,0, T - A_.size(1)))
19
+ Att_s, _ = self.att_sa(query=S_, key=A_, value=A_)
20
+ Att_a, _ = self.att_as(query=A_, key=S_, value=S_)
21
+ Fused = torch.cat([S_, A_, Att_s, Att_a], dim=-1)
22
+ return self.out_proj(Fused)
23
+
24
+ class FusedToLLMTokens(nn.Module):
25
+ def __init__(self, d_in: int, d_llm: int, num_tokens: int = 8):
26
+ super().__init__()
27
+ self.num_tokens = num_tokens
28
+ self.proj = nn.Linear(d_in, d_llm)
29
+ def forward(self, F):
30
+ B,T,D = F.shape
31
+ splits = torch.linspace(0, T, steps=self.num_tokens + 1, device=F.device).long()
32
+ toks = []
33
+ for i in range(self.num_tokens):
34
+ s, e = splits[i].item(), splits[i+1].item()
35
+ seg = F[:, s:e, :] if e > s else F[:, -1:, :]
36
+ toks.append(seg.mean(dim=1))
37
+ X = torch.stack(toks, dim=1) # (B, N, D)
38
+ return self.proj(X) # (B, N, H)
39
+
40
+ class EAAEmotionRegressor(nn.Module):
41
+ def __init__(self, d_sem: int, d_ac: int, llm_hidden: int, fusion_dim=512, n_heads=4, num_audio_tokens=8):
42
+ super().__init__()
43
+ self.fusion = DualCrossAttentionFusion(d_sem, d_ac, d=fusion_dim, n_heads=n_heads)
44
+ self.f2tok = FusedToLLMTokens(d_in=fusion_dim, d_llm=llm_hidden, num_tokens=num_audio_tokens)
45
+ self.reg_token = nn.Parameter(torch.randn(1, llm_hidden) * 0.02)
46
+ self.reg_head = nn.Sequential(
47
+ nn.Linear(llm_hidden, llm_hidden//2),
48
+ nn.ReLU(),
49
+ nn.Linear(llm_hidden//2, 1)
50
+ )
51
+
52
+ def forward(self, sem_feats: torch.Tensor, ac_feats: torch.Tensor, llm, input_ids=None, inputs_embeds=None):
53
+ S = sem_feats
54
+ A = ac_feats
55
+ Fused = self.fusion(S, A) # (B, T, d_fuse)
56
+ audio_tokens = self.f2tok(Fused) # (B, N, H_llm)
57
+ B = audio_tokens.size(0)
58
+ reg = self.reg_token.unsqueeze(0).expand(B, -1, -1)
59
+ if inputs_embeds is None and input_ids is not None:
60
+ txt_embeds = llm.get_input_embeddings()(input_ids)
61
+ else:
62
+ txt_embeds = inputs_embeds
63
+ llm_in = torch.cat([reg, txt_embeds, audio_tokens], dim=1)
64
+ out = llm(inputs_embeds=llm_in, output_hidden_states=True)
65
+ reg_state = out.hidden_states[-1][:, 0, :]
66
+ return self.reg_head(reg_state).squeeze(-1)