Setup and Inference Code

#11
by E10H1M - opened


Inference speed: Stupid fast.

CODE (SETUP INSTRUCTIONS ARE UNDERNEATH):

import torch
from transformers import AutoModelForCausalLM, AutoProcessor
import numpy as np
import soundfile as sf

model_id = "/path/to/your/weights/FlashLabs/Chroma-4B"

# ----- MODEL LOADING 
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    device_map={"": 0},
    dtype=torch.bfloat16,
)

processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)




# ----- PROMPT CONFIG
system_prompt = (
    "You are Chroma, an advanced virtual human created by the FlashLabs. "
    "You possess the ability to understand auditory inputs and generate both text and speech."
)

conversation = [[
    {
        "role": "system",
        "content": [{"type": "text", "text": system_prompt}],
    },
    {
        "role": "user",
        "content": [{"type": "audio", "audio": "example/make_taco.wav"}],
    },
]]



# ------ SETUP
def load_prompt(speaker_name):
    text_path = f"example/prompt_text/{speaker_name}.txt"
    audio_path = f"example/prompt_audio/{speaker_name}.wav"
    with open(text_path, "r", encoding="utf-8") as f:
        prompt_text = f.read()
    return [prompt_text], [audio_path]

prompt_text, prompt_audio = load_prompt("donald_trump") # find other speaker in example/

inputs = processor(
    conversation,
    add_generation_prompt=True,
    tokenize=False,
    prompt_audio=prompt_audio,
    prompt_text=prompt_text,
)

device = model.device
for k, v in inputs.items():
    if torch.is_tensor(v):
        inputs[k] = v.to(device)

for k in ("input_values", "prompt_input_values"):
    if k in inputs and torch.is_tensor(inputs[k]):
        inputs[k] = inputs[k].to(dtype=model.dtype)


# CONTROLLING SEED
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)




# ------ INFERENCE
output = model.generate(
    **inputs,
    max_new_tokens=256,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    use_cache=True,
)



# DECODING
audio_values = model.codec_model.decode(output.permute(0, 2, 1)).audio_values

# ------ SAVING THE OUTPUT
av = audio_values
a = audio_values[0, 0].float().detach().cpu().numpy()
a = np.asarray(a, dtype=np.float32)

sf.write("chroma_out.wav", a, 24_000, subtype="PCM_16")

ENVIRONMENT

git clone https://github.com/FlashLabs-AI-Corp/FlashLabs-Chroma.git
cd FlashLabs-Chroma
python3 -m venv .venv

# activate your environment
. ./venv/bin/activate

(I'm using python 3.12)

Some general python stuff you should always install before getting into you venv:
python -m pip install -U pip setuptools wheel

Installing Torch

torch (must be minimum 2.7.1, I believe. I used 2.10.0+cu130)
pip3 install --index-url https://download.pytorch.org/whl/cu130 torch torchvision torchaudio

Find other versions here: https://pytorch.org/get-started/previous-versions/

transformers: make sure rc1 or newer.

python -m pip install transformers==5.0.0rc1

Ensure to update to the full release once transformers 5.0 drops.

Audio processing

python -m pip install "av>=14.0.0" "librosa>=0.11.0" "audioread>=3.0.0" "soundfile>=0.13.0"

Other dependencies

python -m pip install "pillow>=11.0.0" "accelerate>=1.7.0" "numpy>=2.2.0" "safetensors>=0.5.0" "huggingface-hub>=1.3.0"

POSSIBLES (transformers may want these but run first and see if it demands it, I was running through different versions so not entirely sure):

pip install protobuf
python -m pip install -U sentencepiece tiktoken tokenizers
python -m pip install -U torchcodec

WEIGHTS CAN BE FOUND HERE:

https://huggingface.co/FlashLabs/Chroma-4B

Note, you don't technically need to clone the github repo, the code is running via torch + transformers.

This comment has been hidden (marked as Resolved)

I tested this setup in google colab,, but got error of numpy

  4 import soundfile as sf
      5 

26 frames
/usr/local/lib/python3.12/dist-packages/numpy/_core/strings.py in <module>
     20 from numpy._core.multiarray import _vec_string
     21 from numpy._core.overrides import array_function_dispatch, set_module
---> 22 from numpy._core.umath import (
     23     _center,
     24     _expandtabs,

ImportError: cannot import name '_center' from 'numpy._core.umath' (/usr/local/lib/python3.12/dist-packages/numpy/_core/umath.py)

---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.

Sign up or log in to comment