id4thomas's picture
[add] initial commit
3664f11
|
raw
history blame
2.2 kB
metadata
datasets:
  - id4thomas/emotion-prediction-comet-atomic-2020
language:
  - en
base_model:
  - Qwen/Qwen2.5-3B-Instruct

emotion-predictor-Qwen2.5-3B-Instruct

LLM trained to predict a character's emotional response in the given situation

  • Trained to predict in a structured output format.

Quickstart

The model is trained to predict in the following schema

from enum import Enum
from pydantic import BaseModel

class RelationshipStatus(str, Enum):
    na = "na"
    low = "low"
    medium = "medium"
    high = "high"
    
class EmotionLabel(BaseModel):
    joy: RelationshipStatus
    trust: RelationshipStatus
    fear: RelationshipStatus
    surprise: RelationshipStatus
    sadness: RelationshipStatus
    disgust: RelationshipStatus
    anger: RelationshipStatus
    anticipation: RelationshipStatus
        
class EntryResult(BaseModel):
    emotion: EmotionLabel
    reason: str

Using outlines package to generate structured predictions

  • system prompt & user template is provided here
import outlines
from outlines import models
from transformers import AutoTokenizer, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("id4thomas/emotion-predictor-Qwen2.5-3B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("id4thomas/emotion-predictor-Qwen2.5-3B-Instruct")

# Initalize outlines generator
outlines_model = models.Transformers(model, tokenizer)
generator = outlines.generate.json(outlines_model, EntryResult)

# Generate
messages = [
    {"role": "system", "content": system_message},
    {"role": "user", "content": user_message}
]
input_text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
)
prediction = generator(input_text)
>>> EntryResult(emotion=EmotionLabel(joy=<RelationshipStatus.na: 'na'>, ...)

Using endpoint loaded with vllm & OpenAI client package

client = OpenAI(...)
json_schema = EntryResult.model_json_schema()
completion = client.chat.completions.create(
    model="id4thomas/emotion-predictor-Qwen2.5-3B-Instruct",
    messages=messages,
    extra_body={"guided_json": json_schema},
)
print(completion.choices[0].message.content)