Huge perplexity value

#20
by zhuqiang - opened

Hi all, I just noticed that gemma-3n gives huge perplexity value even when a very simple case is given.

Reproduce

import torch

from transformers import AutoProcessor, AutoModelForImageTextToText


model = AutoModelForImageTextToText.from_pretrained(
        "models/gemma-3n-E2B-it",
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        # attn_implementation="eager",
        device_map='cuda',
    ).eval()
processor = AutoProcessor.from_pretrained("models/gemma-3n-E2B-it", )



messages = [
    {
        "role": "system",
        "content": [{"type": "text", "text": "You are a helpful assistant."}]
    },
    {
        "role": "user",
        "content": [
           {"type": "text", "text": "Hi"}
        ]
    },
    {
        "role": "assistant",
        "content": [
           {"type": "text", "text": "How are you?"}
        ]
    }
]




text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
print(text)
encodings = processor(text=text, images=None, videos=None, padding=False, return_tensors="pt")


input_ids = encodings.input_ids.to('cuda')
target_ids = input_ids.clone()
trg_len = -2
target_ids[:, :trg_len] = -100


with torch.no_grad():
    outputs = model(input_ids, labels=target_ids)

    nll = outputs.loss

ppl = torch.exp(nll)
print(ppl)

Ouptput

tensor(43704.4180, device='cuda:0')

packages:

transformers==4.56.2
torch==2.8.0

Hi @zhuqiang
I believe the issue is with your label masking. Your code target_ids[:, :-2] = -100 incorrectly calculates perplexity on only the last two tokens of the sequence leading to the massive score. I think to fix this, you can mask the entire input prompt and calculate loss only on the target response tokens you want to evaluate. find the token length of your prompt (prompt_len), and then apply the mask correctly like this target_ids[:, :prompt_len] = -100.
Thank you

Hi,

masking the input sequence still return massive ppl. So I try to print the ppl for each token, turn out some tokens give huge value.

Token 0: 2 -> '<bos>', ppl: nan
Token 1: 105 -> '<start_of_turn>', ppl: 4.323620912617226e+17
Token 2: 2364 -> 'user', ppl: 91463056.0
Token 3: 107 -> '\n', ppl: 154.09774780273438
Token 4: 3048 -> 'You', ppl: 2894.81201171875
Token 5: 659 -> ' are', ppl: 593.8463134765625
Token 6: 496 -> ' a', ppl: 1.6754709482192993
Token 7: 11045 -> ' helpful', ppl: 538.2894287109375
Token 8: 16326 -> ' assistant', ppl: 75.0553207397461
Token 9: 236761 -> '.', ppl: 4.012104034423828
Token 10: 108 -> '\n\n', ppl: 880.7503051757812
Token 11: 10979 -> 'Hi', ppl: 12539.9169921875
Token 12: 236764 -> ',', ppl: 12.089090347290039
Token 13: 1217 -> ' how', ppl: 14.097077369689941
Token 14: 659 -> ' are', ppl: 18.346384048461914
Token 15: 611 -> ' you', ppl: 1.0000014305114746
Token 16: 236881 -> '?', ppl: 20.15067481994629
Token 17: 106 -> '<end_of_turn>', ppl: 1397942784.0
Token 18: 107 -> '\n', ppl: 1.033643126487732
Token 19: 105 -> '<start_of_turn>', ppl: 479222112.0
Token 20: 4368 -> 'model', ppl: 2953670656.0
Token 21: 107 -> '\n', ppl: 262706.875
Token 22: 236777 -> 'I', ppl: 2114.218017578125
Token 23: 1006 -> ' am', ppl: 1.0164035558700562
Token 24: 5851 -> ' fine', ppl: 4462355456.0
Token 25: 236764 -> ',', ppl: 1.0000015497207642
Token 26: 7806 -> ' thank', ppl: 1.0000077486038208
Token 27: 611 -> ' you', ppl: 1.0
Token 28: 236888 -> '!', ppl: 89641.359375
Token 29: 106 -> '<end_of_turn>', ppl: 1817642532864.0
Token 30: 107 -> '\n', ppl: 199582.5625

Decoded (joined): 

<bos><start_of_turn>user
You are a helpful assistant.

Hi, how are you?<end_of_turn>
<start_of_turn>model
I am fine, thank you!<end_of_turn>

here is the updated code


import torch

from transformers import AutoProcessor, AutoModelForImageTextToText


model = AutoModelForImageTextToText.from_pretrained(
        "./gemma-3n-E2B-it",
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        # attn_implementation="eager",
        device_map='cuda',
    ).eval()
processor = AutoProcessor.from_pretrained("./gemma-3n-E2B-it", )



messages = [
    {
        "role": "system",
        "content": [{"type": "text", "text": "You are a helpful assistant."}]
    },
    {
        "role": "user",
        "content": [
           {"type": "text", "text": "Hi, how are you?"}
        ]
    },
    
]




encodings = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_dict=True,
    return_tensors="pt",)
# encodings = processor(text=text, images=None, videos=None, padding=False, return_tensors="pt")
input_ids = encodings.input_ids.to('cuda')
trg_len = input_ids.shape[1]
print(input_ids, input_ids.shape)

messages.append({
    "role": "model",
    "content": [
        {"type": "text", "text": "I am fine, thank you!"}
    ]
})

encodings = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_dict=True,
    return_tensors="pt",)
# text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
input_ids = encodings.input_ids.to('cuda')
print(input_ids, input_ids.shape)

# Decode the input_ids token one by one
decoded_tokens = []
for i in range(input_ids.shape[1]):
    token_id = input_ids[0, i].item()
    token_str = processor.tokenizer.decode([token_id])
    decoded_tokens.append(token_str)

    target_ids = input_ids.clone()
    target_ids[:, :] = -100
    target_ids[:, i] = input_ids[:, i]


    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)

        nll = outputs.loss

    ppl = torch.exp(nll)
    print(f"Token {i}: {token_id} -> {repr(token_str)}, ppl: {ppl}")


# Optionally, print full sentence/text
print("Decoded (joined):", "".join(decoded_tokens))

Google org

Hi
The observed high perplexity values in the shared example occur because the language model is attempting to predict structural and conversational boundary tokens like , , and the model/user labels. Please recalculate the perplexity after excluding the special tokens. . And also a much better approach is the "sliding window" strategy as it ensures there is always has a good amount of recent context to make its next prediction This gives a more accurate and usually lower perplexity score .
Please check this for reference https://huggingface.co/docs/transformers/en/perplexity
Thank you

Sign up or log in to comment