veda-programming / tokenizer.py
vedaco's picture
Update tokenizer.py
c4cd8de verified
"""Subword Tokenizer (BPE-like) for Veda Programming Assistant"""
import json
import re
from typing import List, Dict, Optional, Tuple
class VedaTokenizer:
"""
Subword tokenizer that learns common subwords/phrases.
Better than word-level or char-level tokenization.
"""
def __init__(self, vocab_size: int = 8000):
self.vocab_size = vocab_size
self.token_to_idx: Dict[str, int] = {}
self.idx_to_token: Dict[int, str] = {}
# Base vocabulary (special tokens + ASCII)
self._init_base_vocab()
# Merges for subwords (pair -> new_token)
self.merges: Dict[Tuple[str, str], str] = {}
def _init_base_vocab(self):
"""Initialize base vocabulary"""
special = [
"<PAD>", "<UNK>", "<START>", "<END>",
"<CODE>", "<ENDCODE>",
"<USER>", "<ASSISTANT>"
]
for idx, token in enumerate(special):
self.token_to_idx[token] = idx
self.idx_to_token[idx] = token
# ASCII characters as base tokens
idx = len(special)
# Printable ASCII range
for i in range(32, 127):
char = chr(i)
if char not in self.token_to_idx:
self.token_to_idx[char] = idx
self.idx_to_token[idx] = char
idx += 1
# Common whitespace
for char in ["\n", "\t", " "]: # spaces for indentation
if char not in self.token_to_idx:
self.token_to_idx[char] = idx
self.idx_to_token[idx] = char
idx += 1
self.base_vocab_size = idx
def _get_stats(self, vocab: Dict[Tuple[str, ...], int]) -> Dict[Tuple[str, str], int]:
"""Count frequency of adjacent pairs"""
pairs = {}
for word_tuple, freq in vocab.items():
for i in range(len(word_tuple) - 1):
pair = (word_tuple[i], word_tuple[i+1])
pairs[pair] = pairs.get(pair, 0) + freq
return pairs
def _merge_vocab(self, pair: Tuple[str, str], vocab: Dict[Tuple[str, ...], int]) -> Dict[Tuple[str, ...], int]:
"""Merge all occurrences of pair in vocabulary"""
new_vocab = {}
bigram = pair
new_token = "".join(pair)
for word, freq in vocab.items():
new_word = []
i = 0
while i < len(word):
if i < len(word) - 1 and word[i] == bigram[0] and word[i+1] == bigram[1]:
new_word.append(new_token)
i += 2
else:
new_word.append(word[i])
i += 1
new_vocab[tuple(new_word)] = freq
return new_vocab
def fit(self, texts: List[str]):
"""Train BPE tokenizer on texts"""
# Pre-tokenize into words to avoid merging across word boundaries
# This regex splits by whitespace but keeps punctuation
# Also handles code symbols better
word_counts = {}
for text in texts:
# Simple pre-tokenization for code
words = re.findall(r'[a-zA-Z0-9_]+|[^\s\w]', text)
for word in words:
# Convert word to tuple of characters
token_tuple = tuple(c for c in word)
word_counts[token_tuple] = word_counts.get(token_tuple, 0) + 1
# BPE training loop
vocab = word_counts
num_merges = self.vocab_size - self.base_vocab_size
print(f"Training BPE tokenizer (target vocab: {self.vocab_size})...")
for i in range(num_merges):
pairs = self._get_stats(vocab)
if not pairs:
break
# Find most frequent pair
best_pair = max(pairs, key=pairs.get)
# Stop if pair frequency is too low (e.g., 1)
if pairs[best_pair] < 2:
break
# Merge pair
vocab = self._merge_vocab(best_pair, vocab)
# Add new token to vocabulary
new_token = "".join(best_pair)
self.merges[best_pair] = new_token
idx = len(self.token_to_idx)
self.token_to_idx[new_token] = idx
self.idx_to_token[idx] = new_token
if (i + 1) % 100 == 0:
print(f"BPE merge {i+1}/{num_merges}: '{best_pair[0]}' + '{best_pair[1]}' -> '{new_token}'")
print(f"BPE training complete. Final vocab size: {len(self.token_to_idx)}")
def _tokenize_word(self, word: str) -> List[str]:
"""Tokenize a single word using learned merges"""
if word in self.token_to_idx:
return [word]
# Start with characters
tokens = list(word)
# Apply merges iteratively
# Note: In a real BPE implementation we would apply in order of priority
# Here we do a simpler greedy application based on length
while True:
merged = False
i = 0
new_tokens = []
while i < len(tokens) - 1:
pair = (tokens[i], tokens[i+1])
pair_str = "".join(pair)
# Check if this pair forms a known token
if pair_str in self.token_to_idx:
new_tokens.append(pair_str)
i += 2
merged = True
else:
new_tokens.append(tokens[i])
i += 1
if i < len(tokens):
new_tokens.append(tokens[i])
if not merged:
break
tokens = new_tokens
return tokens
def encode(self, text: str, max_length: Optional[int] = None) -> List[int]:
"""Encode text to token indices"""
# Pre-tokenize same way as training
words = re.findall(r'[a-zA-Z0-9_]+|[^\s\w]|\s+', text)
encoded = []
for word in words:
if word in self.token_to_idx:
encoded.append(self.token_to_idx[word])
else:
# Apply BPE
subwords = self._tokenize_word(word)
for sw in subwords:
encoded.append(self.token_to_idx.get(sw, self.token_to_idx["<UNK>"]))
# Truncate or Pad
if max_length:
if len(encoded) > max_length:
encoded = encoded[:max_length]
elif len(encoded) < max_length:
encoded += [self.token_to_idx["<PAD>"]] * (max_length - len(encoded))
return encoded
def decode(self, indices: List[int]) -> str:
"""Decode indices to text"""
tokens = []
for idx in indices:
# Skip special tokens if needed, but usually we decode them
# and let post-processing handle cleanup
if idx in self.idx_to_token:
token = self.idx_to_token[idx]
if token not in ["<PAD>", "<UNK>", "<START>", "<END>"]:
tokens.append(token)
return "".join(tokens)
def save(self, path: str):
"""Save tokenizer"""
data = {
'vocab_size': self.vocab_size,
'token_to_idx': self.token_to_idx,
'idx_to_token': {str(k): v for k, v in self.idx_to_token.items()},
'base_vocab_size': self.base_vocab_size,
'merges': {f"{p[0]}|{p[1]}": m for p, m in self.merges.items()}
}
with open(path, 'w') as f:
json.dump(data, f, indent=2)
def load(self, path: str):
"""Load tokenizer"""
with open(path, 'r') as f:
data = json.load(f)
self.vocab_size = data['vocab_size']
self.token_to_idx = data['token_to_idx']
self.idx_to_token = {int(k): v for k, v in data['idx_to_token'].items()}
self.base_vocab_size = data.get('base_vocab_size', 100)
# Load merges
if 'merges' in data:
self.merges = {}
for k, v in data['merges'].items():
p = k.split('|')
if len(p) == 2:
self.merges[(p[0], p[1])] = v
@property
def vocabulary_size(self) -> int:
return len(self.token_to_idx)