--- datasets: - id4thomas/emotion-prediction-comet-atomic-2020 language: - en base_model: - Qwen/Qwen2.5-3B-Instruct --- # emotion-predictor-Qwen2.5-3B-Instruct LLM trained to predict a character's emotional response in the given situation * Trained to predict in a structured output format. ## Quickstart The model is trained to predict in the following schema ``` from enum import Enum from pydantic import BaseModel class RelationshipStatus(str, Enum): na = "na" low = "low" medium = "medium" high = "high" class EmotionLabel(BaseModel): joy: RelationshipStatus trust: RelationshipStatus fear: RelationshipStatus surprise: RelationshipStatus sadness: RelationshipStatus disgust: RelationshipStatus anger: RelationshipStatus anticipation: RelationshipStatus class EntryResult(BaseModel): emotion: EmotionLabel reason: str ``` Using `outlines` package to generate structured predictions * system prompt & user template is provided [here](./assets/inference_prompt.yaml) ``` import outlines from outlines import models from transformers import AutoTokenizer, AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("id4thomas/emotion-predictor-Qwen2.5-3B-Instruct") tokenizer = AutoTokenizer.from_pretrained("id4thomas/emotion-predictor-Qwen2.5-3B-Instruct") # Initalize outlines generator outlines_model = models.Transformers(model, tokenizer) generator = outlines.generate.json(outlines_model, EntryResult) # Generate messages = [ {"role": "system", "content": system_message}, {"role": "user", "content": user_message} ] input_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) prediction = generator(input_text) >>> EntryResult(emotion=EmotionLabel(joy=, ...) ``` Using endpoint loaded with vllm & OpenAI client package ``` client = OpenAI(...) json_schema = EntryResult.model_json_schema() completion = client.chat.completions.create( model="id4thomas/emotion-predictor-Qwen2.5-3B-Instruct", messages=messages, extra_body={"guided_json": json_schema}, ) print(completion.choices[0].message.content) ```