import torch import torch.nn.functional as F from sentence_transformers.models import Transformer as BaseTransformer class EmbeddingModel(BaseTransformer): """Wrapper model for extracting embeddings from a causal language model using SentenceTransformer framework.""" def __init__(self, model_name_or_path=None, base_model=None, tokenizer=None, max_seq_length=512, pooling="last", **kwargs): """ Initialize the embedding model with a base model, tokenizer, and pooling strategy. Args: model_name_or_path: HuggingFace model name or path (used by SentenceTransformer) base_model: Pre-initialized model (prioritized over model_name_or_path) tokenizer: Tokenizer to use (must be provided if base_model is provided) max_seq_length: Maximum sequence length pooling: Pooling strategy ("mean" or "last" token pooling) """ # If we're given a base_model directly, use that instead of loading from model_name_or_path if base_model is not None: if tokenizer is None: raise ValueError("If base_model is provided, tokenizer must also be provided") # Skip the normal initialization - we'll do that manually # This just initializes the parent nn.Module - we'll handle the rest ourselves super(BaseTransformer, self).__init__() # Set up the model configuration manually self.config = base_model.config self.max_seq_length = max_seq_length self.auto_model = base_model self._tokenizer = tokenizer self.tokenizer = tokenizer # For compatibility self.do_lower_case = getattr(tokenizer, "do_lower_case", False) # For certain models (like Llama), ensure that padding_idx is set correctly self.padding_idx = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 # Additional attributes from BaseTransformer self.name = 'Transformer' self.backend = "huggingface_transformers" # Default backend else: # Use standard initialization from BaseTransformer super().__init__(model_name_or_path=model_name_or_path, max_seq_length=max_seq_length, **kwargs) self.pooling = pooling self.embedding_dim = self.auto_model.config.hidden_size # Remove lm_head if it exists to save memory if hasattr(self.auto_model, "lm_head"): delattr(self.auto_model, "lm_head") def forward(self, features): """ Forward pass through the model to get embeddings. Adapted to work with SentenceTransformer's expected format. Args: features: Dictionary with 'input_ids', 'attention_mask', etc. Returns: Dictionary with embeddings """ input_ids = features['input_ids'] attention_mask = features['attention_mask'] # Get the model outputs outputs = self.auto_model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True ) # Extract embeddings using the specified pooling hidden_states = outputs.hidden_states[-1] # Get embeddings using our pooling method embeddings = self._get_embeddings( hidden_states, input_ids, self._tokenizer.eos_token_id if hasattr(self, '_tokenizer') else self.tokenizer.eos_token_id ) # Return in the format expected by SentenceTransformer return {'token_embeddings': hidden_states, 'sentence_embedding': embeddings} def _get_embeddings(self, hidden_states, input_ids, eos_token_id, pooling=None): """Extract embeddings from hidden states using the specified pooling strategy.""" # Extract embeddings using the pooling strategy if pooling is None: pooling = self.pooling batch_size = input_ids.shape[0] hidden_dim = hidden_states.size(-1) embeddings = torch.zeros(batch_size, hidden_dim, device=hidden_states.device) tokenizer = self._tokenizer if hasattr(self, '_tokenizer') else self.tokenizer if pooling == "mean": attention_mask = (input_ids != tokenizer.pad_token_id).float() sum_embeddings = torch.sum( hidden_states * attention_mask.unsqueeze(-1), dim=1 ) input_mask_sum = torch.sum(attention_mask, dim=1).unsqueeze(-1) input_mask_sum = torch.clamp(input_mask_sum, min=1e-9) embeddings = sum_embeddings / input_mask_sum else: eos_positions = (input_ids == eos_token_id).nonzero(as_tuple=True) batch_indices = eos_positions[0] token_positions = eos_positions[1] has_eos = torch.zeros( batch_size, dtype=torch.bool, device=hidden_states.device ) has_eos[batch_indices] = True unique_batch_indices = batch_indices.unique() for i in unique_batch_indices: idx = (batch_indices == i).nonzero(as_tuple=True)[0][0] embeddings[i] = hidden_states[i, token_positions[idx]] non_eos_indices = (~has_eos).nonzero(as_tuple=True)[0] if len(non_eos_indices) > 0: for i in non_eos_indices: mask = (input_ids[i] != tokenizer.pad_token_id).nonzero( as_tuple=True )[0] embeddings[i] = hidden_states[i, mask[-1]] return embeddings def get_sentence_embedding_dimension(self): """Return the dimension of the sentence embeddings.""" return self.embedding_dim def get_word_embedding_dimension(self): """Return the dimension of the word/token embeddings.""" return self.embedding_dim