| import torch | |
| from transformers import DebertaV2ForTokenClassification | |
| from transformers.models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING | |
| class SnuDeID(DebertaV2ForTokenClassification): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| if hasattr(self.deberta.encoder, "rel_embeddings"): | |
| self.deberta.encoder.rel_embeddings = torch.nn.Embedding( | |
| config.max_position_embeddings, | |
| config.hidden_size | |
| ) | |
| else: | |
| raise ValueError("Expected DebertaV2 encoder to have rel_embeddings attribute") | |
| MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.register("snu-deid", SnuDeID) |