# respected_architecture.py from transformers import AutoModel, PretrainedConfig from transformers.modeling_outputs import SequenceClassifierOutput from torch import nn import torch # Config personalizada class RespectedConfig(PretrainedConfig): model_type = "respected" def __init__( self, base_model_name_or_path: str = "answerdotai/ModernBERT-base", hidden_size: int = 512, num_labels: int = 3, dropout_rate: float = 0.4, **kwargs, ): super().__init__(**kwargs) self.base_model_name_or_path = base_model_name_or_path self.hidden_size = hidden_size self.num_labels = num_labels self.dropout_rate = dropout_rate # Modelo principal class RespectedArchitecture(nn.Module): """ ModernBERT-base → global mean-pool → BiLSTM → token-attention → MLP → clasificación """ def __init__(self, config: RespectedConfig): super().__init__() self.config = config # 1) Backbone ModernBERT (se carga con Remote Code) self.transformer = AutoModel.from_pretrained( config.base_model_name_or_path, trust_remote_code=True, ) # 2) Global mean-pool (reduce L → 1) self.pool = nn.AdaptiveAvgPool1d(1) # 3) BiLSTM bidireccional (1 paso, H → 2H) self.bilstm = nn.LSTM( input_size=self.transformer.config.hidden_size, hidden_size=config.hidden_size, num_layers=1, batch_first=True, bidirectional=True, ) # 4) Atención sobre la salida de la BiLSTM self.attn_score = nn.Linear(2 * config.hidden_size, 1) self.softmax = nn.Softmax(dim=1) # 5) MLP self.mlp = nn.Sequential( nn.Linear(2 * config.hidden_size, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Dropout(config.dropout_rate), ) # 6) Clasificador final self.classifier = nn.Linear(128, config.num_labels) # opcional: devolver el “context” si se pide (para LIME/UMAP) def forward( self, input_ids=None, attention_mask=None, labels=None, return_context: bool = False, ): # a) ModernBERT encoder last_hidden = self.transformer( input_ids=input_ids, attention_mask=attention_mask, ).last_hidden_state # (B, L, H) # b) Global mean-pool x = self.pool(last_hidden.transpose(1, 2)).squeeze(-1).unsqueeze(1) # c) BiLSTM lstm_out, _ = self.bilstm(x) # (B, 1, 2H) # d) Atención sobre L=1 → “context” weights = self.softmax(self.attn_score(lstm_out).squeeze(-1)) # (B, 1) context = torch.sum(weights.unsqueeze(-1) * lstm_out, dim=1) # (B, 2H) # e) MLP head h = self.mlp(context) # (B, 128) # f) Logits finales logits = self.classifier(h) # (B, num_labels) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits, labels) if return_context: return logits, context return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=None, attentions=None, )