import torch from transformers import DebertaV2ForTokenClassification from transformers.models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING class SnuDeID(DebertaV2ForTokenClassification): def __init__(self, config): super().__init__(config) if hasattr(self.deberta.encoder, "rel_embeddings"): self.deberta.encoder.rel_embeddings = torch.nn.Embedding( config.max_position_embeddings, config.hidden_size ) else: raise ValueError("Expected DebertaV2 encoder to have rel_embeddings attribute") MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.register("snu-deid", SnuDeID)