File size: 679 Bytes
c4255f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

import torch
from transformers import DebertaV2ForTokenClassification

from transformers.models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING

class SnuDeID(DebertaV2ForTokenClassification):
    def __init__(self, config):
        super().__init__(config)
        if hasattr(self.deberta.encoder, "rel_embeddings"):
            self.deberta.encoder.rel_embeddings = torch.nn.Embedding(
                config.max_position_embeddings,
                config.hidden_size
            )
        else:
            raise ValueError("Expected DebertaV2 encoder to have rel_embeddings attribute")


MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.register("snu-deid", SnuDeID)