|
|
|
|
|
|
|
|
|
|
|
import re, sys, unicodedata |
|
|
import codecs |
|
|
|
|
|
remove_tag = True |
|
|
spacelist = [' ', '\t', '\r', '\n'] |
|
|
puncts = [ |
|
|
'!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』', |
|
|
'《', '》', '(', ')', '(', ')', '[', ']', '【', '】', '{', '}', '〔', '〕', |
|
|
'⟨', '⟩', '《', '》' |
|
|
] |
|
|
|
|
|
|
|
|
def characterize(string): |
|
|
res = [] |
|
|
i = 0 |
|
|
while i < len(string): |
|
|
char = string[i] |
|
|
if char in puncts: |
|
|
i += 1 |
|
|
continue |
|
|
cat1 = unicodedata.category(char) |
|
|
|
|
|
if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: |
|
|
i += 1 |
|
|
continue |
|
|
if cat1 == 'Lo': |
|
|
res.append(char) |
|
|
i += 1 |
|
|
else: |
|
|
|
|
|
sep = ' ' |
|
|
if char == '<': sep = '>' |
|
|
j = i + 1 |
|
|
while j < len(string): |
|
|
c = string[j] |
|
|
if ord(c) >= 128 or (c in spacelist) or (c == sep): |
|
|
break |
|
|
j += 1 |
|
|
if j < len(string) and string[j] == '>': |
|
|
j += 1 |
|
|
res.append(string[i:j]) |
|
|
i = j |
|
|
return res |
|
|
|
|
|
|
|
|
def stripoff_tags(x): |
|
|
if not x: return '' |
|
|
chars = [] |
|
|
i = 0 |
|
|
T = len(x) |
|
|
while i < T: |
|
|
if x[i] == '<': |
|
|
while i < T and x[i] != '>': |
|
|
i += 1 |
|
|
i += 1 |
|
|
else: |
|
|
chars.append(x[i]) |
|
|
i += 1 |
|
|
return ''.join(chars) |
|
|
|
|
|
|
|
|
def normalize(sentence, ignore_words, cs, split=None): |
|
|
""" sentence, ignore_words are both in unicode |
|
|
""" |
|
|
new_sentence = [] |
|
|
for token in sentence: |
|
|
x = token |
|
|
if not cs: |
|
|
x = x.upper() |
|
|
if x in ignore_words: |
|
|
continue |
|
|
if remove_tag: |
|
|
x = stripoff_tags(x) |
|
|
x = re.sub(r'[.,!?;:()\[\]{}<>""„""«»‹›\/\\|@#$%^&*_=+~`-]', '', x) |
|
|
|
|
|
if re.search(r'\d', x): |
|
|
continue |
|
|
if not x: |
|
|
continue |
|
|
if split and x in split: |
|
|
new_sentence += split[x] |
|
|
else: |
|
|
new_sentence.append(x) |
|
|
return new_sentence |
|
|
|
|
|
|
|
|
class Calculator: |
|
|
|
|
|
def __init__(self): |
|
|
self.data = {} |
|
|
self.space = [] |
|
|
self.cost = {} |
|
|
self.cost['cor'] = 0 |
|
|
self.cost['sub'] = 1 |
|
|
self.cost['del'] = 1 |
|
|
self.cost['ins'] = 1 |
|
|
|
|
|
def calculate(self, lab, rec): |
|
|
|
|
|
lab.insert(0, '') |
|
|
rec.insert(0, '') |
|
|
while len(self.space) < len(lab): |
|
|
self.space.append([]) |
|
|
for row in self.space: |
|
|
for element in row: |
|
|
element['dist'] = 0 |
|
|
element['error'] = 'non' |
|
|
while len(row) < len(rec): |
|
|
row.append({'dist': 0, 'error': 'non'}) |
|
|
for i in range(len(lab)): |
|
|
self.space[i][0]['dist'] = i |
|
|
self.space[i][0]['error'] = 'del' |
|
|
for j in range(len(rec)): |
|
|
self.space[0][j]['dist'] = j |
|
|
self.space[0][j]['error'] = 'ins' |
|
|
self.space[0][0]['error'] = 'non' |
|
|
for token in lab: |
|
|
if token not in self.data and len(token) > 0: |
|
|
self.data[token] = { |
|
|
'all': 0, |
|
|
'cor': 0, |
|
|
'sub': 0, |
|
|
'ins': 0, |
|
|
'del': 0 |
|
|
} |
|
|
for token in rec: |
|
|
if token not in self.data and len(token) > 0: |
|
|
self.data[token] = { |
|
|
'all': 0, |
|
|
'cor': 0, |
|
|
'sub': 0, |
|
|
'ins': 0, |
|
|
'del': 0 |
|
|
} |
|
|
|
|
|
for i, lab_token in enumerate(lab): |
|
|
for j, rec_token in enumerate(rec): |
|
|
if i == 0 or j == 0: |
|
|
continue |
|
|
min_dist = sys.maxsize |
|
|
min_error = 'none' |
|
|
dist = self.space[i - 1][j]['dist'] + self.cost['del'] |
|
|
error = 'del' |
|
|
if dist < min_dist: |
|
|
min_dist = dist |
|
|
min_error = error |
|
|
dist = self.space[i][j - 1]['dist'] + self.cost['ins'] |
|
|
error = 'ins' |
|
|
if dist < min_dist: |
|
|
min_dist = dist |
|
|
min_error = error |
|
|
if lab_token == rec_token: |
|
|
dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor'] |
|
|
error = 'cor' |
|
|
else: |
|
|
dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub'] |
|
|
error = 'sub' |
|
|
if dist < min_dist: |
|
|
min_dist = dist |
|
|
min_error = error |
|
|
self.space[i][j]['dist'] = min_dist |
|
|
self.space[i][j]['error'] = min_error |
|
|
|
|
|
result = { |
|
|
'lab': [], |
|
|
'rec': [], |
|
|
'all': 0, |
|
|
'cor': 0, |
|
|
'sub': 0, |
|
|
'ins': 0, |
|
|
'del': 0 |
|
|
} |
|
|
i = len(lab) - 1 |
|
|
j = len(rec) - 1 |
|
|
while True: |
|
|
if self.space[i][j]['error'] == 'cor': |
|
|
if len(lab[i]) > 0: |
|
|
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 |
|
|
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 |
|
|
result['all'] = result['all'] + 1 |
|
|
result['cor'] = result['cor'] + 1 |
|
|
result['lab'].insert(0, lab[i]) |
|
|
result['rec'].insert(0, rec[j]) |
|
|
i = i - 1 |
|
|
j = j - 1 |
|
|
elif self.space[i][j]['error'] == 'sub': |
|
|
if len(lab[i]) > 0: |
|
|
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 |
|
|
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 |
|
|
result['all'] = result['all'] + 1 |
|
|
result['sub'] = result['sub'] + 1 |
|
|
result['lab'].insert(0, lab[i]) |
|
|
result['rec'].insert(0, rec[j]) |
|
|
i = i - 1 |
|
|
j = j - 1 |
|
|
elif self.space[i][j]['error'] == 'del': |
|
|
if len(lab[i]) > 0: |
|
|
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 |
|
|
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 |
|
|
result['all'] = result['all'] + 1 |
|
|
result['del'] = result['del'] + 1 |
|
|
result['lab'].insert(0, lab[i]) |
|
|
result['rec'].insert(0, "") |
|
|
i = i - 1 |
|
|
elif self.space[i][j]['error'] == 'ins': |
|
|
if len(rec[j]) > 0: |
|
|
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 |
|
|
result['ins'] = result['ins'] + 1 |
|
|
result['lab'].insert(0, "") |
|
|
result['rec'].insert(0, rec[j]) |
|
|
j = j - 1 |
|
|
elif self.space[i][j]['error'] == 'non': |
|
|
break |
|
|
else: |
|
|
print( |
|
|
'this should not happen , i = {i} , j = {j} , error = {error}' |
|
|
.format(i=i, j=j, error=self.space[i][j]['error'])) |
|
|
return result |
|
|
|
|
|
def overall(self): |
|
|
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} |
|
|
for token in self.data: |
|
|
result['all'] = result['all'] + self.data[token]['all'] |
|
|
result['cor'] = result['cor'] + self.data[token]['cor'] |
|
|
result['sub'] = result['sub'] + self.data[token]['sub'] |
|
|
result['ins'] = result['ins'] + self.data[token]['ins'] |
|
|
result['del'] = result['del'] + self.data[token]['del'] |
|
|
return result |
|
|
|
|
|
def cluster(self, data): |
|
|
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} |
|
|
for token in data: |
|
|
if token in self.data: |
|
|
result['all'] = result['all'] + self.data[token]['all'] |
|
|
result['cor'] = result['cor'] + self.data[token]['cor'] |
|
|
result['sub'] = result['sub'] + self.data[token]['sub'] |
|
|
result['ins'] = result['ins'] + self.data[token]['ins'] |
|
|
result['del'] = result['del'] + self.data[token]['del'] |
|
|
return result |
|
|
|
|
|
def keys(self): |
|
|
return list(self.data.keys()) |
|
|
|
|
|
|
|
|
def width(string): |
|
|
return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) |
|
|
|
|
|
|
|
|
def default_cluster(word): |
|
|
|
|
|
|
|
|
unicode_names = [] |
|
|
for char in word: |
|
|
try: |
|
|
unicode_names.append(unicodedata.name(char)) |
|
|
except ValueError: |
|
|
unicode_names.append("UNK") |
|
|
for i in reversed(range(len(unicode_names))): |
|
|
if unicode_names[i].startswith('DIGIT'): |
|
|
unicode_names[i] = 'Number' |
|
|
elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') |
|
|
or unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')): |
|
|
|
|
|
unicode_names[i] = 'Mandarin' |
|
|
elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') |
|
|
or unicode_names[i].startswith('LATIN SMALL LETTER')): |
|
|
|
|
|
unicode_names[i] = 'English' |
|
|
elif unicode_names[i].startswith('HIRAGANA LETTER'): |
|
|
unicode_names[i] = 'Japanese' |
|
|
elif (unicode_names[i].startswith('AMPERSAND') |
|
|
or unicode_names[i].startswith('APOSTROPHE') |
|
|
or unicode_names[i].startswith('COMMERCIAL AT') |
|
|
or unicode_names[i].startswith('DEGREE CELSIUS') |
|
|
or unicode_names[i].startswith('EQUALS SIGN') |
|
|
or unicode_names[i].startswith('FULL STOP') |
|
|
or unicode_names[i].startswith('HYPHEN-MINUS') |
|
|
or unicode_names[i].startswith('LOW LINE') |
|
|
or unicode_names[i].startswith('NUMBER SIGN') |
|
|
or unicode_names[i].startswith('PLUS SIGN') |
|
|
or unicode_names[i].startswith('SEMICOLON')): |
|
|
|
|
|
del unicode_names[i] |
|
|
else: |
|
|
return 'Other' |
|
|
if len(unicode_names) == 0: |
|
|
return 'Other' |
|
|
if len(unicode_names) == 1: |
|
|
return unicode_names[0] |
|
|
for i in range(len(unicode_names) - 1): |
|
|
if unicode_names[i] != unicode_names[i + 1]: |
|
|
return 'Other' |
|
|
return unicode_names[0] |
|
|
|
|
|
|
|
|
def usage(): |
|
|
print( |
|
|
"compute-wer.py : compute word error rate (WER) and align recognition results and references." |
|
|
) |
|
|
print( |
|
|
" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer" |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
if len(sys.argv) == 1: |
|
|
usage() |
|
|
sys.exit(0) |
|
|
calculator = Calculator() |
|
|
cluster_file = '' |
|
|
ignore_words = set() |
|
|
tochar = False |
|
|
verbose = 1 |
|
|
padding_symbol = ' ' |
|
|
case_sensitive = False |
|
|
max_words_per_line = sys.maxsize |
|
|
split = None |
|
|
while len(sys.argv) > 3: |
|
|
a = '--maxw=' |
|
|
if sys.argv[1].startswith(a): |
|
|
b = sys.argv[1][len(a):] |
|
|
del sys.argv[1] |
|
|
max_words_per_line = int(b) |
|
|
continue |
|
|
a = '--rt=' |
|
|
if sys.argv[1].startswith(a): |
|
|
b = sys.argv[1][len(a):].lower() |
|
|
del sys.argv[1] |
|
|
remove_tag = (b == 'true') or (b != '0') |
|
|
continue |
|
|
a = '--cs=' |
|
|
if sys.argv[1].startswith(a): |
|
|
b = sys.argv[1][len(a):].lower() |
|
|
del sys.argv[1] |
|
|
case_sensitive = (b == 'true') or (b != '0') |
|
|
continue |
|
|
a = '--cluster=' |
|
|
if sys.argv[1].startswith(a): |
|
|
cluster_file = sys.argv[1][len(a):] |
|
|
del sys.argv[1] |
|
|
continue |
|
|
a = '--splitfile=' |
|
|
if sys.argv[1].startswith(a): |
|
|
split_file = sys.argv[1][len(a):] |
|
|
del sys.argv[1] |
|
|
split = dict() |
|
|
with codecs.open(split_file, 'r', 'utf-8') as fh: |
|
|
for line in fh: |
|
|
words = line.strip().split() |
|
|
if len(words) >= 2: |
|
|
split[words[0]] = words[1:] |
|
|
continue |
|
|
a = '--ig=' |
|
|
if sys.argv[1].startswith(a): |
|
|
ignore_file = sys.argv[1][len(a):] |
|
|
del sys.argv[1] |
|
|
with codecs.open(ignore_file, 'r', 'utf-8') as fh: |
|
|
for line in fh: |
|
|
line = line.strip() |
|
|
if len(line) > 0: |
|
|
ignore_words.add(line) |
|
|
continue |
|
|
a = '--char=' |
|
|
if sys.argv[1].startswith(a): |
|
|
b = sys.argv[1][len(a):].lower() |
|
|
del sys.argv[1] |
|
|
tochar = (b == 'true') or (b != '0') |
|
|
continue |
|
|
a = '--v=' |
|
|
if sys.argv[1].startswith(a): |
|
|
b = sys.argv[1][len(a):].lower() |
|
|
del sys.argv[1] |
|
|
verbose = 0 |
|
|
try: |
|
|
verbose = int(b) |
|
|
except: |
|
|
if b == 'true' or b != '0': |
|
|
verbose = 1 |
|
|
continue |
|
|
a = '--padding-symbol=' |
|
|
if sys.argv[1].startswith(a): |
|
|
b = sys.argv[1][len(a):].lower() |
|
|
del sys.argv[1] |
|
|
if b == 'space': |
|
|
padding_symbol = ' ' |
|
|
elif b == 'underline': |
|
|
padding_symbol = '_' |
|
|
continue |
|
|
if True or sys.argv[1].startswith('-'): |
|
|
|
|
|
del sys.argv[1] |
|
|
continue |
|
|
|
|
|
if not case_sensitive: |
|
|
ig = set([w.upper() for w in ignore_words]) |
|
|
ignore_words = ig |
|
|
|
|
|
default_clusters = {} |
|
|
default_words = {} |
|
|
|
|
|
ref_file = sys.argv[1] |
|
|
hyp_file = sys.argv[2] |
|
|
rec_set = {} |
|
|
if split and not case_sensitive: |
|
|
newsplit = dict() |
|
|
for w in split: |
|
|
words = split[w] |
|
|
for i in range(len(words)): |
|
|
words[i] = words[i].upper() |
|
|
newsplit[w.upper()] = words |
|
|
split = newsplit |
|
|
|
|
|
with codecs.open(hyp_file, 'r', 'utf-8') as fh: |
|
|
for line in fh: |
|
|
if tochar: |
|
|
array = characterize(line) |
|
|
else: |
|
|
array = line.strip().split() |
|
|
if len(array) == 0: continue |
|
|
fid = array[0] |
|
|
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, |
|
|
split) |
|
|
|
|
|
|
|
|
for line in open(ref_file, 'r', encoding='utf-8'): |
|
|
if tochar: |
|
|
array = characterize(line) |
|
|
else: |
|
|
array = line.rstrip('\n').split() |
|
|
if len(array) == 0: continue |
|
|
fid = array[0] |
|
|
if fid not in rec_set: |
|
|
continue |
|
|
lab = normalize(array[1:], ignore_words, case_sensitive, split) |
|
|
rec = rec_set[fid] |
|
|
if verbose: |
|
|
print('\nutt: %s' % fid) |
|
|
|
|
|
for word in rec + lab: |
|
|
if word not in default_words: |
|
|
default_cluster_name = default_cluster(word) |
|
|
if default_cluster_name not in default_clusters: |
|
|
default_clusters[default_cluster_name] = {} |
|
|
if word not in default_clusters[default_cluster_name]: |
|
|
default_clusters[default_cluster_name][word] = 1 |
|
|
default_words[word] = default_cluster_name |
|
|
|
|
|
result = calculator.calculate(lab, rec) |
|
|
if verbose: |
|
|
if result['all'] != 0: |
|
|
wer = float(result['ins'] + result['sub'] + |
|
|
result['del']) * 100.0 / result['all'] |
|
|
else: |
|
|
wer = 0.0 |
|
|
print('WER: %4.2f %%' % wer, end=' ') |
|
|
print('N=%d C=%d S=%d D=%d I=%d' % |
|
|
(result['all'], result['cor'], result['sub'], result['del'], |
|
|
result['ins'])) |
|
|
space = {} |
|
|
space['lab'] = [] |
|
|
space['rec'] = [] |
|
|
for idx in range(len(result['lab'])): |
|
|
len_lab = width(result['lab'][idx]) |
|
|
len_rec = width(result['rec'][idx]) |
|
|
length = max(len_lab, len_rec) |
|
|
space['lab'].append(length - len_lab) |
|
|
space['rec'].append(length - len_rec) |
|
|
upper_lab = len(result['lab']) |
|
|
upper_rec = len(result['rec']) |
|
|
lab1, rec1 = 0, 0 |
|
|
while lab1 < upper_lab or rec1 < upper_rec: |
|
|
if verbose > 1: |
|
|
print('lab(%s):' % fid.encode('utf-8'), end=' ') |
|
|
else: |
|
|
print('lab:', end=' ') |
|
|
lab2 = min(upper_lab, lab1 + max_words_per_line) |
|
|
for idx in range(lab1, lab2): |
|
|
token = result['lab'][idx] |
|
|
print('{token}'.format(token=token), end='') |
|
|
for n in range(space['lab'][idx]): |
|
|
print(padding_symbol, end='') |
|
|
print(' ', end='') |
|
|
print() |
|
|
if verbose > 1: |
|
|
print('rec(%s):' % fid.encode('utf-8'), end=' ') |
|
|
else: |
|
|
print('rec:', end=' ') |
|
|
rec2 = min(upper_rec, rec1 + max_words_per_line) |
|
|
for idx in range(rec1, rec2): |
|
|
token = result['rec'][idx] |
|
|
print('{token}'.format(token=token), end='') |
|
|
for n in range(space['rec'][idx]): |
|
|
print(padding_symbol, end='') |
|
|
print(' ', end='') |
|
|
print('\n', end='\n') |
|
|
lab1 = lab2 |
|
|
rec1 = rec2 |
|
|
|
|
|
if verbose: |
|
|
print( |
|
|
'===========================================================================' |
|
|
) |
|
|
print() |
|
|
|
|
|
result = calculator.overall() |
|
|
if result['all'] != 0: |
|
|
wer = float(result['ins'] + result['sub'] + |
|
|
result['del']) * 100.0 / result['all'] |
|
|
else: |
|
|
wer = 0.0 |
|
|
print('Overall -> %4.2f %%' % wer, end=' ') |
|
|
print('N=%d C=%d S=%d D=%d I=%d' % |
|
|
(result['all'], result['cor'], result['sub'], result['del'], |
|
|
result['ins'])) |
|
|
if not verbose: |
|
|
print() |
|
|
|
|
|
if verbose: |
|
|
for cluster_id in default_clusters: |
|
|
result = calculator.cluster( |
|
|
[k for k in default_clusters[cluster_id]]) |
|
|
if result['all'] != 0: |
|
|
wer = float(result['ins'] + result['sub'] + |
|
|
result['del']) * 100.0 / result['all'] |
|
|
else: |
|
|
wer = 0.0 |
|
|
print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') |
|
|
print('N=%d C=%d S=%d D=%d I=%d' % |
|
|
(result['all'], result['cor'], result['sub'], result['del'], |
|
|
result['ins'])) |
|
|
if len(cluster_file) > 0: |
|
|
cluster_id = '' |
|
|
cluster = [] |
|
|
for line in open(cluster_file, 'r', encoding='utf-8'): |
|
|
for token in line.decode('utf-8').rstrip('\n').split(): |
|
|
|
|
|
if token[0:2] == '</' and token[len(token)-1] == '>' and \ |
|
|
token.lstrip('</').rstrip('>') == cluster_id : |
|
|
result = calculator.cluster(cluster) |
|
|
if result['all'] != 0: |
|
|
wer = float(result['ins'] + result['sub'] + |
|
|
result['del']) * 100.0 / result['all'] |
|
|
else: |
|
|
wer = 0.0 |
|
|
print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') |
|
|
print('N=%d C=%d S=%d D=%d I=%d' % |
|
|
(result['all'], result['cor'], result['sub'], |
|
|
result['del'], result['ins'])) |
|
|
cluster_id = '' |
|
|
cluster = [] |
|
|
|
|
|
elif token[0] == '<' and token[len(token)-1] == '>' and \ |
|
|
cluster_id == '' : |
|
|
cluster_id = token.lstrip('<').rstrip('>') |
|
|
cluster = [] |
|
|
|
|
|
else: |
|
|
cluster.append(token) |
|
|
print() |
|
|
print( |
|
|
'===========================================================================' |
|
|
) |
|
|
|