""" Tests for Music Tokenizer Extension. """ import pytest from unittest.mock import MagicMock, patch from TouchGrass.tokenizer.music_token_extension import MusicTokenizerExtension class TestMusicTokenizerExtension: """Test suite for MusicTokenizerExtension.""" def setup_method(self): """Set up test fixtures.""" self.special_tokens = { "[GUITAR]": 32000, "[PIANO]": 32001, "[DRUMS]": 32002, "[VOCALS]": 32003, "[THEORY]": 32004, "[PRODUCTION]": 32005, "[FRUSTRATED]": 32006, "[CONFUSED]": 32007, "[EXCITED]": 32008, "[CONFIDENT]": 32009, "[EASY]": 32010, "[MEDIUM]": 32011, "[HARD]": 32012, "[TAB]": 32013, "[CHORD]": 32014, "[SCALE]": 32015, "[INTERVAL]": 32016, "[PROGRESSION]": 32017, "[SIMPLIFY]": 32018, "[ENCOURAGE]": 32019, } self.music_vocab_extensions = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] def test_tokenizer_initialization(self): """Test that tokenizer initializes correctly with special tokens.""" with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class: mock_tokenizer = MagicMock() mock_tokenizer.vocab_size = 32000 mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer ext = MusicTokenizerExtension( "Qwen/Qwen3.5-3B-Instruct", special_tokens=self.special_tokens, music_vocab_extensions=self.music_vocab_extensions ) assert ext.base_tokenizer == mock_tokenizer mock_tokenizer_class.from_pretrained.assert_called_once_with("Qwen/Qwen3.5-3B-Instruct") def test_special_tokens_added(self): """Test that special tokens are added to tokenizer.""" with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class: mock_tokenizer = MagicMock() mock_tokenizer.vocab_size = 32000 mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer ext = MusicTokenizerExtension( "Qwen/Qwen3.5-3B-Instruct", special_tokens=self.special_tokens, music_vocab_extensions=[] ) expected_tokens = list(self.special_tokens.keys()) mock_tokenizer.add_special_tokens.assert_called_once_with( {"additional_special_tokens": expected_tokens} ) def test_music_vocab_extensions_added(self): """Test that music vocabulary extensions are added.""" with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class: mock_tokenizer = MagicMock() mock_tokenizer.vocab_size = 32000 mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer ext = MusicTokenizerExtension( "Qwen/Qwen3.5-3B-Instruct", special_tokens={}, music_vocab_extensions=self.music_vocab_extensions ) # Check that add_tokens was called with music vocab extensions assert mock_tokenizer.add_tokens.called added_tokens = mock_tokenizer.add_tokens.call_args[0][0] assert set(added_tokens) == set(self.music_vocab_extensions) def test_tokenizer_vocab_size_increased(self): """Test that vocab size is correctly increased after adding tokens.""" with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class: mock_tokenizer = MagicMock() mock_tokenizer.vocab_size = 32000 mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer num_special = len(self.special_tokens) num_music = len(self.music_vocab_extensions) expected_new_vocab_size = 32000 + num_special + num_music ext = MusicTokenizerExtension( "Qwen/Qwen3.5-3B-Instruct", special_tokens=self.special_tokens, music_vocab_extensions=self.music_vocab_extensions ) assert ext.base_tokenizer.vocab_size == expected_new_vocab_size def test_encode_with_music_tokens(self): """Test encoding text with music tokens.""" with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class: mock_tokenizer = MagicMock() mock_tokenizer.vocab_size = 32021 mock_tokenizer.encode.return_value = [1, 2, 32000, 3, 4] mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer ext = MusicTokenizerExtension( "Qwen/Qwen3.5-3B-Instruct", special_tokens=self.special_tokens, music_vocab_extensions=[] ) result = ext.encode("Play a [GUITAR] chord") assert result == [1, 2, 32000, 3, 4] mock_tokenizer.encode.assert_called_once_with("Play a [GUITAR] chord") def test_decode_with_music_tokens(self): """Test decoding token IDs with music tokens.""" with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class: mock_tokenizer = MagicMock() mock_tokenizer.vocab_size = 32021 mock_tokenizer.decode.return_value = "Play a [GUITAR] chord" mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer ext = MusicTokenizerExtension( "Qwen/Qwen3.5-3B-Instruct", special_tokens=self.special_tokens, music_vocab_extensions=[] ) result = ext.decode([1, 2, 32000, 3, 4]) assert result == "Play a [GUITAR] chord" mock_tokenizer.decode.assert_called_once_with([1, 2, 32000, 3, 4]) def test_get_music_token_id(self): """Test retrieving token ID for a music token.""" with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class: mock_tokenizer = MagicMock() mock_tokenizer.vocab_size = 32021 mock_tokenizer.convert_tokens_to_ids.return_value = 32000 mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer ext = MusicTokenizerExtension( "Qwen/Qwen3.5-3B-Instruct", special_tokens=self.special_tokens, music_vocab_extensions=[] ) token_id = ext.get_music_token_id("[GUITAR]") assert token_id == 32000 mock_tokenizer.convert_tokens_to_ids.assert_called_with("[GUITAR]") def test_has_music_token(self): """Test checking if a token is a music token.""" with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class: mock_tokenizer = MagicMock() mock_tokenizer.vocab_size = 32021 mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer ext = MusicTokenizerExtension( "Qwen/Qwen3.5-3B-Instruct", special_tokens=self.special_tokens, music_vocab_extensions=[] ) assert ext.has_music_token("[GUITAR]") is True assert ext.has_music_token("[UNKNOWN]") is False def test_get_music_domain_tokens(self): """Test retrieving all domain tokens.""" with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class: mock_tokenizer = MagicMock() mock_tokenizer.vocab_size = 32021 mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer ext = MusicTokenizerExtension( "Qwen/Qwen3.5-3B-Instruct", special_tokens=self.special_tokens, music_vocab_extensions=[] ) domain_tokens = ext.get_music_domain_tokens() expected = ["[GUITAR]", "[PIANO]", "[DRUMS]", "[VOCALS]", "[THEORY]", "[PRODUCTION]"] assert domain_tokens == expected def test_get_emotion_tokens(self): """Test retrieving emotion tokens.""" with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class: mock_tokenizer = MagicMock() mock_tokenizer.vocab_size = 32021 mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer ext = MusicTokenizerExtension( "Qwen/Qwen3.5-3B-Instruct", special_tokens=self.special_tokens, music_vocab_extensions=[] ) emotion_tokens = ext.get_emotion_tokens() expected = ["[FRUSTRATED]", "[CONFUSED]", "[EXCITED]", "[CONFIDENT]"] assert emotion_tokens == expected def test_get_difficulty_tokens(self): """Test retrieving difficulty tokens.""" with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class: mock_tokenizer = MagicMock() mock_tokenizer.vocab_size = 32021 mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer ext = MusicTokenizerExtension( "Qwen/Qwen3.5-3B-Instruct", special_tokens=self.special_tokens, music_vocab_extensions=[] ) difficulty_tokens = ext.get_difficulty_tokens() expected = ["[EASY]", "[MEDIUM]", "[HARD]"] assert difficulty_tokens == expected def test_get_music_function_tokens(self): """Test retrieving music function tokens.""" with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class: mock_tokenizer = MagicMock() mock_tokenizer.vocab_size = 32021 mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer ext = MusicTokenizerExtension( "Qwen/Qwen3.5-3B-Instruct", special_tokens=self.special_tokens, music_vocab_extensions=[] ) function_tokens = ext.get_music_function_tokens() expected = ["[TAB]", "[CHORD]", "[SCALE]", "[INTERVAL]", "[PROGRESSION]"] assert function_tokens == expected def test_get_eq_tokens(self): """Test retrieving EQ (emotional intelligence) tokens.""" with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class: mock_tokenizer = MagicMock() mock_tokenizer.vocab_size = 32021 mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer ext = MusicTokenizerExtension( "Qwen/Qwen3.5-3B-Instruct", special_tokens=self.special_tokens, music_vocab_extensions=[] ) eq_tokens = ext.get_eq_tokens() expected = ["[FRUSTRATED]", "[CONFUSED]", "[EXCITED]", "[CONFIDENT]", "[SIMPLIFY]", "[ENCOURAGE]"] assert eq_tokens == expected def test_token_count_with_music_tokens(self): """Test that token count increases after adding music tokens.""" with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class: mock_tokenizer = MagicMock() mock_tokenizer.vocab_size = 32000 mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer num_special = len(self.special_tokens) num_music = len(self.music_vocab_extensions) ext = MusicTokenizerExtension( "Qwen/Qwen3.5-3B-Instruct", special_tokens=self.special_tokens, music_vocab_extensions=self.music_vocab_extensions ) expected_vocab_size = 32000 + num_special + num_music assert ext.base_tokenizer.vocab_size == expected_vocab_size assert ext.base_tokenizer.vocab_size > 32000 if __name__ == "__main__": pytest.main([__file__, "-v"])