Spaces:
Running
Running
| """Subword Tokenizer (BPE-like) for Veda Programming Assistant""" | |
| import json | |
| import re | |
| from typing import List, Dict, Optional, Tuple | |
| class VedaTokenizer: | |
| """ | |
| Subword tokenizer that learns common subwords/phrases. | |
| Better than word-level or char-level tokenization. | |
| """ | |
| def __init__(self, vocab_size: int = 8000): | |
| self.vocab_size = vocab_size | |
| self.token_to_idx: Dict[str, int] = {} | |
| self.idx_to_token: Dict[int, str] = {} | |
| # Base vocabulary (special tokens + ASCII) | |
| self._init_base_vocab() | |
| # Merges for subwords (pair -> new_token) | |
| self.merges: Dict[Tuple[str, str], str] = {} | |
| def _init_base_vocab(self): | |
| """Initialize base vocabulary""" | |
| special = [ | |
| "<PAD>", "<UNK>", "<START>", "<END>", | |
| "<CODE>", "<ENDCODE>", | |
| "<USER>", "<ASSISTANT>" | |
| ] | |
| for idx, token in enumerate(special): | |
| self.token_to_idx[token] = idx | |
| self.idx_to_token[idx] = token | |
| # ASCII characters as base tokens | |
| idx = len(special) | |
| # Printable ASCII range | |
| for i in range(32, 127): | |
| char = chr(i) | |
| if char not in self.token_to_idx: | |
| self.token_to_idx[char] = idx | |
| self.idx_to_token[idx] = char | |
| idx += 1 | |
| # Common whitespace | |
| for char in ["\n", "\t", " "]: # spaces for indentation | |
| if char not in self.token_to_idx: | |
| self.token_to_idx[char] = idx | |
| self.idx_to_token[idx] = char | |
| idx += 1 | |
| self.base_vocab_size = idx | |
| def _get_stats(self, vocab: Dict[Tuple[str, ...], int]) -> Dict[Tuple[str, str], int]: | |
| """Count frequency of adjacent pairs""" | |
| pairs = {} | |
| for word_tuple, freq in vocab.items(): | |
| for i in range(len(word_tuple) - 1): | |
| pair = (word_tuple[i], word_tuple[i+1]) | |
| pairs[pair] = pairs.get(pair, 0) + freq | |
| return pairs | |
| def _merge_vocab(self, pair: Tuple[str, str], vocab: Dict[Tuple[str, ...], int]) -> Dict[Tuple[str, ...], int]: | |
| """Merge all occurrences of pair in vocabulary""" | |
| new_vocab = {} | |
| bigram = pair | |
| new_token = "".join(pair) | |
| for word, freq in vocab.items(): | |
| new_word = [] | |
| i = 0 | |
| while i < len(word): | |
| if i < len(word) - 1 and word[i] == bigram[0] and word[i+1] == bigram[1]: | |
| new_word.append(new_token) | |
| i += 2 | |
| else: | |
| new_word.append(word[i]) | |
| i += 1 | |
| new_vocab[tuple(new_word)] = freq | |
| return new_vocab | |
| def fit(self, texts: List[str]): | |
| """Train BPE tokenizer on texts""" | |
| # Pre-tokenize into words to avoid merging across word boundaries | |
| # This regex splits by whitespace but keeps punctuation | |
| # Also handles code symbols better | |
| word_counts = {} | |
| for text in texts: | |
| # Simple pre-tokenization for code | |
| words = re.findall(r'[a-zA-Z0-9_]+|[^\s\w]', text) | |
| for word in words: | |
| # Convert word to tuple of characters | |
| token_tuple = tuple(c for c in word) | |
| word_counts[token_tuple] = word_counts.get(token_tuple, 0) + 1 | |
| # BPE training loop | |
| vocab = word_counts | |
| num_merges = self.vocab_size - self.base_vocab_size | |
| print(f"Training BPE tokenizer (target vocab: {self.vocab_size})...") | |
| for i in range(num_merges): | |
| pairs = self._get_stats(vocab) | |
| if not pairs: | |
| break | |
| # Find most frequent pair | |
| best_pair = max(pairs, key=pairs.get) | |
| # Stop if pair frequency is too low (e.g., 1) | |
| if pairs[best_pair] < 2: | |
| break | |
| # Merge pair | |
| vocab = self._merge_vocab(best_pair, vocab) | |
| # Add new token to vocabulary | |
| new_token = "".join(best_pair) | |
| self.merges[best_pair] = new_token | |
| idx = len(self.token_to_idx) | |
| self.token_to_idx[new_token] = idx | |
| self.idx_to_token[idx] = new_token | |
| if (i + 1) % 100 == 0: | |
| print(f"BPE merge {i+1}/{num_merges}: '{best_pair[0]}' + '{best_pair[1]}' -> '{new_token}'") | |
| print(f"BPE training complete. Final vocab size: {len(self.token_to_idx)}") | |
| def _tokenize_word(self, word: str) -> List[str]: | |
| """Tokenize a single word using learned merges""" | |
| if word in self.token_to_idx: | |
| return [word] | |
| # Start with characters | |
| tokens = list(word) | |
| # Apply merges iteratively | |
| # Note: In a real BPE implementation we would apply in order of priority | |
| # Here we do a simpler greedy application based on length | |
| while True: | |
| merged = False | |
| i = 0 | |
| new_tokens = [] | |
| while i < len(tokens) - 1: | |
| pair = (tokens[i], tokens[i+1]) | |
| pair_str = "".join(pair) | |
| # Check if this pair forms a known token | |
| if pair_str in self.token_to_idx: | |
| new_tokens.append(pair_str) | |
| i += 2 | |
| merged = True | |
| else: | |
| new_tokens.append(tokens[i]) | |
| i += 1 | |
| if i < len(tokens): | |
| new_tokens.append(tokens[i]) | |
| if not merged: | |
| break | |
| tokens = new_tokens | |
| return tokens | |
| def encode(self, text: str, max_length: Optional[int] = None) -> List[int]: | |
| """Encode text to token indices""" | |
| # Pre-tokenize same way as training | |
| words = re.findall(r'[a-zA-Z0-9_]+|[^\s\w]|\s+', text) | |
| encoded = [] | |
| for word in words: | |
| if word in self.token_to_idx: | |
| encoded.append(self.token_to_idx[word]) | |
| else: | |
| # Apply BPE | |
| subwords = self._tokenize_word(word) | |
| for sw in subwords: | |
| encoded.append(self.token_to_idx.get(sw, self.token_to_idx["<UNK>"])) | |
| # Truncate or Pad | |
| if max_length: | |
| if len(encoded) > max_length: | |
| encoded = encoded[:max_length] | |
| elif len(encoded) < max_length: | |
| encoded += [self.token_to_idx["<PAD>"]] * (max_length - len(encoded)) | |
| return encoded | |
| def decode(self, indices: List[int]) -> str: | |
| """Decode indices to text""" | |
| tokens = [] | |
| for idx in indices: | |
| # Skip special tokens if needed, but usually we decode them | |
| # and let post-processing handle cleanup | |
| if idx in self.idx_to_token: | |
| token = self.idx_to_token[idx] | |
| if token not in ["<PAD>", "<UNK>", "<START>", "<END>"]: | |
| tokens.append(token) | |
| return "".join(tokens) | |
| def save(self, path: str): | |
| """Save tokenizer""" | |
| data = { | |
| 'vocab_size': self.vocab_size, | |
| 'token_to_idx': self.token_to_idx, | |
| 'idx_to_token': {str(k): v for k, v in self.idx_to_token.items()}, | |
| 'base_vocab_size': self.base_vocab_size, | |
| 'merges': {f"{p[0]}|{p[1]}": m for p, m in self.merges.items()} | |
| } | |
| with open(path, 'w') as f: | |
| json.dump(data, f, indent=2) | |
| def load(self, path: str): | |
| """Load tokenizer""" | |
| with open(path, 'r') as f: | |
| data = json.load(f) | |
| self.vocab_size = data['vocab_size'] | |
| self.token_to_idx = data['token_to_idx'] | |
| self.idx_to_token = {int(k): v for k, v in data['idx_to_token'].items()} | |
| self.base_vocab_size = data.get('base_vocab_size', 100) | |
| # Load merges | |
| if 'merges' in data: | |
| self.merges = {} | |
| for k, v in data['merges'].items(): | |
| p = k.split('|') | |
| if len(p) == 2: | |
| self.merges[(p[0], p[1])] = v | |
| def vocabulary_size(self) -> int: | |
| return len(self.token_to_idx) |