TouchGrass-3b / tests /test_tokenizer.py
Zandy-Wandy's picture
Upload 39 files
9071ef9 verified
"""
Tests for Music Tokenizer Extension.
"""
import pytest
from unittest.mock import MagicMock, patch
from TouchGrass.tokenizer.music_token_extension import MusicTokenizerExtension
class TestMusicTokenizerExtension:
"""Test suite for MusicTokenizerExtension."""
def setup_method(self):
"""Set up test fixtures."""
self.special_tokens = {
"[GUITAR]": 32000,
"[PIANO]": 32001,
"[DRUMS]": 32002,
"[VOCALS]": 32003,
"[THEORY]": 32004,
"[PRODUCTION]": 32005,
"[FRUSTRATED]": 32006,
"[CONFUSED]": 32007,
"[EXCITED]": 32008,
"[CONFIDENT]": 32009,
"[EASY]": 32010,
"[MEDIUM]": 32011,
"[HARD]": 32012,
"[TAB]": 32013,
"[CHORD]": 32014,
"[SCALE]": 32015,
"[INTERVAL]": 32016,
"[PROGRESSION]": 32017,
"[SIMPLIFY]": 32018,
"[ENCOURAGE]": 32019,
}
self.music_vocab_extensions = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
def test_tokenizer_initialization(self):
"""Test that tokenizer initializes correctly with special tokens."""
with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
mock_tokenizer = MagicMock()
mock_tokenizer.vocab_size = 32000
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
ext = MusicTokenizerExtension(
"Qwen/Qwen3.5-3B-Instruct",
special_tokens=self.special_tokens,
music_vocab_extensions=self.music_vocab_extensions
)
assert ext.base_tokenizer == mock_tokenizer
mock_tokenizer_class.from_pretrained.assert_called_once_with("Qwen/Qwen3.5-3B-Instruct")
def test_special_tokens_added(self):
"""Test that special tokens are added to tokenizer."""
with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
mock_tokenizer = MagicMock()
mock_tokenizer.vocab_size = 32000
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
ext = MusicTokenizerExtension(
"Qwen/Qwen3.5-3B-Instruct",
special_tokens=self.special_tokens,
music_vocab_extensions=[]
)
expected_tokens = list(self.special_tokens.keys())
mock_tokenizer.add_special_tokens.assert_called_once_with(
{"additional_special_tokens": expected_tokens}
)
def test_music_vocab_extensions_added(self):
"""Test that music vocabulary extensions are added."""
with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
mock_tokenizer = MagicMock()
mock_tokenizer.vocab_size = 32000
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
ext = MusicTokenizerExtension(
"Qwen/Qwen3.5-3B-Instruct",
special_tokens={},
music_vocab_extensions=self.music_vocab_extensions
)
# Check that add_tokens was called with music vocab extensions
assert mock_tokenizer.add_tokens.called
added_tokens = mock_tokenizer.add_tokens.call_args[0][0]
assert set(added_tokens) == set(self.music_vocab_extensions)
def test_tokenizer_vocab_size_increased(self):
"""Test that vocab size is correctly increased after adding tokens."""
with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
mock_tokenizer = MagicMock()
mock_tokenizer.vocab_size = 32000
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
num_special = len(self.special_tokens)
num_music = len(self.music_vocab_extensions)
expected_new_vocab_size = 32000 + num_special + num_music
ext = MusicTokenizerExtension(
"Qwen/Qwen3.5-3B-Instruct",
special_tokens=self.special_tokens,
music_vocab_extensions=self.music_vocab_extensions
)
assert ext.base_tokenizer.vocab_size == expected_new_vocab_size
def test_encode_with_music_tokens(self):
"""Test encoding text with music tokens."""
with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
mock_tokenizer = MagicMock()
mock_tokenizer.vocab_size = 32021
mock_tokenizer.encode.return_value = [1, 2, 32000, 3, 4]
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
ext = MusicTokenizerExtension(
"Qwen/Qwen3.5-3B-Instruct",
special_tokens=self.special_tokens,
music_vocab_extensions=[]
)
result = ext.encode("Play a [GUITAR] chord")
assert result == [1, 2, 32000, 3, 4]
mock_tokenizer.encode.assert_called_once_with("Play a [GUITAR] chord")
def test_decode_with_music_tokens(self):
"""Test decoding token IDs with music tokens."""
with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
mock_tokenizer = MagicMock()
mock_tokenizer.vocab_size = 32021
mock_tokenizer.decode.return_value = "Play a [GUITAR] chord"
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
ext = MusicTokenizerExtension(
"Qwen/Qwen3.5-3B-Instruct",
special_tokens=self.special_tokens,
music_vocab_extensions=[]
)
result = ext.decode([1, 2, 32000, 3, 4])
assert result == "Play a [GUITAR] chord"
mock_tokenizer.decode.assert_called_once_with([1, 2, 32000, 3, 4])
def test_get_music_token_id(self):
"""Test retrieving token ID for a music token."""
with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
mock_tokenizer = MagicMock()
mock_tokenizer.vocab_size = 32021
mock_tokenizer.convert_tokens_to_ids.return_value = 32000
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
ext = MusicTokenizerExtension(
"Qwen/Qwen3.5-3B-Instruct",
special_tokens=self.special_tokens,
music_vocab_extensions=[]
)
token_id = ext.get_music_token_id("[GUITAR]")
assert token_id == 32000
mock_tokenizer.convert_tokens_to_ids.assert_called_with("[GUITAR]")
def test_has_music_token(self):
"""Test checking if a token is a music token."""
with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
mock_tokenizer = MagicMock()
mock_tokenizer.vocab_size = 32021
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
ext = MusicTokenizerExtension(
"Qwen/Qwen3.5-3B-Instruct",
special_tokens=self.special_tokens,
music_vocab_extensions=[]
)
assert ext.has_music_token("[GUITAR]") is True
assert ext.has_music_token("[UNKNOWN]") is False
def test_get_music_domain_tokens(self):
"""Test retrieving all domain tokens."""
with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
mock_tokenizer = MagicMock()
mock_tokenizer.vocab_size = 32021
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
ext = MusicTokenizerExtension(
"Qwen/Qwen3.5-3B-Instruct",
special_tokens=self.special_tokens,
music_vocab_extensions=[]
)
domain_tokens = ext.get_music_domain_tokens()
expected = ["[GUITAR]", "[PIANO]", "[DRUMS]", "[VOCALS]", "[THEORY]", "[PRODUCTION]"]
assert domain_tokens == expected
def test_get_emotion_tokens(self):
"""Test retrieving emotion tokens."""
with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
mock_tokenizer = MagicMock()
mock_tokenizer.vocab_size = 32021
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
ext = MusicTokenizerExtension(
"Qwen/Qwen3.5-3B-Instruct",
special_tokens=self.special_tokens,
music_vocab_extensions=[]
)
emotion_tokens = ext.get_emotion_tokens()
expected = ["[FRUSTRATED]", "[CONFUSED]", "[EXCITED]", "[CONFIDENT]"]
assert emotion_tokens == expected
def test_get_difficulty_tokens(self):
"""Test retrieving difficulty tokens."""
with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
mock_tokenizer = MagicMock()
mock_tokenizer.vocab_size = 32021
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
ext = MusicTokenizerExtension(
"Qwen/Qwen3.5-3B-Instruct",
special_tokens=self.special_tokens,
music_vocab_extensions=[]
)
difficulty_tokens = ext.get_difficulty_tokens()
expected = ["[EASY]", "[MEDIUM]", "[HARD]"]
assert difficulty_tokens == expected
def test_get_music_function_tokens(self):
"""Test retrieving music function tokens."""
with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
mock_tokenizer = MagicMock()
mock_tokenizer.vocab_size = 32021
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
ext = MusicTokenizerExtension(
"Qwen/Qwen3.5-3B-Instruct",
special_tokens=self.special_tokens,
music_vocab_extensions=[]
)
function_tokens = ext.get_music_function_tokens()
expected = ["[TAB]", "[CHORD]", "[SCALE]", "[INTERVAL]", "[PROGRESSION]"]
assert function_tokens == expected
def test_get_eq_tokens(self):
"""Test retrieving EQ (emotional intelligence) tokens."""
with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
mock_tokenizer = MagicMock()
mock_tokenizer.vocab_size = 32021
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
ext = MusicTokenizerExtension(
"Qwen/Qwen3.5-3B-Instruct",
special_tokens=self.special_tokens,
music_vocab_extensions=[]
)
eq_tokens = ext.get_eq_tokens()
expected = ["[FRUSTRATED]", "[CONFUSED]", "[EXCITED]", "[CONFIDENT]", "[SIMPLIFY]", "[ENCOURAGE]"]
assert eq_tokens == expected
def test_token_count_with_music_tokens(self):
"""Test that token count increases after adding music tokens."""
with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
mock_tokenizer = MagicMock()
mock_tokenizer.vocab_size = 32000
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
num_special = len(self.special_tokens)
num_music = len(self.music_vocab_extensions)
ext = MusicTokenizerExtension(
"Qwen/Qwen3.5-3B-Instruct",
special_tokens=self.special_tokens,
music_vocab_extensions=self.music_vocab_extensions
)
expected_vocab_size = 32000 + num_special + num_music
assert ext.base_tokenizer.vocab_size == expected_vocab_size
assert ext.base_tokenizer.vocab_size > 32000
if __name__ == "__main__":
pytest.main([__file__, "-v"])