magic-bert-50m-classification / modeling_magic_bert.py
mjbommar's picture
Upload magic-bert-50m-classification model files
967e04b verified
raw
history blame
13.1 kB
"""MagicBERT model implementation for HuggingFace transformers.
This module provides HuggingFace-compatible implementations of MagicBERT,
a BERT-style model trained for binary file type understanding.
"""
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import (
MaskedLMOutput,
SequenceClassifierOutput,
BaseModelOutput,
)
try:
from .configuration_magic_bert import MagicBERTConfig
except ImportError:
from configuration_magic_bert import MagicBERTConfig
class MagicBERTEmbeddings(nn.Module):
"""MagicBERT embeddings: token + position embeddings."""
def __init__(self, config: MagicBERTConfig):
super().__init__()
self.token_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size
)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer(
"position_ids",
torch.arange(config.max_position_embeddings).expand((1, -1)),
persistent=False,
)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
batch_size, seq_length = input_ids.shape
token_embeds = self.token_embeddings(input_ids)
position_ids = self.position_ids[:, :seq_length]
position_embeds = self.position_embeddings(position_ids)
embeddings = token_embeds + position_embeds
embeddings = self.layer_norm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class MagicBERTAttention(nn.Module):
"""Multi-head self-attention."""
def __init__(self, config: MagicBERTConfig):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = config.hidden_size // config.num_attention_heads
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query_layer = self.transpose_for_scores(self.query(hidden_states))
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :]
attention_scores = attention_scores + (1.0 - attention_mask) * -10000.0
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
context = torch.matmul(attention_probs, value_layer)
context = context.permute(0, 2, 1, 3).contiguous()
new_shape = context.size()[:-2] + (self.all_head_size,)
context = context.view(new_shape)
return context
class MagicBERTLayer(nn.Module):
"""Single transformer layer."""
def __init__(self, config: MagicBERTConfig):
super().__init__()
self.attention = MagicBERTAttention(config)
self.attention_output = nn.Linear(config.hidden_size, config.hidden_size)
self.attention_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention_dropout = nn.Dropout(config.hidden_dropout_prob)
self.intermediate = nn.Linear(config.hidden_size, config.intermediate_size)
self.output = nn.Linear(config.intermediate_size, config.hidden_size)
self.output_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Self-attention with residual
attention_output = self.attention(hidden_states, attention_mask)
attention_output = self.attention_output(attention_output)
attention_output = self.attention_dropout(attention_output)
attention_output = self.attention_norm(hidden_states + attention_output)
# Feed-forward with residual
intermediate_output = self.intermediate(attention_output)
intermediate_output = F.gelu(intermediate_output)
layer_output = self.output(intermediate_output)
layer_output = self.output_dropout(layer_output)
layer_output = self.output_norm(attention_output + layer_output)
return layer_output
class MagicBERTEncoder(nn.Module):
"""Stack of transformer layers."""
def __init__(self, config: MagicBERTConfig):
super().__init__()
self.layers = nn.ModuleList(
[MagicBERTLayer(config) for _ in range(config.num_hidden_layers)]
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask)
return hidden_states
class MagicBERTPreTrainedModel(PreTrainedModel):
"""Base class for MagicBERT models."""
config_class = MagicBERTConfig
base_model_prefix = "magic_bert"
supports_gradient_checkpointing = False
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class MagicBERTModel(MagicBERTPreTrainedModel):
"""MagicBERT base model outputting raw hidden states."""
def __init__(self, config: MagicBERTConfig):
super().__init__(config)
self.config = config
self.embeddings = MagicBERTEmbeddings(config)
self.encoder = MagicBERTEncoder(config)
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None, # Ignored, for tokenizer compatibility
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor], BaseModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.embeddings(input_ids)
sequence_output = self.encoder(hidden_states, attention_mask)
pooled_output = sequence_output[:, 0, :]
if not return_dict:
return (sequence_output, pooled_output)
return BaseModelOutput(
last_hidden_state=sequence_output,
hidden_states=None,
attentions=None,
)
class MagicBERTForMaskedLM(MagicBERTPreTrainedModel):
"""MagicBERT for masked language modeling (fill-mask task)."""
def __init__(self, config: MagicBERTConfig):
super().__init__(config)
self.config = config
self.embeddings = MagicBERTEmbeddings(config)
self.encoder = MagicBERTEncoder(config)
self.mlm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None, # Ignored, for tokenizer compatibility
labels: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...], MaskedLMOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.embeddings(input_ids)
sequence_output = self.encoder(hidden_states, attention_mask)
logits = self.mlm_head(sequence_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None,
)
def get_embeddings(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
pooling: str = "cls",
) -> torch.Tensor:
"""Get embeddings for downstream tasks.
Args:
input_ids: Input token IDs
attention_mask: Attention mask
pooling: Pooling strategy ("cls" or "mean")
Returns:
Pooled embeddings [batch_size, hidden_size]
"""
hidden_states = self.embeddings(input_ids)
sequence_output = self.encoder(hidden_states, attention_mask)
if pooling == "cls":
return sequence_output[:, 0, :]
elif pooling == "mean":
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).float()
return (sequence_output * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
return sequence_output.mean(dim=1)
else:
raise ValueError(f"Unknown pooling: {pooling}")
class MagicBERTForSequenceClassification(MagicBERTPreTrainedModel):
"""MagicBERT for sequence classification (file type classification)."""
def __init__(self, config: MagicBERTConfig):
super().__init__(config)
self.config = config
self.num_labels = getattr(config, "num_labels", 106)
self.embeddings = MagicBERTEmbeddings(config)
self.encoder = MagicBERTEncoder(config)
# Projection head (for contrastive learning compatibility)
projection_dim = getattr(config, "contrastive_projection_dim", 256)
self.projection = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size),
nn.ReLU(),
nn.Linear(config.hidden_size, projection_dim),
)
self.classifier = nn.Linear(projection_dim, self.num_labels)
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None, # Ignored, for tokenizer compatibility
labels: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.embeddings(input_ids)
sequence_output = self.encoder(hidden_states, attention_mask)
pooled_output = sequence_output[:, 0, :]
projections = self.projection(pooled_output)
projections = F.normalize(projections, p=2, dim=1)
logits = self.classifier(projections)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None,
)
def get_embeddings(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Get normalized projection embeddings for similarity search."""
hidden_states = self.embeddings(input_ids)
sequence_output = self.encoder(hidden_states, attention_mask)
pooled_output = sequence_output[:, 0, :]
projections = self.projection(pooled_output)
return F.normalize(projections, p=2, dim=1)