Spaces:
Sleeping
Sleeping
| from torch.nn import functional as F | |
| from torch.utils.data import Dataset | |
| import numpy as np | |
| import random | |
| import torch | |
| import re | |
| stoi = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, '\n': 10, '000000000000': 11} | |
| itos = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: '\n', 11: '000000000000'} | |
| tok_chars = re.compile(r'000000000000|\d{1}|\n') | |
| def encode(text, stoi, tokenizer): | |
| matches = tokenizer.findall(text) | |
| return [stoi[c] for c in matches if c in stoi] | |
| def decode(encoded, itos): | |
| return ''.join([itos[i] for i in encoded]) | |
| class Dataset: | |
| def __init__(self, data, ctx_len, epoch_length_fixed, time_aug=True): | |
| self.ctx_len = ctx_len | |
| self.epoch_length_fixed = epoch_length_fixed | |
| self.start_token = '000000000000' | |
| self.tokenizer = tok_chars | |
| self.stoi = stoi | |
| self.itos = itos | |
| self.vocab_size = len(stoi) | |
| print('vocab size:', self.vocab_size) | |
| self.data = encode(data, self.stoi, self.tokenizer) | |
| self.data_size = len(self.data) | |
| print(f'data has {self.data_size} tokens') | |
| def __len__(self): | |
| return self.epoch_length_fixed | |
| def __getitem__(self, idx): | |
| cues = [] | |
| idx_randm = random.randint(0, len(self.data) - (self.ctx_len) * 4) | |
| i = idx_randm | |
| while True: | |
| if self.data[i] == self.stoi[self.start_token]: | |
| cues = [i] | |
| break | |
| else: | |
| i = (i + 1) % len(self.data) | |
| if not cues: | |
| return None | |
| start_idx = cues[0] | |
| dix = self.data[start_idx : start_idx + self.ctx_len + 2] | |
| # 96 tick resolution | |
| time_shift = [ | |
| [0, 0, 0, 0, 0, 7, 6, 8, 0, 7, 6, 8, 0], | |
| [0, 0, 0, 0, 1, 5, 3, 6, 1, 5, 3, 6, 0], | |
| ] | |
| data_aug = random.choice([True, False]) | |
| t = dix[2:2 + self.ctx_len] # testing | |
| if data_aug: | |
| ts_rndm = random.choice(time_shift) | |
| ts = ts_rndm * ((self.ctx_len - 1) // len(ts_rndm) + 1) | |
| tsx = torch.tensor(ts[:self.ctx_len]) | |
| for j in reversed(range(len(t))): | |
| if j % 13 not in range(2, 12): | |
| continue | |
| aug_int = t[j] + tsx[j] | |
| if aug_int >= 10 and (aug_int not in [10, 11] or j not in [9, 10]): | |
| left_int = aug_int // 10 | |
| right_int = aug_int % 10 | |
| if j > 0: | |
| t[j - 1] += left_int | |
| t[j] = right_int | |
| else: | |
| t[j] = aug_int | |
| x = t | |
| y = t[1:] + [t[-1]] | |
| else: | |
| x = dix[:-1][:self.ctx_len] | |
| y = dix[1:][:self.ctx_len] | |
| x = torch.tensor(x, dtype=torch.int64) | |
| y = torch.tensor(y, dtype=torch.int64) | |
| return x, y | |
| class TOKENIZER(): | |
| def __init__(self): | |
| self.tokenizer = tok_chars | |
| self.stoi = stoi | |
| self.itos = itos | |
| self.vocab_size = len(self.stoi) | |
| def encode(self, text): | |
| matches = self.tokenizer.findall(text) | |
| return [self.stoi[c] for c in matches if c in self.stoi] | |
| def decode(self, encoded): | |
| return ''.join([self.itos[i] for i in encoded]) | |
| def sample_logits(self, out, x, ctx_len, temperature=1.0, top_k=50): | |
| probs = F.softmax(torch.tensor(out), dim=-1) | |
| if top_k > 0: | |
| top_k = min(top_k, probs.size(-1)) | |
| sorted_probs, sorted_indices = torch.topk(probs, top_k) | |
| probs.fill_(0) | |
| probs.scatter_(dim=-1, index=sorted_indices, src=sorted_probs) | |
| if temperature != 1.0: | |
| probs = probs.pow(1.0 / temperature) | |
| return torch.multinomial(probs, num_samples=1)[0] | |