You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

Vietnamese Emotion Models (Text, Voice, Multimodal)

Three Vietnamese emotion recognition models (text, voice, multimodal) packaged for Hugging Face with configs/labels/metrics and inference snippets. Only the best checkpoint is kept for each branch.

Structure

  • text-phobert-focalloss/: PhoBERT + focal loss for text emotion classification.
  • voice-wav2vec2-vi-emotion/: Wav2Vec2-base-vi-250h fine-tuned for Vietnamese SER.
  • multimodal/: Fusion weights for audio + text (best.pt) with labels.json.

Setup

pip install transformers torch torchaudio soundfile

Voice and multimodal require 16 kHz audio; resample if your files differ.

Text model (PhoBERT focal loss)

from pathlib import Path
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

base = Path(__file__).resolve().parent  # .../hf-release
repo = base / "text-phobert-focalloss"
tok = AutoTokenizer.from_pretrained(repo, use_fast=False)
model = AutoModelForSequenceClassification.from_pretrained(repo)

inputs = tok("Tôi đang rất vui và hào hứng", return_tensors="pt")
with torch.no_grad():
    probs = model(**inputs).logits.softmax(-1)[0]
pred = model.config.id2label[str(int(probs.argmax()))]
print(pred, float(probs.max()))

Voice model (Wav2Vec2 SER)

from pathlib import Path
import torch, torchaudio
from transformers import Wav2Vec2ForSequenceClassification, AutoProcessor

base = Path(__file__).resolve().parent
repo = base / "voice-wav2vec2-vi-emotion"
processor = AutoProcessor.from_pretrained(repo)
model = Wav2Vec2ForSequenceClassification.from_pretrained(repo)

wav, sr = torchaudio.load("audio.wav")
if sr != 16000:
    wav = torchaudio.functional.resample(wav, sr, 16000)
inputs = processor(wav.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
    probs = model(**inputs).logits.softmax(-1)[0]
pred = model.config.id2label[int(probs.argmax())]
print(pred, float(probs.max()))

Multimodal model (audio + transcript)

from pathlib import Path
import sys, torch, torchaudio
from transformers import AutoTokenizer, Wav2Vec2FeatureExtractor
from multimodal.multimodal_train_eval import FusionXMerlin

base = Path(__file__).resolve().parent          # .../hf-release
sys.path.append(str(base.parent))               # add repo root

text_repo = base / "text-phobert-focalloss"
audio_repo = base / "voice-wav2vec2-vi-emotion"
ckpt_path = base / "multimodal" / "best.pt"

ckpt = torch.load(ckpt_path, map_location="cpu")
label2id = ckpt["label2id"]
id2label = {v: k for k, v in label2id.items()}

tokenizer = AutoTokenizer.from_pretrained(text_repo, use_fast=False)
processor = Wav2Vec2FeatureExtractor.from_pretrained(audio_repo)

model = FusionXMerlin(
    text_model_path=text_repo,
    audio_model_path=audio_repo,
    num_classes=len(label2id),
    freeze_encoders=True,
).eval()
model.load_state_dict(ckpt["model_state"])

transcript = "Tôi rất thất vọng về dịch vụ."
wav, sr = torchaudio.load("audio.wav")
if sr != 16000:
    wav = torchaudio.functional.resample(wav, sr, 16000)

t_inputs = tokenizer(transcript, return_tensors="pt", padding=True, truncation=True, max_length=256)
a_inputs = processor(wav.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")

with torch.no_grad():
    logits, _ = model(
        t_inputs["input_ids"],
        t_inputs["attention_mask"],
        a_inputs["input_values"],
        a_inputs["attention_mask"],
    )
    probs = torch.softmax(logits, dim=-1)[0]
pred = id2label[int(probs.argmax())]
print(pred, float(probs.max()))

Extra info

  • Label set: Anger, Disgust, Enjoyment, Fear, Neutral, Sadness, Surprise (mappings inside each config/labels.json).
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support