SNU_Thunder-DeID-340M / snudeid.py
Gyuseong's picture
Upload snudeid.py with huggingface_hub
c4255f6 verified
raw
history blame contribute delete
679 Bytes
import torch
from transformers import DebertaV2ForTokenClassification
from transformers.models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
class SnuDeID(DebertaV2ForTokenClassification):
def __init__(self, config):
super().__init__(config)
if hasattr(self.deberta.encoder, "rel_embeddings"):
self.deberta.encoder.rel_embeddings = torch.nn.Embedding(
config.max_position_embeddings,
config.hidden_size
)
else:
raise ValueError("Expected DebertaV2 encoder to have rel_embeddings attribute")
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.register("snu-deid", SnuDeID)