Gyuseong commited on
Commit
c4255f6
·
verified ·
1 Parent(s): 9923956

Upload snudeid.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. snudeid.py +19 -0
snudeid.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from transformers import DebertaV2ForTokenClassification
4
+
5
+ from transformers.models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
6
+
7
+ class SnuDeID(DebertaV2ForTokenClassification):
8
+ def __init__(self, config):
9
+ super().__init__(config)
10
+ if hasattr(self.deberta.encoder, "rel_embeddings"):
11
+ self.deberta.encoder.rel_embeddings = torch.nn.Embedding(
12
+ config.max_position_embeddings,
13
+ config.hidden_size
14
+ )
15
+ else:
16
+ raise ValueError("Expected DebertaV2 encoder to have rel_embeddings attribute")
17
+
18
+
19
+ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.register("snu-deid", SnuDeID)