| | import random |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import Qwen3PreTrainedModel, Qwen3Config, Qwen3Model |
| | from transformers.models.qwen3.modeling_qwen3 import Qwen3MLP |
| |
|
| |
|
| | class TokenCompressor(nn.Module): |
| | """ |
| | Adaptive Token Compression Module |
| | For sequences exceeding the threshold length, use adaptive_avg_pool1d for compression |
| | Compressed length = threshold + excess_part * compression_ratio |
| | """ |
| |
|
| | def __init__(self, length_threshold: int = 512, compression_ratio: float = 0.3): |
| | super().__init__() |
| | self.length_threshold = length_threshold |
| | self.compression_ratio = compression_ratio |
| |
|
| | def forward( |
| | self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Perform adaptive compression on token embeddings |
| | Args: |
| | token_embeddings: [batch_size, seq_len, hidden_size] |
| | attention_mask: [batch_size, seq_len] |
| | Returns: |
| | compressed_embeddings: Compressed embeddings |
| | compressed_mask: Compressed attention mask |
| | """ |
| | padding_side = 'right' if (attention_mask[:, -1] == 0).any() else 'left' |
| |
|
| | compressed_embeddings_list = [] |
| | compressed_masks_list = [] |
| | for text_idx in range(token_embeddings.shape[0]): |
| | |
| | real_length = int(attention_mask[text_idx].sum().item()) |
| | if real_length <= self.length_threshold: |
| | |
| | if padding_side == 'left': |
| | |
| | valid_embeddings = token_embeddings[text_idx:text_idx + 1, -real_length:, :] |
| | else: |
| | |
| | valid_embeddings = token_embeddings[text_idx:text_idx + 1, :real_length, :] |
| | compressed_embeddings_list.append(valid_embeddings) |
| | compressed_masks_list.append([1] * real_length) |
| | else: |
| | target_length = int( |
| | self.length_threshold + (real_length - self.length_threshold) * self.compression_ratio |
| | ) |
| | |
| | if padding_side == 'left': |
| | |
| | valid_embeddings = token_embeddings[text_idx:text_idx + 1, -real_length:, :] |
| | else: |
| | |
| | valid_embeddings = token_embeddings[text_idx:text_idx + 1, :real_length, :] |
| |
|
| | |
| | compressed_embeddings_list.append( |
| | F.adaptive_avg_pool1d( |
| | valid_embeddings.transpose(1, 2), target_length |
| | ).transpose(1, 2) |
| | ) |
| | |
| | compressed_masks_list.append([1] * target_length) |
| |
|
| | |
| | new_seq_len = max((len(_mask) for _mask in compressed_masks_list)) |
| | new_attention_mask = torch.tensor( |
| | [ |
| | _mask + [0] * (new_seq_len - len(_mask)) |
| | if padding_side == "right" |
| | else |
| | [0] * (new_seq_len - len(_mask)) + _mask |
| | for _mask in compressed_masks_list |
| | ], |
| | dtype=torch.long, |
| | device=token_embeddings.device |
| | ) |
| |
|
| | |
| | batch_size = token_embeddings.shape[0] |
| | hidden_size = token_embeddings.shape[2] |
| | new_token_embeddings = torch.zeros( |
| | batch_size, new_seq_len, hidden_size, |
| | dtype=token_embeddings.dtype, |
| | device=token_embeddings.device |
| | ) |
| |
|
| | for idx, compressed_emb in enumerate(compressed_embeddings_list): |
| | seq_len = compressed_emb.shape[1] |
| | if padding_side == "right": |
| | new_token_embeddings[idx, :seq_len, :] = compressed_emb.squeeze(0) |
| | else: |
| | |
| | new_token_embeddings[idx, -seq_len:, :] = compressed_emb.squeeze(0) |
| |
|
| | return new_token_embeddings, new_attention_mask |
| |
|
| |
|
| | class JasperV2Encoder(Qwen3PreTrainedModel): |
| |
|
| | def __init__(self, config: Qwen3Config): |
| | super().__init__(config) |
| | self.model = Qwen3Model(config) |
| | self.jasper_mlp = Qwen3MLP(config=config) |
| | self.linear_1 = nn.Linear(in_features=config.hidden_size, out_features=2048, bias=True) |
| | self.token_compressor = TokenCompressor(length_threshold=80, compression_ratio=0.5) |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | *args, |
| | **kwargs |
| | ) -> torch.Tensor: |
| | |
| | token_embeddings = self.model.embed_tokens(input_ids) |
| | token_embeddings = self.jasper_mlp(token_embeddings) |
| |
|
| | self.token_compressor.compression_ratio = kwargs.get( |
| | "compression_ratio", |
| | self.token_compressor.compression_ratio |
| | ) |
| | compressed_token_embeddings, attention_mask = self.token_compressor(token_embeddings, attention_mask) |
| | compressed_token_embeddings = self.model( |
| | inputs_embeds=compressed_token_embeddings, attention_mask=attention_mask |
| | )["last_hidden_state"] |
| |
|
| | |
| | input_mask_expanded = ( |
| | attention_mask.unsqueeze(-1).expand(compressed_token_embeddings.size()).to( |
| | compressed_token_embeddings.dtype) |
| | ) |
| | sum_embeddings = torch.sum(compressed_token_embeddings * input_mask_expanded, 1) |
| | sum_mask = input_mask_expanded.sum(1) |
| | sum_mask = torch.clamp(sum_mask, min=1e-9) |
| | vector = sum_embeddings / sum_mask |
| | return self.linear_1(vector) |
| |
|