File size: 679 Bytes
c4255f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import torch
from transformers import DebertaV2ForTokenClassification
from transformers.models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
class SnuDeID(DebertaV2ForTokenClassification):
def __init__(self, config):
super().__init__(config)
if hasattr(self.deberta.encoder, "rel_embeddings"):
self.deberta.encoder.rel_embeddings = torch.nn.Embedding(
config.max_position_embeddings,
config.hidden_size
)
else:
raise ValueError("Expected DebertaV2 encoder to have rel_embeddings attribute")
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.register("snu-deid", SnuDeID) |