armembed / EmbeddingModel.py
alexshah's picture
Remove tokenize
2c57c6a verified
import torch
import torch.nn.functional as F
from sentence_transformers.models import Transformer as BaseTransformer
class EmbeddingModel(BaseTransformer):
"""Wrapper model for extracting embeddings from a causal language model using SentenceTransformer framework."""
def __init__(self, model_name_or_path=None, base_model=None, tokenizer=None, max_seq_length=512, pooling="last", **kwargs):
"""
Initialize the embedding model with a base model, tokenizer, and pooling strategy.
Args:
model_name_or_path: HuggingFace model name or path (used by SentenceTransformer)
base_model: Pre-initialized model (prioritized over model_name_or_path)
tokenizer: Tokenizer to use (must be provided if base_model is provided)
max_seq_length: Maximum sequence length
pooling: Pooling strategy ("mean" or "last" token pooling)
"""
# If we're given a base_model directly, use that instead of loading from model_name_or_path
if base_model is not None:
if tokenizer is None:
raise ValueError("If base_model is provided, tokenizer must also be provided")
# Skip the normal initialization - we'll do that manually
# This just initializes the parent nn.Module - we'll handle the rest ourselves
super(BaseTransformer, self).__init__()
# Set up the model configuration manually
self.config = base_model.config
self.max_seq_length = max_seq_length
self.auto_model = base_model
self._tokenizer = tokenizer
self.tokenizer = tokenizer # For compatibility
self.do_lower_case = getattr(tokenizer, "do_lower_case", False)
# For certain models (like Llama), ensure that padding_idx is set correctly
self.padding_idx = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
# Additional attributes from BaseTransformer
self.name = 'Transformer'
self.backend = "huggingface_transformers" # Default backend
else:
# Use standard initialization from BaseTransformer
super().__init__(model_name_or_path=model_name_or_path, max_seq_length=max_seq_length, **kwargs)
self.pooling = pooling
self.embedding_dim = self.auto_model.config.hidden_size
# Remove lm_head if it exists to save memory
if hasattr(self.auto_model, "lm_head"):
delattr(self.auto_model, "lm_head")
def forward(self, features):
"""
Forward pass through the model to get embeddings.
Adapted to work with SentenceTransformer's expected format.
Args:
features: Dictionary with 'input_ids', 'attention_mask', etc.
Returns:
Dictionary with embeddings
"""
input_ids = features['input_ids']
attention_mask = features['attention_mask']
# Get the model outputs
outputs = self.auto_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True
)
# Extract embeddings using the specified pooling
hidden_states = outputs.hidden_states[-1]
# Get embeddings using our pooling method
embeddings = self._get_embeddings(
hidden_states,
input_ids,
self._tokenizer.eos_token_id if hasattr(self, '_tokenizer') else self.tokenizer.eos_token_id
)
# Return in the format expected by SentenceTransformer
return {'token_embeddings': hidden_states, 'sentence_embedding': embeddings}
def _get_embeddings(self, hidden_states, input_ids, eos_token_id, pooling=None):
"""Extract embeddings from hidden states using the specified pooling strategy."""
# Extract embeddings using the pooling strategy
if pooling is None:
pooling = self.pooling
batch_size = input_ids.shape[0]
hidden_dim = hidden_states.size(-1)
embeddings = torch.zeros(batch_size, hidden_dim, device=hidden_states.device)
tokenizer = self._tokenizer if hasattr(self, '_tokenizer') else self.tokenizer
if pooling == "mean":
attention_mask = (input_ids != tokenizer.pad_token_id).float()
sum_embeddings = torch.sum(
hidden_states * attention_mask.unsqueeze(-1), dim=1
)
input_mask_sum = torch.sum(attention_mask, dim=1).unsqueeze(-1)
input_mask_sum = torch.clamp(input_mask_sum, min=1e-9)
embeddings = sum_embeddings / input_mask_sum
else:
eos_positions = (input_ids == eos_token_id).nonzero(as_tuple=True)
batch_indices = eos_positions[0]
token_positions = eos_positions[1]
has_eos = torch.zeros(
batch_size, dtype=torch.bool, device=hidden_states.device
)
has_eos[batch_indices] = True
unique_batch_indices = batch_indices.unique()
for i in unique_batch_indices:
idx = (batch_indices == i).nonzero(as_tuple=True)[0][0]
embeddings[i] = hidden_states[i, token_positions[idx]]
non_eos_indices = (~has_eos).nonzero(as_tuple=True)[0]
if len(non_eos_indices) > 0:
for i in non_eos_indices:
mask = (input_ids[i] != tokenizer.pad_token_id).nonzero(
as_tuple=True
)[0]
embeddings[i] = hidden_states[i, mask[-1]]
return embeddings
def get_sentence_embedding_dimension(self):
"""Return the dimension of the sentence embeddings."""
return self.embedding_dim
def get_word_embedding_dimension(self):
"""Return the dimension of the word/token embeddings."""
return self.embedding_dim