TouchGrass-3b / tests /test_losses.py
Zandy-Wandy's picture
Upload 39 files
9071ef9 verified
"""
Tests for TouchGrass Loss Functions.
"""
import pytest
import torch
import torch.nn.functional as F
from TouchGrass.training.losses import TouchGrassLoss, MusicAwareLoss
class TestTouchGrassLoss:
"""Test suite for TouchGrassLoss."""
def setup_method(self):
"""Set up test fixtures."""
self.batch_size = 4
self.seq_len = 10
self.vocab_size = 32000
self.loss_fn = TouchGrassLoss(
lm_loss_weight=1.0,
eq_loss_weight=0.1,
music_module_loss_weight=0.05
)
def test_loss_initialization(self):
"""Test loss function initialization."""
assert self.loss_fn.lm_loss_weight == 1.0
assert self.loss_fn.eq_loss_weight == 0.1
assert self.loss_fn.music_module_loss_weight == 0.05
def test_forward_with_all_outputs(self):
"""Test forward pass with all outputs."""
logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
eq_outputs = {
"frustration": torch.rand(self.batch_size, self.seq_len, 1),
"emotion": torch.randn(self.batch_size, self.seq_len, 4)
}
eq_labels = {
"frustration": torch.rand(self.batch_size, self.seq_len, 1),
"emotion": torch.randint(0, 4, (self.batch_size, self.seq_len))
}
music_outputs = {
"tab_validator": torch.rand(self.batch_size, self.seq_len, 1),
"difficulty": torch.randn(self.batch_size, self.seq_len, 3),
"interval_logits": torch.randn(self.batch_size, self.seq_len, 12)
}
music_labels = {
"tab_validator": torch.rand(self.batch_size, self.seq_len, 1),
"difficulty": torch.randint(0, 3, (self.batch_size, self.seq_len)),
"interval_logits": torch.randint(0, 12, (self.batch_size, self.seq_len))
}
loss_dict = self.loss_fn(
logits=logits,
labels=labels,
eq_outputs=eq_outputs,
eq_labels=eq_labels,
music_outputs=music_outputs,
music_labels=music_labels
)
assert "total_loss" in loss_dict
assert "lm_loss" in loss_dict
assert "eq_loss" in loss_dict
assert "music_loss" in loss_dict
assert isinstance(loss_dict["total_loss"], torch.Tensor)
assert loss_dict["total_loss"].shape == ()
def test_forward_without_auxiliary_losses(self):
"""Test forward pass with only LM loss."""
logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
loss_dict = self.loss_fn(logits=logits, labels=labels)
assert "total_loss" in loss_dict
assert "lm_loss" in loss_dict
assert loss_dict["eq_loss"] == 0.0
assert loss_dict["music_loss"] == 0.0
# Total should equal LM loss only
assert torch.isclose(loss_dict["total_loss"], loss_dict["lm_loss"])
def test_lm_loss_calculation(self):
"""Test that LM loss is computed correctly."""
logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
loss_dict = self.loss_fn(logits=logits, labels=labels)
lm_loss = loss_dict["lm_loss"]
# Manual calculation
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
expected_lm_loss = F.cross_entropy(
shift_logits.view(-1, self.vocab_size),
shift_labels.view(-1)
)
assert torch.isclose(lm_loss, expected_lm_loss, rtol=1e-4)
def test_eq_loss_frustration_mse(self):
"""Test that frustration loss uses MSE."""
eq_outputs = {"frustration": torch.rand(self.batch_size, self.seq_len, 1)}
eq_labels = {"frustration": torch.rand(self.batch_size, self.seq_len, 1)}
logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
loss_dict = self.loss_fn(
logits=logits, labels=labels,
eq_outputs=eq_outputs, eq_labels=eq_labels
)
# EQ loss should be non-zero
assert loss_dict["eq_loss"] > 0
def test_eq_loss_emotion_cross_entropy(self):
"""Test that emotion loss uses cross-entropy."""
eq_outputs = {"emotion": torch.randn(self.batch_size, self.seq_len, 4)}
eq_labels = {"emotion": torch.randint(0, 4, (self.batch_size, self.seq_len))}
logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
loss_dict = self.loss_fn(
logits=logits, labels=labels,
eq_outputs=eq_outputs, eq_labels=eq_labels
)
assert loss_dict["eq_loss"] > 0
def test_music_loss_components(self):
"""Test that music module loss aggregates multiple components."""
music_outputs = {
"tab_validator": torch.rand(self.batch_size, self.seq_len, 1),
"difficulty": torch.randn(self.batch_size, self.seq_len, 3),
"interval_logits": torch.randn(self.batch_size, self.seq_len, 12)
}
music_labels = {
"tab_validator": torch.rand(self.batch_size, self.seq_len, 1),
"difficulty": torch.randint(0, 3, (self.batch_size, self.seq_len)),
"interval_logits": torch.randint(0, 12, (self.batch_size, self.seq_len))
}
logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
loss_dict = self.loss_fn(
logits=logits, labels=labels,
music_outputs=music_outputs, music_labels=music_labels
)
assert loss_dict["music_loss"] > 0
def test_loss_weighting(self):
"""Test that loss weights are applied correctly."""
# Create a scenario where we can isolate weights
logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
# Only LM loss
loss1 = self.loss_fn(logits=logits, labels=labels, lm_loss_weight=1.0)
loss2 = self.loss_fn(logits=logits, labels=labels, lm_loss_weight=2.0)
# With double weight, total loss should roughly double (if LM is only component)
assert torch.isclose(loss2["total_loss"], 2 * loss1["total_loss"], rtol=1e-3)
def test_gradient_computation(self):
"""Test that gradients can be computed."""
logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size, requires_grad=True)
labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
loss_dict = self.loss_fn(logits=logits, labels=labels)
loss_dict["total_loss"].backward()
assert logits.grad is not None
def test_different_batch_sizes(self):
"""Test loss with different batch sizes."""
for batch_size in [1, 2, 8]:
seq_len = 10
logits = torch.randn(batch_size, seq_len, self.vocab_size)
labels = torch.randint(0, self.vocab_size, (batch_size, seq_len))
loss_dict = self.loss_fn(logits=logits, labels=labels)
assert loss_dict["total_loss"].shape == ()
def test_different_seq_lengths(self):
"""Test loss with different sequence lengths."""
for seq_len in [5, 20, 50, 100]:
logits = torch.randn(self.batch_size, seq_len, self.vocab_size)
labels = torch.randint(0, self.vocab_size, (self.batch_size, seq_len))
loss_dict = self.loss_fn(logits=logits, labels=labels)
assert loss_dict["total_loss"].shape == ()
def test_loss_dict_keys(self):
"""Test that loss dictionary contains expected keys."""
logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
loss_dict = self.loss_fn(logits=logits, labels=labels)
expected_keys = ["total_loss", "lm_loss", "eq_loss", "music_loss"]
for key in expected_keys:
assert key in loss_dict
def test_loss_values_are_finite(self):
"""Test that all loss values are finite."""
logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
loss_dict = self.loss_fn(logits=logits, labels=labels)
for key, value in loss_dict.items():
assert torch.isfinite(value), f"Loss {key} is not finite: {value}"
def test_loss_weights_accumulate(self):
"""Test that total loss properly accumulates weighted components."""
logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
eq_outputs = {"frustration": torch.rand(self.batch_size, self.seq_len, 1)}
eq_labels = {"frustration": torch.rand(self.batch_size, self.seq_len, 1)}
music_outputs = {"difficulty": torch.randn(self.batch_size, self.seq_len, 3)}
music_labels = {"difficulty": torch.randint(0, 3, (self.batch_size, self.seq_len))}
loss_fn = TouchGrassLoss(lm_loss_weight=1.0, eq_loss_weight=0.5, music_module_loss_weight=0.25)
loss_dict = loss_fn(
logits=logits, labels=labels,
eq_outputs=eq_outputs, eq_labels=eq_labels,
music_outputs=music_outputs, music_labels=music_labels
)
# Total should be weighted sum
expected_total = (
1.0 * loss_dict["lm_loss"] +
0.5 * loss_dict["eq_loss"] +
0.25 * loss_dict["music_loss"]
)
assert torch.isclose(loss_dict["total_loss"], expected_total, rtol=1e-4)
def test_with_custom_loss_weights(self):
"""Test initializing with custom loss weights."""
custom_loss_fn = TouchGrassLoss(
lm_loss_weight=2.0,
eq_loss_weight=0.5,
music_module_loss_weight=0.2
)
assert custom_loss_fn.lm_loss_weight == 2.0
assert custom_loss_fn.eq_loss_weight == 0.5
assert custom_loss_fn.music_module_loss_weight == 0.2
def test_missing_auxiliary_outputs(self):
"""Test that missing auxiliary outputs are handled gracefully."""
logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
# Should work without eq_outputs or music_outputs
loss_dict = self.loss_fn(logits=logits, labels=labels)
assert loss_dict["total_loss"] > 0
class TestMusicAwareLoss:
"""Test suite for MusicAwareLoss (alternative implementation)."""
def test_music_aware_loss_initialization(self):
"""Test MusicAwareLoss initialization."""
loss_fn = MusicAwareLoss()
assert hasattr(loss_fn, "forward")
def test_music_aware_loss_forward(self):
"""Test MusicAwareLoss forward pass."""
loss_fn = MusicAwareLoss()
logits = torch.randn(2, 10, 1000)
labels = torch.randint(0, 1000, (2, 10))
# Should work with just LM loss
loss = loss_fn(logits, labels)
assert isinstance(loss, torch.Tensor)
assert loss.shape == ()
def test_music_aware_loss_with_weights(self):
"""Test MusicAwareLoss with custom weights."""
loss_fn = MusicAwareLoss(
lm_weight=1.0,
music_weight=0.1,
eq_weight=0.05
)
logits = torch.randn(2, 10, 1000)
labels = torch.randint(0, 1000, (2, 10))
loss = loss_fn(logits, labels)
assert torch.isfinite(loss)
if __name__ == "__main__":
pytest.main([__file__, "-v"])