|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from sentence_transformers.models import Transformer as BaseTransformer |
|
|
|
|
|
class EmbeddingModel(BaseTransformer): |
|
|
"""Wrapper model for extracting embeddings from a causal language model using SentenceTransformer framework.""" |
|
|
|
|
|
def __init__(self, model_name_or_path=None, base_model=None, tokenizer=None, max_seq_length=512, pooling="last", **kwargs): |
|
|
""" |
|
|
Initialize the embedding model with a base model, tokenizer, and pooling strategy. |
|
|
|
|
|
Args: |
|
|
model_name_or_path: HuggingFace model name or path (used by SentenceTransformer) |
|
|
base_model: Pre-initialized model (prioritized over model_name_or_path) |
|
|
tokenizer: Tokenizer to use (must be provided if base_model is provided) |
|
|
max_seq_length: Maximum sequence length |
|
|
pooling: Pooling strategy ("mean" or "last" token pooling) |
|
|
""" |
|
|
|
|
|
if base_model is not None: |
|
|
if tokenizer is None: |
|
|
raise ValueError("If base_model is provided, tokenizer must also be provided") |
|
|
|
|
|
|
|
|
|
|
|
super(BaseTransformer, self).__init__() |
|
|
|
|
|
|
|
|
self.config = base_model.config |
|
|
self.max_seq_length = max_seq_length |
|
|
self.auto_model = base_model |
|
|
self._tokenizer = tokenizer |
|
|
self.tokenizer = tokenizer |
|
|
self.do_lower_case = getattr(tokenizer, "do_lower_case", False) |
|
|
|
|
|
|
|
|
self.padding_idx = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 |
|
|
|
|
|
|
|
|
self.name = 'Transformer' |
|
|
self.backend = "huggingface_transformers" |
|
|
else: |
|
|
|
|
|
super().__init__(model_name_or_path=model_name_or_path, max_seq_length=max_seq_length, **kwargs) |
|
|
|
|
|
self.pooling = pooling |
|
|
self.embedding_dim = self.auto_model.config.hidden_size |
|
|
|
|
|
|
|
|
if hasattr(self.auto_model, "lm_head"): |
|
|
delattr(self.auto_model, "lm_head") |
|
|
|
|
|
def forward(self, features): |
|
|
""" |
|
|
Forward pass through the model to get embeddings. |
|
|
Adapted to work with SentenceTransformer's expected format. |
|
|
|
|
|
Args: |
|
|
features: Dictionary with 'input_ids', 'attention_mask', etc. |
|
|
|
|
|
Returns: |
|
|
Dictionary with embeddings |
|
|
""" |
|
|
input_ids = features['input_ids'] |
|
|
attention_mask = features['attention_mask'] |
|
|
|
|
|
|
|
|
outputs = self.auto_model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
output_hidden_states=True, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
|
|
|
hidden_states = outputs.hidden_states[-1] |
|
|
|
|
|
|
|
|
embeddings = self._get_embeddings( |
|
|
hidden_states, |
|
|
input_ids, |
|
|
self._tokenizer.eos_token_id if hasattr(self, '_tokenizer') else self.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
return {'token_embeddings': hidden_states, 'sentence_embedding': embeddings} |
|
|
|
|
|
def _get_embeddings(self, hidden_states, input_ids, eos_token_id, pooling=None): |
|
|
"""Extract embeddings from hidden states using the specified pooling strategy.""" |
|
|
|
|
|
if pooling is None: |
|
|
pooling = self.pooling |
|
|
|
|
|
batch_size = input_ids.shape[0] |
|
|
hidden_dim = hidden_states.size(-1) |
|
|
embeddings = torch.zeros(batch_size, hidden_dim, device=hidden_states.device) |
|
|
|
|
|
tokenizer = self._tokenizer if hasattr(self, '_tokenizer') else self.tokenizer |
|
|
|
|
|
if pooling == "mean": |
|
|
attention_mask = (input_ids != tokenizer.pad_token_id).float() |
|
|
sum_embeddings = torch.sum( |
|
|
hidden_states * attention_mask.unsqueeze(-1), dim=1 |
|
|
) |
|
|
input_mask_sum = torch.sum(attention_mask, dim=1).unsqueeze(-1) |
|
|
input_mask_sum = torch.clamp(input_mask_sum, min=1e-9) |
|
|
embeddings = sum_embeddings / input_mask_sum |
|
|
else: |
|
|
eos_positions = (input_ids == eos_token_id).nonzero(as_tuple=True) |
|
|
batch_indices = eos_positions[0] |
|
|
token_positions = eos_positions[1] |
|
|
has_eos = torch.zeros( |
|
|
batch_size, dtype=torch.bool, device=hidden_states.device |
|
|
) |
|
|
has_eos[batch_indices] = True |
|
|
unique_batch_indices = batch_indices.unique() |
|
|
for i in unique_batch_indices: |
|
|
idx = (batch_indices == i).nonzero(as_tuple=True)[0][0] |
|
|
embeddings[i] = hidden_states[i, token_positions[idx]] |
|
|
|
|
|
non_eos_indices = (~has_eos).nonzero(as_tuple=True)[0] |
|
|
if len(non_eos_indices) > 0: |
|
|
for i in non_eos_indices: |
|
|
mask = (input_ids[i] != tokenizer.pad_token_id).nonzero( |
|
|
as_tuple=True |
|
|
)[0] |
|
|
embeddings[i] = hidden_states[i, mask[-1]] |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def get_sentence_embedding_dimension(self): |
|
|
"""Return the dimension of the sentence embeddings.""" |
|
|
return self.embedding_dim |
|
|
|
|
|
def get_word_embedding_dimension(self): |
|
|
"""Return the dimension of the word/token embeddings.""" |
|
|
return self.embedding_dim |