| | """
|
| | Tests for TouchGrass Loss Functions.
|
| | """
|
| |
|
| | import pytest
|
| | import torch
|
| | import torch.nn.functional as F
|
| |
|
| | from TouchGrass.training.losses import TouchGrassLoss, MusicAwareLoss
|
| |
|
| |
|
| | class TestTouchGrassLoss:
|
| | """Test suite for TouchGrassLoss."""
|
| |
|
| | def setup_method(self):
|
| | """Set up test fixtures."""
|
| | self.batch_size = 4
|
| | self.seq_len = 10
|
| | self.vocab_size = 32000
|
| | self.loss_fn = TouchGrassLoss(
|
| | lm_loss_weight=1.0,
|
| | eq_loss_weight=0.1,
|
| | music_module_loss_weight=0.05
|
| | )
|
| |
|
| | def test_loss_initialization(self):
|
| | """Test loss function initialization."""
|
| | assert self.loss_fn.lm_loss_weight == 1.0
|
| | assert self.loss_fn.eq_loss_weight == 0.1
|
| | assert self.loss_fn.music_module_loss_weight == 0.05
|
| |
|
| | def test_forward_with_all_outputs(self):
|
| | """Test forward pass with all outputs."""
|
| | logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
|
| | labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
|
| |
|
| | eq_outputs = {
|
| | "frustration": torch.rand(self.batch_size, self.seq_len, 1),
|
| | "emotion": torch.randn(self.batch_size, self.seq_len, 4)
|
| | }
|
| | eq_labels = {
|
| | "frustration": torch.rand(self.batch_size, self.seq_len, 1),
|
| | "emotion": torch.randint(0, 4, (self.batch_size, self.seq_len))
|
| | }
|
| |
|
| | music_outputs = {
|
| | "tab_validator": torch.rand(self.batch_size, self.seq_len, 1),
|
| | "difficulty": torch.randn(self.batch_size, self.seq_len, 3),
|
| | "interval_logits": torch.randn(self.batch_size, self.seq_len, 12)
|
| | }
|
| | music_labels = {
|
| | "tab_validator": torch.rand(self.batch_size, self.seq_len, 1),
|
| | "difficulty": torch.randint(0, 3, (self.batch_size, self.seq_len)),
|
| | "interval_logits": torch.randint(0, 12, (self.batch_size, self.seq_len))
|
| | }
|
| |
|
| | loss_dict = self.loss_fn(
|
| | logits=logits,
|
| | labels=labels,
|
| | eq_outputs=eq_outputs,
|
| | eq_labels=eq_labels,
|
| | music_outputs=music_outputs,
|
| | music_labels=music_labels
|
| | )
|
| |
|
| | assert "total_loss" in loss_dict
|
| | assert "lm_loss" in loss_dict
|
| | assert "eq_loss" in loss_dict
|
| | assert "music_loss" in loss_dict
|
| | assert isinstance(loss_dict["total_loss"], torch.Tensor)
|
| | assert loss_dict["total_loss"].shape == ()
|
| |
|
| | def test_forward_without_auxiliary_losses(self):
|
| | """Test forward pass with only LM loss."""
|
| | logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
|
| | labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
|
| |
|
| | loss_dict = self.loss_fn(logits=logits, labels=labels)
|
| |
|
| | assert "total_loss" in loss_dict
|
| | assert "lm_loss" in loss_dict
|
| | assert loss_dict["eq_loss"] == 0.0
|
| | assert loss_dict["music_loss"] == 0.0
|
| |
|
| | assert torch.isclose(loss_dict["total_loss"], loss_dict["lm_loss"])
|
| |
|
| | def test_lm_loss_calculation(self):
|
| | """Test that LM loss is computed correctly."""
|
| | logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
|
| | labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
|
| |
|
| | loss_dict = self.loss_fn(logits=logits, labels=labels)
|
| | lm_loss = loss_dict["lm_loss"]
|
| |
|
| |
|
| | shift_logits = logits[..., :-1, :].contiguous()
|
| | shift_labels = labels[..., 1:].contiguous()
|
| | expected_lm_loss = F.cross_entropy(
|
| | shift_logits.view(-1, self.vocab_size),
|
| | shift_labels.view(-1)
|
| | )
|
| |
|
| | assert torch.isclose(lm_loss, expected_lm_loss, rtol=1e-4)
|
| |
|
| | def test_eq_loss_frustration_mse(self):
|
| | """Test that frustration loss uses MSE."""
|
| | eq_outputs = {"frustration": torch.rand(self.batch_size, self.seq_len, 1)}
|
| | eq_labels = {"frustration": torch.rand(self.batch_size, self.seq_len, 1)}
|
| |
|
| | logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
|
| | labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
|
| |
|
| | loss_dict = self.loss_fn(
|
| | logits=logits, labels=labels,
|
| | eq_outputs=eq_outputs, eq_labels=eq_labels
|
| | )
|
| |
|
| |
|
| | assert loss_dict["eq_loss"] > 0
|
| |
|
| | def test_eq_loss_emotion_cross_entropy(self):
|
| | """Test that emotion loss uses cross-entropy."""
|
| | eq_outputs = {"emotion": torch.randn(self.batch_size, self.seq_len, 4)}
|
| | eq_labels = {"emotion": torch.randint(0, 4, (self.batch_size, self.seq_len))}
|
| |
|
| | logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
|
| | labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
|
| |
|
| | loss_dict = self.loss_fn(
|
| | logits=logits, labels=labels,
|
| | eq_outputs=eq_outputs, eq_labels=eq_labels
|
| | )
|
| |
|
| | assert loss_dict["eq_loss"] > 0
|
| |
|
| | def test_music_loss_components(self):
|
| | """Test that music module loss aggregates multiple components."""
|
| | music_outputs = {
|
| | "tab_validator": torch.rand(self.batch_size, self.seq_len, 1),
|
| | "difficulty": torch.randn(self.batch_size, self.seq_len, 3),
|
| | "interval_logits": torch.randn(self.batch_size, self.seq_len, 12)
|
| | }
|
| | music_labels = {
|
| | "tab_validator": torch.rand(self.batch_size, self.seq_len, 1),
|
| | "difficulty": torch.randint(0, 3, (self.batch_size, self.seq_len)),
|
| | "interval_logits": torch.randint(0, 12, (self.batch_size, self.seq_len))
|
| | }
|
| |
|
| | logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
|
| | labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
|
| |
|
| | loss_dict = self.loss_fn(
|
| | logits=logits, labels=labels,
|
| | music_outputs=music_outputs, music_labels=music_labels
|
| | )
|
| |
|
| | assert loss_dict["music_loss"] > 0
|
| |
|
| | def test_loss_weighting(self):
|
| | """Test that loss weights are applied correctly."""
|
| |
|
| | logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
|
| | labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
|
| |
|
| |
|
| | loss1 = self.loss_fn(logits=logits, labels=labels, lm_loss_weight=1.0)
|
| | loss2 = self.loss_fn(logits=logits, labels=labels, lm_loss_weight=2.0)
|
| |
|
| |
|
| | assert torch.isclose(loss2["total_loss"], 2 * loss1["total_loss"], rtol=1e-3)
|
| |
|
| | def test_gradient_computation(self):
|
| | """Test that gradients can be computed."""
|
| | logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size, requires_grad=True)
|
| | labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
|
| |
|
| | loss_dict = self.loss_fn(logits=logits, labels=labels)
|
| | loss_dict["total_loss"].backward()
|
| |
|
| | assert logits.grad is not None
|
| |
|
| | def test_different_batch_sizes(self):
|
| | """Test loss with different batch sizes."""
|
| | for batch_size in [1, 2, 8]:
|
| | seq_len = 10
|
| | logits = torch.randn(batch_size, seq_len, self.vocab_size)
|
| | labels = torch.randint(0, self.vocab_size, (batch_size, seq_len))
|
| |
|
| | loss_dict = self.loss_fn(logits=logits, labels=labels)
|
| | assert loss_dict["total_loss"].shape == ()
|
| |
|
| | def test_different_seq_lengths(self):
|
| | """Test loss with different sequence lengths."""
|
| | for seq_len in [5, 20, 50, 100]:
|
| | logits = torch.randn(self.batch_size, seq_len, self.vocab_size)
|
| | labels = torch.randint(0, self.vocab_size, (self.batch_size, seq_len))
|
| |
|
| | loss_dict = self.loss_fn(logits=logits, labels=labels)
|
| | assert loss_dict["total_loss"].shape == ()
|
| |
|
| | def test_loss_dict_keys(self):
|
| | """Test that loss dictionary contains expected keys."""
|
| | logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
|
| | labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
|
| |
|
| | loss_dict = self.loss_fn(logits=logits, labels=labels)
|
| |
|
| | expected_keys = ["total_loss", "lm_loss", "eq_loss", "music_loss"]
|
| | for key in expected_keys:
|
| | assert key in loss_dict
|
| |
|
| | def test_loss_values_are_finite(self):
|
| | """Test that all loss values are finite."""
|
| | logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
|
| | labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
|
| |
|
| | loss_dict = self.loss_fn(logits=logits, labels=labels)
|
| |
|
| | for key, value in loss_dict.items():
|
| | assert torch.isfinite(value), f"Loss {key} is not finite: {value}"
|
| |
|
| | def test_loss_weights_accumulate(self):
|
| | """Test that total loss properly accumulates weighted components."""
|
| | logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
|
| | labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
|
| |
|
| | eq_outputs = {"frustration": torch.rand(self.batch_size, self.seq_len, 1)}
|
| | eq_labels = {"frustration": torch.rand(self.batch_size, self.seq_len, 1)}
|
| |
|
| | music_outputs = {"difficulty": torch.randn(self.batch_size, self.seq_len, 3)}
|
| | music_labels = {"difficulty": torch.randint(0, 3, (self.batch_size, self.seq_len))}
|
| |
|
| | loss_fn = TouchGrassLoss(lm_loss_weight=1.0, eq_loss_weight=0.5, music_module_loss_weight=0.25)
|
| | loss_dict = loss_fn(
|
| | logits=logits, labels=labels,
|
| | eq_outputs=eq_outputs, eq_labels=eq_labels,
|
| | music_outputs=music_outputs, music_labels=music_labels
|
| | )
|
| |
|
| |
|
| | expected_total = (
|
| | 1.0 * loss_dict["lm_loss"] +
|
| | 0.5 * loss_dict["eq_loss"] +
|
| | 0.25 * loss_dict["music_loss"]
|
| | )
|
| | assert torch.isclose(loss_dict["total_loss"], expected_total, rtol=1e-4)
|
| |
|
| | def test_with_custom_loss_weights(self):
|
| | """Test initializing with custom loss weights."""
|
| | custom_loss_fn = TouchGrassLoss(
|
| | lm_loss_weight=2.0,
|
| | eq_loss_weight=0.5,
|
| | music_module_loss_weight=0.2
|
| | )
|
| | assert custom_loss_fn.lm_loss_weight == 2.0
|
| | assert custom_loss_fn.eq_loss_weight == 0.5
|
| | assert custom_loss_fn.music_module_loss_weight == 0.2
|
| |
|
| | def test_missing_auxiliary_outputs(self):
|
| | """Test that missing auxiliary outputs are handled gracefully."""
|
| | logits = torch.randn(self.batch_size, self.seq_len, self.vocab_size)
|
| | labels = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len))
|
| |
|
| |
|
| | loss_dict = self.loss_fn(logits=logits, labels=labels)
|
| | assert loss_dict["total_loss"] > 0
|
| |
|
| |
|
| | class TestMusicAwareLoss:
|
| | """Test suite for MusicAwareLoss (alternative implementation)."""
|
| |
|
| | def test_music_aware_loss_initialization(self):
|
| | """Test MusicAwareLoss initialization."""
|
| | loss_fn = MusicAwareLoss()
|
| | assert hasattr(loss_fn, "forward")
|
| |
|
| | def test_music_aware_loss_forward(self):
|
| | """Test MusicAwareLoss forward pass."""
|
| | loss_fn = MusicAwareLoss()
|
| | logits = torch.randn(2, 10, 1000)
|
| | labels = torch.randint(0, 1000, (2, 10))
|
| |
|
| |
|
| | loss = loss_fn(logits, labels)
|
| | assert isinstance(loss, torch.Tensor)
|
| | assert loss.shape == ()
|
| |
|
| | def test_music_aware_loss_with_weights(self):
|
| | """Test MusicAwareLoss with custom weights."""
|
| | loss_fn = MusicAwareLoss(
|
| | lm_weight=1.0,
|
| | music_weight=0.1,
|
| | eq_weight=0.05
|
| | )
|
| | logits = torch.randn(2, 10, 1000)
|
| | labels = torch.randint(0, 1000, (2, 10))
|
| |
|
| | loss = loss_fn(logits, labels)
|
| | assert torch.isfinite(loss)
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | pytest.main([__file__, "-v"])
|
| |
|