| | """ |
| | LookingGlass - A DNA Language Model |
| | |
| | Pure PyTorch implementation of LookingGlass, a pretrained language model for DNA sequences. |
| | Based on AWD-LSTM architecture, originally trained with fastai v1. |
| | |
| | Paper: Hoarfrost et al., "Deep learning of a bacterial and archaeal universal language |
| | of life enables transfer learning and illuminates microbial dark matter", |
| | Nature Communications, 2022. |
| | |
| | Usage: |
| | from lookingglass import LookingGlass, LookingGlassTokenizer |
| | |
| | # Load from HuggingFace Hub |
| | model = LookingGlass.from_pretrained('HoarfrostLab/lookingglass-v1') |
| | tokenizer = LookingGlassTokenizer() |
| | |
| | # Or load from local path |
| | model = LookingGlass.from_pretrained('./lookingglass-v1') |
| | |
| | inputs = tokenizer(["GATTACA", "ATCGATCG"], return_tensors=True) |
| | embeddings = model.get_embeddings(inputs['input_ids']) # (batch, 104) |
| | """ |
| |
|
| | import json |
| | import os |
| | import warnings |
| | from dataclasses import dataclass, asdict |
| | from typing import Optional, Tuple, List, Dict, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | try: |
| | from huggingface_hub import hf_hub_download |
| | HF_HUB_AVAILABLE = True |
| | except ImportError: |
| | HF_HUB_AVAILABLE = False |
| |
|
| |
|
| | __version__ = "1.1.0" |
| |
|
| |
|
| | def _is_hf_hub_id(path: str) -> bool: |
| | """Check if path looks like a HuggingFace Hub model ID (e.g., 'user/model').""" |
| | if os.path.exists(path): |
| | return False |
| | return '/' in path and not path.startswith(('.', '/')) |
| |
|
| |
|
| | def _download_from_hub(repo_id: str, filename: str) -> str: |
| | """Download a file from HuggingFace Hub and return the local path.""" |
| | if not HF_HUB_AVAILABLE: |
| | raise ImportError( |
| | "huggingface_hub is required to load models from the Hub. " |
| | "Install it with: pip install huggingface_hub" |
| | ) |
| | return hf_hub_download(repo_id=repo_id, filename=filename) |
| | __all__ = [ |
| | "LookingGlassConfig", |
| | "LookingGlass", |
| | "LookingGlassLM", |
| | "LookingGlassTokenizer", |
| | ] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @dataclass |
| | class LookingGlassConfig: |
| | """ |
| | Configuration for LookingGlass model. |
| | |
| | Default values match the original pretrained LookingGlass model. |
| | """ |
| | vocab_size: int = 8 |
| | hidden_size: int = 104 |
| | intermediate_size: int = 1152 |
| | num_hidden_layers: int = 3 |
| | pad_token_id: int = 1 |
| | bos_token_id: int = 2 |
| | eos_token_id: int = 3 |
| | bidirectional: bool = False |
| | output_dropout: float = 0.1 |
| | hidden_dropout: float = 0.15 |
| | input_dropout: float = 0.25 |
| | embed_dropout: float = 0.02 |
| | weight_dropout: float = 0.2 |
| | tie_weights: bool = True |
| | output_bias: bool = True |
| | model_type: str = "lookingglass" |
| |
|
| | def to_dict(self) -> Dict: |
| | return asdict(self) |
| |
|
| | def save_pretrained(self, save_directory: str): |
| | os.makedirs(save_directory, exist_ok=True) |
| | with open(os.path.join(save_directory, "config.json"), 'w') as f: |
| | json.dump(self.to_dict(), f, indent=2) |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_path: str) -> "LookingGlassConfig": |
| | if _is_hf_hub_id(pretrained_path): |
| | try: |
| | config_path = _download_from_hub(pretrained_path, "config.json") |
| | except Exception: |
| | return cls() |
| | elif os.path.isdir(pretrained_path): |
| | config_path = os.path.join(pretrained_path, "config.json") |
| | else: |
| | config_path = pretrained_path |
| |
|
| | if os.path.exists(config_path): |
| | with open(config_path, 'r') as f: |
| | config_dict = json.load(f) |
| | valid_fields = {f.name for f in cls.__dataclass_fields__.values()} |
| | return cls(**{k: v for k, v in config_dict.items() if k in valid_fields}) |
| | return cls() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | VOCAB = ['xxunk', 'xxpad', 'xxbos', 'xxeos', 'G', 'A', 'C', 'T'] |
| | VOCAB_TO_ID = {tok: i for i, tok in enumerate(VOCAB)} |
| | ID_TO_VOCAB = {i: tok for i, tok in enumerate(VOCAB)} |
| |
|
| |
|
| | class LookingGlassTokenizer: |
| | """ |
| | Tokenizer for DNA sequences. |
| | |
| | Each nucleotide (G, A, C, T) is a single token. By default, adds BOS token |
| | at the start of each sequence (matching original LookingGlass training). |
| | |
| | Special tokens: |
| | - xxunk (0): Unknown |
| | - xxpad (1): Padding |
| | - xxbos (2): Beginning of sequence |
| | - xxeos (3): End of sequence |
| | """ |
| |
|
| | vocab = VOCAB |
| | vocab_to_id = VOCAB_TO_ID |
| | id_to_vocab = ID_TO_VOCAB |
| |
|
| | def __init__( |
| | self, |
| | add_bos_token: bool = True, |
| | add_eos_token: bool = False, |
| | padding_side: str = "right", |
| | ): |
| | self.add_bos_token = add_bos_token |
| | self.add_eos_token = add_eos_token |
| | self.padding_side = padding_side |
| |
|
| | self.unk_token_id = 0 |
| | self.pad_token_id = 1 |
| | self.bos_token_id = 2 |
| | self.eos_token_id = 3 |
| |
|
| | @property |
| | def vocab_size(self) -> int: |
| | return len(self.vocab) |
| |
|
| | def encode(self, sequence: str, add_special_tokens: bool = True) -> List[int]: |
| | """Encode a DNA sequence to token IDs.""" |
| | tokens = [] |
| |
|
| | if add_special_tokens and self.add_bos_token: |
| | tokens.append(self.bos_token_id) |
| |
|
| | for char in sequence.upper(): |
| | if char in self.vocab_to_id: |
| | tokens.append(self.vocab_to_id[char]) |
| | elif char.strip(): |
| | tokens.append(self.unk_token_id) |
| |
|
| | if add_special_tokens and self.add_eos_token: |
| | tokens.append(self.eos_token_id) |
| |
|
| | return tokens |
| |
|
| | def decode(self, token_ids: Union[List[int], torch.Tensor], skip_special_tokens: bool = True) -> str: |
| | """Decode token IDs back to DNA sequence.""" |
| | if isinstance(token_ids, torch.Tensor): |
| | token_ids = token_ids.tolist() |
| |
|
| | special_ids = {0, 1, 2, 3} |
| | tokens = [] |
| | for tid in token_ids: |
| | if skip_special_tokens and tid in special_ids: |
| | continue |
| | tokens.append(self.id_to_vocab.get(tid, 'xxunk')) |
| | return ''.join(tokens) |
| |
|
| | def __call__( |
| | self, |
| | sequences: Union[str, List[str]], |
| | padding: Union[bool, str] = False, |
| | max_length: Optional[int] = None, |
| | truncation: bool = False, |
| | return_tensors: Union[bool, str] = False, |
| | return_attention_mask: bool = True, |
| | ) -> Dict[str, torch.Tensor]: |
| | """Tokenize DNA sequence(s).""" |
| | if isinstance(sequences, str): |
| | sequences = [sequences] |
| | single = True |
| | else: |
| | single = False |
| |
|
| | encoded = [self.encode(seq) for seq in sequences] |
| |
|
| | if truncation and max_length: |
| | encoded = [e[:max_length] for e in encoded] |
| |
|
| | |
| | if padding or len(encoded) > 1: |
| | if padding == 'max_length' and max_length: |
| | pad_len = max_length |
| | else: |
| | pad_len = max(len(e) for e in encoded) |
| |
|
| | padded = [] |
| | masks = [] |
| | for e in encoded: |
| | pad_amount = pad_len - len(e) |
| | mask = [1] * len(e) + [0] * pad_amount |
| | if self.padding_side == 'right': |
| | e = e + [self.pad_token_id] * pad_amount |
| | else: |
| | e = [self.pad_token_id] * pad_amount + e |
| | mask = [0] * pad_amount + [1] * len(e) |
| | padded.append(e) |
| | masks.append(mask) |
| | encoded = padded |
| | else: |
| | masks = [[1] * len(e) for e in encoded] |
| |
|
| | result = {} |
| | if return_tensors in ('pt', True): |
| | result['input_ids'] = torch.tensor(encoded, dtype=torch.long) |
| | if return_attention_mask: |
| | result['attention_mask'] = torch.tensor(masks, dtype=torch.long) |
| | else: |
| | result['input_ids'] = encoded[0] if single else encoded |
| | if return_attention_mask: |
| | result['attention_mask'] = masks[0] if single else masks |
| |
|
| | return result |
| |
|
| | def save_pretrained(self, save_directory: str): |
| | os.makedirs(save_directory, exist_ok=True) |
| | with open(os.path.join(save_directory, "vocab.json"), 'w') as f: |
| | json.dump(self.vocab_to_id, f, indent=2) |
| | with open(os.path.join(save_directory, "tokenizer_config.json"), 'w') as f: |
| | json.dump({ |
| | "add_bos_token": self.add_bos_token, |
| | "add_eos_token": self.add_eos_token, |
| | "padding_side": self.padding_side, |
| | }, f, indent=2) |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_path: str) -> "LookingGlassTokenizer": |
| | kwargs = {} |
| | if _is_hf_hub_id(pretrained_path): |
| | try: |
| | config_path = _download_from_hub(pretrained_path, "tokenizer_config.json") |
| | with open(config_path, 'r') as f: |
| | kwargs = json.load(f) |
| | except Exception: |
| | pass |
| | else: |
| | config_path = os.path.join(pretrained_path, "tokenizer_config.json") |
| | if os.path.exists(config_path): |
| | with open(config_path, 'r') as f: |
| | kwargs = json.load(f) |
| | return cls(**kwargs) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _dropout_mask(x: torch.Tensor, size: Tuple[int, ...], p: float) -> torch.Tensor: |
| | """Create dropout mask with inverted scaling.""" |
| | return x.new_empty(*size).bernoulli_(1 - p).div_(1 - p) |
| |
|
| |
|
| | class _RNNDropout(nn.Module): |
| | """Dropout consistent across sequence dimension.""" |
| |
|
| | def __init__(self, p: float = 0.5): |
| | super().__init__() |
| | self.p = p |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | if not self.training or self.p == 0.: |
| | return x |
| | mask = _dropout_mask(x.data, (x.size(0), 1, x.size(2)), self.p) |
| | return x * mask |
| |
|
| |
|
| | class _EmbeddingDropout(nn.Module): |
| | """Dropout applied to entire embedding rows.""" |
| |
|
| | def __init__(self, embedding: nn.Embedding, p: float): |
| | super().__init__() |
| | self.embedding = embedding |
| | self.p = p |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | if self.training and self.p != 0: |
| | mask = _dropout_mask(self.embedding.weight.data, |
| | (self.embedding.weight.size(0), 1), self.p) |
| | masked_weight = self.embedding.weight * mask |
| | else: |
| | masked_weight = self.embedding.weight |
| |
|
| | padding_idx = self.embedding.padding_idx if self.embedding.padding_idx is not None else -1 |
| | return F.embedding(x, masked_weight, padding_idx, |
| | self.embedding.max_norm, self.embedding.norm_type, |
| | self.embedding.scale_grad_by_freq, self.embedding.sparse) |
| |
|
| |
|
| | class _WeightDropout(nn.Module): |
| | """DropConnect applied to RNN hidden-to-hidden weights.""" |
| |
|
| | def __init__(self, module: nn.Module, p: float, layer_names='weight_hh_l0'): |
| | super().__init__() |
| | self.module = module |
| | self.p = p |
| | self.layer_names = [layer_names] if isinstance(layer_names, str) else layer_names |
| |
|
| | for layer in self.layer_names: |
| | w = getattr(self.module, layer) |
| | delattr(self.module, layer) |
| | self.register_parameter(f'{layer}_raw', nn.Parameter(w.data)) |
| | setattr(self.module, layer, w.clone()) |
| |
|
| | if isinstance(self.module, nn.RNNBase): |
| | self.module.flatten_parameters = lambda: None |
| |
|
| | def _set_weights(self): |
| | for layer in self.layer_names: |
| | raw_w = getattr(self, f'{layer}_raw') |
| | w = F.dropout(raw_w, p=self.p, training=self.training) if self.training else raw_w.clone() |
| | setattr(self.module, layer, w) |
| |
|
| | def forward(self, *args): |
| | self._set_weights() |
| | with warnings.catch_warnings(): |
| | warnings.simplefilter("ignore", category=UserWarning) |
| | return self.module(*args) |
| |
|
| |
|
| | class _AWDLSTMEncoder(nn.Module): |
| | """AWD-LSTM encoder backbone.""" |
| |
|
| | _init_range = 0.1 |
| |
|
| | def __init__(self, config: LookingGlassConfig): |
| | super().__init__() |
| | self.config = config |
| | self.hidden_size = config.hidden_size |
| | self.intermediate_size = config.intermediate_size |
| | self.num_layers = config.num_hidden_layers |
| | self.num_directions = 2 if config.bidirectional else 1 |
| | self._batch_size = 1 |
| |
|
| | |
| | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, |
| | padding_idx=config.pad_token_id) |
| | self.embed_tokens.weight.data.uniform_(-self._init_range, self._init_range) |
| | self.embed_dropout = _EmbeddingDropout(self.embed_tokens, config.embed_dropout) |
| |
|
| | |
| | self.layers = nn.ModuleList() |
| | for i in range(config.num_hidden_layers): |
| | input_size = config.hidden_size if i == 0 else config.intermediate_size |
| | output_size = (config.intermediate_size if i != config.num_hidden_layers - 1 |
| | else config.hidden_size) // self.num_directions |
| | lstm = nn.LSTM(input_size, output_size, num_layers=1, |
| | batch_first=True, bidirectional=config.bidirectional) |
| | self.layers.append(_WeightDropout(lstm, config.weight_dropout)) |
| |
|
| | |
| | self.input_dropout = _RNNDropout(config.input_dropout) |
| | self.hidden_dropout = nn.ModuleList([ |
| | _RNNDropout(config.hidden_dropout) for _ in range(config.num_hidden_layers) |
| | ]) |
| |
|
| | self._hidden_state = None |
| | self.reset() |
| |
|
| | def reset(self): |
| | """Reset LSTM hidden states.""" |
| | self._hidden_state = [self._init_hidden(i) for i in range(self.num_layers)] |
| |
|
| | def _init_hidden(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| | nh = (self.intermediate_size if layer_idx != self.num_layers - 1 |
| | else self.hidden_size) // self.num_directions |
| | weight = next(self.parameters()) |
| | return (weight.new_zeros(self.num_directions, self._batch_size, nh), |
| | weight.new_zeros(self.num_directions, self._batch_size, nh)) |
| |
|
| | def _resize_hidden(self, batch_size: int): |
| | new_hidden = [] |
| | for i in range(self.num_layers): |
| | nh = (self.intermediate_size if i != self.num_layers - 1 |
| | else self.hidden_size) // self.num_directions |
| | h, c = self._hidden_state[i] |
| |
|
| | if self._batch_size < batch_size: |
| | h = torch.cat([h, h.new_zeros(self.num_directions, batch_size - self._batch_size, nh)], dim=1) |
| | c = torch.cat([c, c.new_zeros(self.num_directions, batch_size - self._batch_size, nh)], dim=1) |
| | elif self._batch_size > batch_size: |
| | h = h[:, :batch_size].contiguous() |
| | c = c[:, :batch_size].contiguous() |
| | new_hidden.append((h, c)) |
| |
|
| | self._hidden_state = new_hidden |
| | self._batch_size = batch_size |
| |
|
| | def forward(self, input_ids: torch.LongTensor) -> torch.Tensor: |
| | """Returns hidden states for all positions: (batch, seq_len, hidden_size)""" |
| | batch_size, seq_len = input_ids.shape |
| |
|
| | if batch_size != self._batch_size: |
| | self._resize_hidden(batch_size) |
| |
|
| | hidden = self.input_dropout(self.embed_dropout(input_ids)) |
| |
|
| | new_hidden = [] |
| | for i, (layer, hdp) in enumerate(zip(self.layers, self.hidden_dropout)): |
| | hidden, h = layer(hidden, self._hidden_state[i]) |
| | new_hidden.append(h) |
| | if i != self.num_layers - 1: |
| | hidden = hdp(hidden) |
| |
|
| | self._hidden_state = [(h.detach(), c.detach()) for h, c in new_hidden] |
| | return hidden |
| |
|
| |
|
| | class _LMHead(nn.Module): |
| | """Language modeling head.""" |
| |
|
| | _init_range = 0.1 |
| |
|
| | def __init__(self, config: LookingGlassConfig, embed_tokens: Optional[nn.Embedding] = None): |
| | super().__init__() |
| | self.output_dropout = _RNNDropout(config.output_dropout) |
| | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.output_bias) |
| | self.decoder.weight.data.uniform_(-self._init_range, self._init_range) |
| |
|
| | if config.output_bias: |
| | self.decoder.bias.data.zero_() |
| |
|
| | if embed_tokens is not None and config.tie_weights: |
| | self.decoder.weight = embed_tokens.weight |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | return self.decoder(self.output_dropout(hidden_states)) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class LookingGlass(nn.Module): |
| | """ |
| | LookingGlass encoder model. |
| | |
| | Outputs sequence embeddings for downstream tasks (classification, clustering, etc.). |
| | Uses last-token embedding by default, matching original LookingGlass. |
| | |
| | Example: |
| | >>> model = LookingGlass.from_pretrained('lookingglass-v1') |
| | >>> tokenizer = LookingGlassTokenizer() |
| | >>> inputs = tokenizer("GATTACA", return_tensors=True) |
| | >>> embeddings = model.get_embeddings(inputs['input_ids']) # (1, 104) |
| | """ |
| |
|
| | config_class = LookingGlassConfig |
| |
|
| | def __init__(self, config: Optional[LookingGlassConfig] = None): |
| | super().__init__() |
| | self.config = config or LookingGlassConfig() |
| | self.encoder = _AWDLSTMEncoder(self.config) |
| |
|
| | def reset(self): |
| | """Reset hidden states.""" |
| | self.encoder.reset() |
| |
|
| | def forward(self, input_ids: torch.LongTensor, **kwargs) -> torch.Tensor: |
| | """ |
| | Forward pass. Returns last-token embeddings. |
| | |
| | Args: |
| | input_ids: Token indices (batch, seq_len) |
| | |
| | Returns: |
| | Embeddings (batch, hidden_size) |
| | """ |
| | return self.get_embeddings(input_ids) |
| |
|
| | def get_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: |
| | """ |
| | Get sequence embeddings using last-token pooling (original LG method). |
| | |
| | Resets hidden state before encoding for deterministic results. |
| | |
| | Args: |
| | input_ids: Token indices (batch, seq_len) |
| | |
| | Returns: |
| | Embeddings (batch, hidden_size) |
| | """ |
| | self.encoder.reset() |
| | hidden = self.encoder(input_ids) |
| | return hidden[:, -1] |
| |
|
| | def get_hidden_states(self, input_ids: torch.LongTensor) -> torch.Tensor: |
| | """ |
| | Get hidden states for all positions. |
| | |
| | Resets hidden state before encoding for deterministic results. |
| | |
| | Args: |
| | input_ids: Token indices (batch, seq_len) |
| | |
| | Returns: |
| | Hidden states (batch, seq_len, hidden_size) |
| | """ |
| | self.encoder.reset() |
| | return self.encoder(input_ids) |
| |
|
| | def save_pretrained(self, save_directory: str): |
| | os.makedirs(save_directory, exist_ok=True) |
| | self.config.save_pretrained(save_directory) |
| | torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin")) |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_path: str, config: Optional[LookingGlassConfig] = None) -> "LookingGlass": |
| | config = config or LookingGlassConfig.from_pretrained(pretrained_path) |
| | model = cls(config) |
| |
|
| | if _is_hf_hub_id(pretrained_path): |
| | model_path = _download_from_hub(pretrained_path, "pytorch_model.bin") |
| | else: |
| | model_path = os.path.join(pretrained_path, "pytorch_model.bin") |
| |
|
| | if os.path.exists(model_path): |
| | state_dict = torch.load(model_path, map_location='cpu') |
| | |
| | encoder_state_dict = {k: v for k, v in state_dict.items() |
| | if not k.startswith('lm_head.')} |
| | model.load_state_dict(encoder_state_dict, strict=False) |
| |
|
| | return model |
| |
|
| |
|
| | class LookingGlassLM(nn.Module): |
| | """ |
| | LookingGlass with language modeling head. |
| | |
| | Full model for next-token prediction. Can also extract embeddings. |
| | |
| | Example: |
| | >>> model = LookingGlassLM.from_pretrained('lookingglass-v1') |
| | >>> tokenizer = LookingGlassTokenizer() |
| | >>> inputs = tokenizer("GATTACA", return_tensors=True) |
| | >>> logits = model(inputs['input_ids']) # (1, 8, 8) |
| | >>> embeddings = model.get_embeddings(inputs['input_ids']) # (1, 104) |
| | """ |
| |
|
| | config_class = LookingGlassConfig |
| |
|
| | def __init__(self, config: Optional[LookingGlassConfig] = None): |
| | super().__init__() |
| | self.config = config or LookingGlassConfig() |
| | self.encoder = _AWDLSTMEncoder(self.config) |
| | self.lm_head = _LMHead( |
| | self.config, |
| | embed_tokens=self.encoder.embed_tokens if self.config.tie_weights else None |
| | ) |
| |
|
| | def reset(self): |
| | """Reset hidden states.""" |
| | self.encoder.reset() |
| |
|
| | def forward(self, input_ids: torch.LongTensor, **kwargs) -> torch.Tensor: |
| | """ |
| | Forward pass. Returns logits for next-token prediction. |
| | |
| | Args: |
| | input_ids: Token indices (batch, seq_len) |
| | |
| | Returns: |
| | Logits (batch, seq_len, vocab_size) |
| | """ |
| | hidden = self.encoder(input_ids) |
| | return self.lm_head(hidden) |
| |
|
| | def get_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor: |
| | """ |
| | Get sequence embeddings using last-token pooling. |
| | |
| | Resets hidden state before encoding for deterministic results. |
| | |
| | Args: |
| | input_ids: Token indices (batch, seq_len) |
| | |
| | Returns: |
| | Embeddings (batch, hidden_size) |
| | """ |
| | self.encoder.reset() |
| | hidden = self.encoder(input_ids) |
| | return hidden[:, -1] |
| |
|
| | def get_hidden_states(self, input_ids: torch.LongTensor) -> torch.Tensor: |
| | """ |
| | Get hidden states for all positions. |
| | |
| | Resets hidden state before encoding for deterministic results. |
| | |
| | Args: |
| | input_ids: Token indices (batch, seq_len) |
| | |
| | Returns: |
| | Hidden states (batch, seq_len, hidden_size) |
| | """ |
| | self.encoder.reset() |
| | return self.encoder(input_ids) |
| |
|
| | def save_pretrained(self, save_directory: str): |
| | os.makedirs(save_directory, exist_ok=True) |
| | self.config.save_pretrained(save_directory) |
| | torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin")) |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_path: str, config: Optional[LookingGlassConfig] = None) -> "LookingGlassLM": |
| | config = config or LookingGlassConfig.from_pretrained(pretrained_path) |
| | model = cls(config) |
| |
|
| | if _is_hf_hub_id(pretrained_path): |
| | model_path = _download_from_hub(pretrained_path, "pytorch_model.bin") |
| | else: |
| | model_path = os.path.join(pretrained_path, "pytorch_model.bin") |
| |
|
| | if os.path.exists(model_path): |
| | state_dict = torch.load(model_path, map_location='cpu') |
| | model.load_state_dict(state_dict, strict=False) |
| |
|
| | return model |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def load_original_weights(model: Union[LookingGlass, LookingGlassLM], weights_path: str) -> None: |
| | """ |
| | Load weights from original fastai-trained LookingGlass checkpoint. |
| | |
| | Args: |
| | model: Model to load weights into |
| | weights_path: Path to LookingGlass.pth or LookingGlass_enc.pth |
| | """ |
| | checkpoint = torch.load(weights_path, map_location='cpu') |
| |
|
| | if 'model' in checkpoint: |
| | state_dict = checkpoint['model'] |
| | else: |
| | state_dict = checkpoint |
| |
|
| | is_lm_model = isinstance(model, LookingGlassLM) |
| |
|
| | new_state_dict = {} |
| | for k, v in state_dict.items(): |
| | if '.module.weight_hh_l0' in k: |
| | continue |
| |
|
| | if k.startswith('0.'): |
| | new_k = k[2:] |
| | new_k = new_k.replace('encoder.', 'embed_tokens.') |
| | new_k = new_k.replace('encoder_dp.emb.', 'embed_tokens.') |
| | new_k = new_k.replace('rnns.', 'layers.') |
| | new_k = new_k.replace('hidden_dps.', 'hidden_dropout.') |
| | new_k = new_k.replace('input_dp.', 'input_dropout.') |
| | new_state_dict['encoder.' + new_k] = v |
| |
|
| | elif k.startswith('1.') and is_lm_model: |
| | new_k = k[2:] |
| | new_k = new_k.replace('output_dp.', 'output_dropout.') |
| | new_state_dict['lm_head.' + new_k] = v |
| |
|
| | else: |
| | new_k = k.replace('encoder.', 'embed_tokens.') |
| | new_k = new_k.replace('encoder_dp.emb.', 'embed_tokens.') |
| | new_k = new_k.replace('rnns.', 'layers.') |
| | new_k = new_k.replace('hidden_dps.', 'hidden_dropout.') |
| | new_k = new_k.replace('input_dp.', 'input_dropout.') |
| | new_state_dict['encoder.' + new_k] = v |
| |
|
| | model.load_state_dict(new_state_dict, strict=False) |
| |
|
| |
|
| | def convert_checkpoint(input_path: str, output_dir: str) -> None: |
| | """Convert original checkpoint to new format.""" |
| | config = LookingGlassConfig() |
| | model = LookingGlassLM(config) |
| | load_original_weights(model, input_path) |
| | model.save_pretrained(output_dir) |
| |
|
| | tokenizer = LookingGlassTokenizer() |
| | tokenizer.save_pretrained(output_dir) |
| | print(f"Saved to {output_dir}") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == '__main__': |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser(description='LookingGlass DNA Language Model') |
| | parser.add_argument('--convert', type=str, help='Convert original weights') |
| | parser.add_argument('--output', type=str, default='./lookingglass-v1', help='Output directory') |
| | parser.add_argument('--test', action='store_true', help='Run tests') |
| | args = parser.parse_args() |
| |
|
| | if args.convert: |
| | convert_checkpoint(args.convert, args.output) |
| |
|
| | elif args.test: |
| | print("Testing LookingGlass...\n") |
| |
|
| | tokenizer = LookingGlassTokenizer() |
| | print(f"Vocab: {tokenizer.vocab}") |
| | print(f"BOS token added: {tokenizer.add_bos_token}") |
| | print(f"EOS token added: {tokenizer.add_eos_token}") |
| |
|
| | inputs = tokenizer("GATTACA", return_tensors=True) |
| | print(f"\nTokenized 'GATTACA': {inputs['input_ids']}") |
| | print(f"Decoded: {tokenizer.decode(inputs['input_ids'][0])}") |
| |
|
| | config = LookingGlassConfig() |
| | print(f"\nConfig: bidirectional={config.bidirectional}") |
| |
|
| | |
| | encoder = LookingGlass(config) |
| | print(f"\nLookingGlass params: {sum(p.numel() for p in encoder.parameters()):,}") |
| |
|
| | encoder.eval() |
| | with torch.no_grad(): |
| | emb = encoder.get_embeddings(inputs['input_ids']) |
| | print(f"Embeddings shape: {emb.shape}") |
| |
|
| | |
| | lm = LookingGlassLM(config) |
| | print(f"\nLookingGlassLM params: {sum(p.numel() for p in lm.parameters()):,}") |
| |
|
| | lm.eval() |
| | with torch.no_grad(): |
| | logits = lm(inputs['input_ids']) |
| | emb = lm.get_embeddings(inputs['input_ids']) |
| | print(f"Logits shape: {logits.shape}") |
| | print(f"Embeddings shape: {emb.shape}") |
| |
|
| | print("\nAll tests passed!") |
| |
|
| | else: |
| | parser.print_help() |
| |
|