Commit
·
211acd5
1
Parent(s):
9e65ae3
algorithm improved
Browse files
ud.py
CHANGED
|
@@ -12,10 +12,13 @@ class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
|
|
| 12 |
x=self.model.config.label2id
|
| 13 |
y=[k for k in x if k.find("|")<0 and not k.startswith("I-")]
|
| 14 |
self.transition=numpy.full((len(x),len(x)),-numpy.inf)
|
|
|
|
| 15 |
for k,v in x.items():
|
| 16 |
if k.find("|")<0:
|
| 17 |
for j in ["I-"+k[2:]] if k.startswith("B-") else [k]+y if k.startswith("I-") else y:
|
| 18 |
self.transition[v,x[j]]=0
|
|
|
|
|
|
|
| 19 |
def check_model_type(self,supported_models):
|
| 20 |
pass
|
| 21 |
def postprocess(self,model_outputs,**kwargs):
|
|
@@ -24,14 +27,18 @@ class BellmanFordTokenClassificationPipeline(TokenClassificationPipeline):
|
|
| 24 |
return self.bellman_ford_token_classification(model_outputs,**kwargs)
|
| 25 |
def bellman_ford_token_classification(self,model_outputs,**kwargs):
|
| 26 |
m=model_outputs["logits"][0].numpy()
|
|
|
|
| 27 |
e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
|
| 28 |
z=e/e.sum(axis=-1,keepdims=True)
|
| 29 |
for i in range(m.shape[0]-1,0,-1):
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
| 31 |
k=[numpy.argmax(m[0]+self.transition[0])]
|
| 32 |
for i in range(1,m.shape[0]):
|
| 33 |
k.append(numpy.argmax(m[i]+self.transition[k[-1]]))
|
| 34 |
-
w=[{"entity":self.model.config.id2label[j],"start":s,"end":e,"score":z[i,j]} for i,((s,e),j) in enumerate(zip(
|
| 35 |
if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
|
| 36 |
for i,t in reversed(list(enumerate(w))):
|
| 37 |
p=t.pop("entity")
|
|
|
|
| 12 |
x=self.model.config.label2id
|
| 13 |
y=[k for k in x if k.find("|")<0 and not k.startswith("I-")]
|
| 14 |
self.transition=numpy.full((len(x),len(x)),-numpy.inf)
|
| 15 |
+
self.sympos=numpy.full((len(x)),-numpy.inf)
|
| 16 |
for k,v in x.items():
|
| 17 |
if k.find("|")<0:
|
| 18 |
for j in ["I-"+k[2:]] if k.startswith("B-") else [k]+y if k.startswith("I-") else y:
|
| 19 |
self.transition[v,x[j]]=0
|
| 20 |
+
if k.startswith("SYM"):
|
| 21 |
+
self.sympos[v]=0
|
| 22 |
def check_model_type(self,supported_models):
|
| 23 |
pass
|
| 24 |
def postprocess(self,model_outputs,**kwargs):
|
|
|
|
| 27 |
return self.bellman_ford_token_classification(model_outputs,**kwargs)
|
| 28 |
def bellman_ford_token_classification(self,model_outputs,**kwargs):
|
| 29 |
m=model_outputs["logits"][0].numpy()
|
| 30 |
+
v=model_outputs["offset_mapping"][0].tolist()
|
| 31 |
e=numpy.exp(m-numpy.max(m,axis=-1,keepdims=True))
|
| 32 |
z=e/e.sum(axis=-1,keepdims=True)
|
| 33 |
for i in range(m.shape[0]-1,0,-1):
|
| 34 |
+
if v[i-1][0]<v[i-1][1]:
|
| 35 |
+
m[i-1]+=numpy.max(m[i]+self.transition,axis=1)
|
| 36 |
+
else:
|
| 37 |
+
m[i-1]+=self.sympos
|
| 38 |
k=[numpy.argmax(m[0]+self.transition[0])]
|
| 39 |
for i in range(1,m.shape[0]):
|
| 40 |
k.append(numpy.argmax(m[i]+self.transition[k[-1]]))
|
| 41 |
+
w=[{"entity":self.model.config.id2label[j],"start":s,"end":e,"score":z[i,j]} for i,((s,e),j) in enumerate(zip(v,k)) if s<e]
|
| 42 |
if "aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none":
|
| 43 |
for i,t in reversed(list(enumerate(w))):
|
| 44 |
p=t.pop("entity")
|