"""Subword Tokenizer (BPE-like) for Veda Programming Assistant""" import json import re from typing import List, Dict, Optional, Tuple class VedaTokenizer: """ Subword tokenizer that learns common subwords/phrases. Better than word-level or char-level tokenization. """ def __init__(self, vocab_size: int = 8000): self.vocab_size = vocab_size self.token_to_idx: Dict[str, int] = {} self.idx_to_token: Dict[int, str] = {} # Base vocabulary (special tokens + ASCII) self._init_base_vocab() # Merges for subwords (pair -> new_token) self.merges: Dict[Tuple[str, str], str] = {} def _init_base_vocab(self): """Initialize base vocabulary""" special = [ "", "", "", "", "", "", "", "" ] for idx, token in enumerate(special): self.token_to_idx[token] = idx self.idx_to_token[idx] = token # ASCII characters as base tokens idx = len(special) # Printable ASCII range for i in range(32, 127): char = chr(i) if char not in self.token_to_idx: self.token_to_idx[char] = idx self.idx_to_token[idx] = char idx += 1 # Common whitespace for char in ["\n", "\t", " "]: # spaces for indentation if char not in self.token_to_idx: self.token_to_idx[char] = idx self.idx_to_token[idx] = char idx += 1 self.base_vocab_size = idx def _get_stats(self, vocab: Dict[Tuple[str, ...], int]) -> Dict[Tuple[str, str], int]: """Count frequency of adjacent pairs""" pairs = {} for word_tuple, freq in vocab.items(): for i in range(len(word_tuple) - 1): pair = (word_tuple[i], word_tuple[i+1]) pairs[pair] = pairs.get(pair, 0) + freq return pairs def _merge_vocab(self, pair: Tuple[str, str], vocab: Dict[Tuple[str, ...], int]) -> Dict[Tuple[str, ...], int]: """Merge all occurrences of pair in vocabulary""" new_vocab = {} bigram = pair new_token = "".join(pair) for word, freq in vocab.items(): new_word = [] i = 0 while i < len(word): if i < len(word) - 1 and word[i] == bigram[0] and word[i+1] == bigram[1]: new_word.append(new_token) i += 2 else: new_word.append(word[i]) i += 1 new_vocab[tuple(new_word)] = freq return new_vocab def fit(self, texts: List[str]): """Train BPE tokenizer on texts""" # Pre-tokenize into words to avoid merging across word boundaries # This regex splits by whitespace but keeps punctuation # Also handles code symbols better word_counts = {} for text in texts: # Simple pre-tokenization for code words = re.findall(r'[a-zA-Z0-9_]+|[^\s\w]', text) for word in words: # Convert word to tuple of characters token_tuple = tuple(c for c in word) word_counts[token_tuple] = word_counts.get(token_tuple, 0) + 1 # BPE training loop vocab = word_counts num_merges = self.vocab_size - self.base_vocab_size print(f"Training BPE tokenizer (target vocab: {self.vocab_size})...") for i in range(num_merges): pairs = self._get_stats(vocab) if not pairs: break # Find most frequent pair best_pair = max(pairs, key=pairs.get) # Stop if pair frequency is too low (e.g., 1) if pairs[best_pair] < 2: break # Merge pair vocab = self._merge_vocab(best_pair, vocab) # Add new token to vocabulary new_token = "".join(best_pair) self.merges[best_pair] = new_token idx = len(self.token_to_idx) self.token_to_idx[new_token] = idx self.idx_to_token[idx] = new_token if (i + 1) % 100 == 0: print(f"BPE merge {i+1}/{num_merges}: '{best_pair[0]}' + '{best_pair[1]}' -> '{new_token}'") print(f"BPE training complete. Final vocab size: {len(self.token_to_idx)}") def _tokenize_word(self, word: str) -> List[str]: """Tokenize a single word using learned merges""" if word in self.token_to_idx: return [word] # Start with characters tokens = list(word) # Apply merges iteratively # Note: In a real BPE implementation we would apply in order of priority # Here we do a simpler greedy application based on length while True: merged = False i = 0 new_tokens = [] while i < len(tokens) - 1: pair = (tokens[i], tokens[i+1]) pair_str = "".join(pair) # Check if this pair forms a known token if pair_str in self.token_to_idx: new_tokens.append(pair_str) i += 2 merged = True else: new_tokens.append(tokens[i]) i += 1 if i < len(tokens): new_tokens.append(tokens[i]) if not merged: break tokens = new_tokens return tokens def encode(self, text: str, max_length: Optional[int] = None) -> List[int]: """Encode text to token indices""" # Pre-tokenize same way as training words = re.findall(r'[a-zA-Z0-9_]+|[^\s\w]|\s+', text) encoded = [] for word in words: if word in self.token_to_idx: encoded.append(self.token_to_idx[word]) else: # Apply BPE subwords = self._tokenize_word(word) for sw in subwords: encoded.append(self.token_to_idx.get(sw, self.token_to_idx[""])) # Truncate or Pad if max_length: if len(encoded) > max_length: encoded = encoded[:max_length] elif len(encoded) < max_length: encoded += [self.token_to_idx[""]] * (max_length - len(encoded)) return encoded def decode(self, indices: List[int]) -> str: """Decode indices to text""" tokens = [] for idx in indices: # Skip special tokens if needed, but usually we decode them # and let post-processing handle cleanup if idx in self.idx_to_token: token = self.idx_to_token[idx] if token not in ["", "", "", ""]: tokens.append(token) return "".join(tokens) def save(self, path: str): """Save tokenizer""" data = { 'vocab_size': self.vocab_size, 'token_to_idx': self.token_to_idx, 'idx_to_token': {str(k): v for k, v in self.idx_to_token.items()}, 'base_vocab_size': self.base_vocab_size, 'merges': {f"{p[0]}|{p[1]}": m for p, m in self.merges.items()} } with open(path, 'w') as f: json.dump(data, f, indent=2) def load(self, path: str): """Load tokenizer""" with open(path, 'r') as f: data = json.load(f) self.vocab_size = data['vocab_size'] self.token_to_idx = data['token_to_idx'] self.idx_to_token = {int(k): v for k, v in data['idx_to_token'].items()} self.base_vocab_size = data.get('base_vocab_size', 100) # Load merges if 'merges' in data: self.merges = {} for k, v in data['merges'].items(): p = k.split('|') if len(p) == 2: self.merges[(p[0], p[1])] = v @property def vocabulary_size(self) -> int: return len(self.token_to_idx)