1 Star 0 Fork 0

嗜雪的蚂蚁/asr_timestamp_insert_text_grid

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
label_alignment_1.py 24.87 KB
一键复制 编辑 原始数据 按行查看 历史
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# -- coding: utf-8 --
'''
# @Time : 2023/10/16 21:31
# @Author: from https://github.com/wenet-e2e/wenet/blob/main/tools/compute-wer.py
# @Modified By Shiyu He
# @University : Xinjiang University
'''
import os, re, sys, unicodedata
import codecs
import json
import pdb
import copy
from pypinyin import pinyin, Style
# 使用 phrase-pinyin-data 项目中 cc_cedict.txt 文件中的拼音数据优化结果
from pypinyin_dict.phrase_pinyin_data import cc_cedict
cc_cedict.load()
result = {}
remove_tag = True
spacelist = [' ', '\t', '\r', '\n']
puncts = [
'!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』',
'《', '》', '“'
]
def Check_Unknown_Rec(word):
for ch in word:
if '\u4e00' <= ch <= '\u9fff':
return word
return str("?")
def GetAllPinyin(chiese_char):
char_len = len(chiese_char)
pinyin_result = ''
try:
if char_len == 0:
pinyin_result = " "
elif char_len == 1:
pinyin_result = pinyin(str(chiese_char), Style.TONE3, heteronym=False, errors='default', strict=True)[0][0]
else :
raise ValueError("长度超限,请传入单个汉字!当前传入字符为:{}".format(chiese_char))
except ValueError as e:
pass
return pinyin_result
# 获取整句的拼音,减少拼音生成中的多音字
def GetSentencePinyin(chiese_sentence):
# from pypinyin_dict.pinyin_data import kMadarin
# kMadarin.load()
pinyin_list = []
str_sentence = "".join(chiese_sentence)
pinyin_all_sentence = pinyin(str(str_sentence), Style.TONE3, heteronym=False, errors='default', strict=True)
for pinyin_word in pinyin_all_sentence:
pinyin_list.append(pinyin_word[0])
# print("pinyin_list:", pinyin_list)
return pinyin_list
def characterize(string):
res = []
i = 0
while i < len(string):
char = string[i]
if char in puncts:
i += 1
continue
cat1 = unicodedata.category(char)
#https://unicodebook.readthedocs.io/unicode.html#unicode-categories
if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned
i += 1
continue
if cat1 == 'Lo': # letter-other
res.append(char)
i += 1
else:
# some input looks like: <unk><noise>, we want to separate it to two words.
sep = ' '
if char == '<': sep = '>'
j = i + 1
while j < len(string):
c = string[j]
if ord(c) >= 128 or (c in spacelist) or (c == sep):
break
j += 1
if j < len(string) and string[j] == '>':
j += 1
res.append(string[i:j])
i = j
return res
def stripoff_tags(x):
if not x: return ''
chars = []
i = 0
T = len(x)
while i < T:
if x[i] == '<':
while i < T and x[i] != '>':
i += 1
i += 1
else:
chars.append(x[i])
i += 1
return ''.join(chars)
def normalize(sentence, ignore_words, cs, split=None):
""" sentence, ignore_words are both in unicode
"""
new_sentence = []
for token in sentence:
x = token
if not cs:
x = x.upper()
if x in ignore_words:
continue
if remove_tag:
x = stripoff_tags(x)
if not x:
continue
if split and x in split:
new_sentence += split[x]
else:
new_sentence.append(x)
return new_sentence
class Calculator:
def __init__(self, ):
self.data = {}
self.space = []
self.cost = {}
self.cost['cor'] = 0
self.cost['sub'] = 1
self.cost['del'] = 1
self.cost['ins'] = 1
def calculate(self, utt, lab, rec):
# Initialization
# 生成整个lab,rec的拼音
lab_pinyin_list = GetSentencePinyin(lab)
# print("lab_pinyin_list:", lab_pinyin_list)
rec_pinyin_list = GetSentencePinyin(rec)
# print("rec_pinyin_list:", rec_pinyin_list)
lab.insert(0, '')
rec.insert(0, '')
while len(self.space) < len(lab):
self.space.append([])
for row in self.space:
for element in row:
element['dist'] = 0
element['error'] = 'non'
while len(row) < len(rec):
row.append({'dist': 0, 'error': 'non'})
for i in range(len(lab)):
self.space[i][0]['dist'] = i
self.space[i][0]['error'] = 'del'
for j in range(len(rec)):
self.space[0][j]['dist'] = j
self.space[0][j]['error'] = 'ins'
self.space[0][0]['error'] = 'non'
for token in lab:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
for token in rec:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
# Computing edit distance
for i, lab_token in enumerate(lab):
for j, rec_token in enumerate(rec):
if i == 0 or j == 0:
continue
min_dist = sys.maxsize
min_error = 'none'
dist = self.space[i - 1][j]['dist'] + self.cost['del']
error = 'del'
if dist < min_dist:
min_dist = dist
min_error = error
dist = self.space[i][j - 1]['dist'] + self.cost['ins']
error = 'ins'
if dist < min_dist:
min_dist = dist
min_error = error
if lab_token == rec_token:
dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
error = 'cor'
else:
dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
error = 'sub'
if dist < min_dist:
min_dist = dist
min_error = error
self.space[i][j]['dist'] = min_dist
self.space[i][j]['error'] = min_error
# Tracing back
global result
result_in = {
'lab': [],
'rec': [],
'error_mark':[],
'py_rec' : [],
'py_label' : [],
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
result.update({utt : copy.deepcopy(result_in)})
i = len(lab) - 1
j = len(rec) - 1
k = i if len(lab) > len(rec) else j
l = len(lab_pinyin_list) - 1
m = len(rec_pinyin_list) - 1
while True:
if self.space[i][j]['error'] == 'cor': # correct
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
result[utt]['all'] = result[utt]['all'] + 1
result[utt]['cor'] = result[utt]['cor'] + 1
result[utt]['lab'].insert(0, lab[i])
# result[utt]['py_label'].insert(0, GetAllPinyin(lab[i]))
result[utt]['py_label'].insert(0, lab_pinyin_list[l])
checked_rec = Check_Unknown_Rec(rec[j])
result[utt]['rec'].insert(0, checked_rec)
# result[utt]['py_rec'].insert(0, GetAllPinyin(rec[j]))
result[utt]['py_rec'].insert(0, rec_pinyin_list[m])
result[utt]['error_mark'].insert(0, '对')
i = i - 1
j = j - 1
k = k - 1
l = l - 1
m = m - 1
elif self.space[i][j]['error'] == 'sub': # substitution
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
result[utt]['all'] = result[utt]['all'] + 1
result[utt]['sub'] = result[utt]['sub'] + 1
result[utt]['lab'].insert(0, lab[i])
# result[utt]['py_label'].insert(0, GetAllPinyin(lab[i]))
result[utt]['py_label'].insert(0, lab_pinyin_list[l])
checked_rec = Check_Unknown_Rec(rec[j])
result[utt]['rec'].insert(0, checked_rec)
# result[utt]['py_rec'].insert(0, GetAllPinyin(rec[j]))
rec_pinyin = rec_pinyin_list[m]
print(f'当前的rec_pinyin:{rec_pinyin}')
result[utt]['py_rec'].insert(0, rec_pinyin)
result[utt]['error_mark'].insert(0, '换')
i = i - 1
j = j - 1
k = k - 1
l = l - 1
m = m - 1
elif self.space[i][j]['error'] == 'del': # deletion
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
result[utt]['all'] = result[utt]['all'] + 1
result[utt]['del'] = result[utt]['del'] + 1
result[utt]['lab'].insert(0, lab[i])
# result[utt]['py_label'].insert(0, GetAllPinyin(lab[i]))
result[utt]['py_label'].insert(0, lab_pinyin_list[l])
result[utt]['rec'].insert(0, "sil")
result[utt]['error_mark'].insert(0, '删')
result[utt]['py_rec'].insert(0, "sil")
i = i - 1
k = k - 1
l = l - 1
elif self.space[i][j]['error'] == 'ins': # insertion
if len(rec[j]) > 0:
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
result[utt]['ins'] = result[utt]['ins'] + 1
result[utt]['lab'].insert(0, "sil")
result[utt]['py_label'].insert(0, "sil")
checked_rec = Check_Unknown_Rec(rec[j])
result[utt]['rec'].insert(0, checked_rec)
result[utt]['error_mark'].insert(0, '插')
# result[utt]['py_rec'].insert(0, GetAllPinyin(rec[j]))
result[utt]['py_rec'].insert(0, rec_pinyin_list[m])
j = j - 1
k = k - 1
m = m - 1
elif self.space[i][j]['error'] == 'non': # starting point
break
else: # shouldn't reach here
print(
'this should not happen , i = {i} , j = {j} , error = {error}'
.format(i=i, j=j, error=self.space[i][j]['error']))
return result
def overall(self):
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in self.data:
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def cluster(self, data):
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in data:
if token in self.data:
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def keys(self):
return list(self.data.keys())
def width(string):
return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
def default_cluster(word):
unicode_names = [unicodedata.name(char) for char in word]
for i in reversed(range(len(unicode_names))):
if unicode_names[i].startswith('DIGIT'): # 1
unicode_names[i] = 'Number' # 'DIGIT'
elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH')
or unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')):
# 明 / 郎
unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH'
elif (unicode_names[i].startswith('LATIN CAPITAL LETTER')
or unicode_names[i].startswith('LATIN SMALL LETTER')):
# A / a
unicode_names[i] = 'English' # 'LATIN LETTER'
elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め
unicode_names[i] = 'Japanese' # 'GANA LETTER'
elif (unicode_names[i].startswith('AMPERSAND')
or unicode_names[i].startswith('APOSTROPHE')
or unicode_names[i].startswith('COMMERCIAL AT')
or unicode_names[i].startswith('DEGREE CELSIUS')
or unicode_names[i].startswith('EQUALS SIGN')
or unicode_names[i].startswith('FULL STOP')
or unicode_names[i].startswith('HYPHEN-MINUS')
or unicode_names[i].startswith('LOW LINE')
or unicode_names[i].startswith('NUMBER SIGN')
or unicode_names[i].startswith('PLUS SIGN')
or unicode_names[i].startswith('SEMICOLON')):
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
del unicode_names[i]
else:
return 'Other'
if len(unicode_names) == 0:
return 'Other'
if len(unicode_names) == 1:
return unicode_names[0]
for i in range(len(unicode_names) - 1):
if unicode_names[i] != unicode_names[i + 1]:
return 'Other'
return unicode_names[0]
def usage():
print(
"compute-wer.py : compute word error rate (WER) and align recognition results and references."
)
print(
" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer"
)
if __name__ == '__main__':
if len(sys.argv) == 1:
usage()
sys.exit(0)
calculator = Calculator()
cluster_file = ''
ignore_words = set()
tochar = False
verbose = 1
padding_symbol = ' '
case_sensitive = False
max_words_per_line = sys.maxsize
split = None
while len(sys.argv) > 3:
a = '--maxw='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):]
del sys.argv[1]
max_words_per_line = int(b)
continue
a = '--rt='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
remove_tag = (b == 'true') or (b != '0')
continue
a = '--cs='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
case_sensitive = (b == 'true') or (b != '0')
continue
a = '--cluster='
if sys.argv[1].startswith(a):
cluster_file = sys.argv[1][len(a):]
del sys.argv[1]
continue
a = '--splitfile='
if sys.argv[1].startswith(a):
split_file = sys.argv[1][len(a):]
del sys.argv[1]
split = dict()
with codecs.open(split_file, 'r', 'utf-8') as fh:
for line in fh: # line in unicode
words = line.strip().split()
if len(words) >= 2:
split[words[0]] = words[1:]
continue
a = '--ig='
if sys.argv[1].startswith(a):
ignore_file = sys.argv[1][len(a):]
del sys.argv[1]
with codecs.open(ignore_file, 'r', 'utf-8') as fh:
for line in fh: # line in unicode
line = line.strip()
if len(line) > 0:
ignore_words.add(line)
continue
a = '--char='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
tochar = (b == 'true') or (b != '0')
continue
a = '--v='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
verbose = 0
try:
verbose = int(b)
except:
if b == 'true' or b != '0':
verbose = 1
continue
a = '--padding-symbol='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
if b == 'space':
padding_symbol = ' '
elif b == 'underline':
padding_symbol = '_'
continue
if True or sys.argv[1].startswith('-'):
#ignore invalid switch
del sys.argv[1]
continue
if not case_sensitive:
ig = set([w.upper() for w in ignore_words])
ignore_words = ig
default_clusters = {}
default_words = {}
ref_file = sys.argv[1]
hyp_file = sys.argv[2]
rec_set = {}
if split and not case_sensitive:
newsplit = dict()
for w in split:
words = split[w]
for i in range(len(words)):
words[i] = words[i].upper()
newsplit[w.upper()] = words
split = newsplit
with codecs.open(hyp_file, 'r', 'utf-8') as fh:
for line in fh:
if tochar:
array = characterize(line)
else:
array = line.strip().split()
if len(array) == 0: continue
fid = array[0]
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive,
split)
# compute error rate on the interaction of reference file and hyp file
for line in open(ref_file, 'r', encoding='utf-8'):
if tochar:
array = characterize(line)
else:
array = line.rstrip('\n').split()
if len(array) == 0: continue
fid = array[0]
if fid not in rec_set:
continue
# print("array:", array)
lab = normalize(array[1:], ignore_words, case_sensitive, split)
rec = rec_set[fid]
# print("lab:", lab)
# print("rec:", rec)
if verbose:
print('\nutt: %s' % fid)
for word in rec + lab:
if word not in default_words:
default_cluster_name = default_cluster(word)
if default_cluster_name not in default_clusters:
default_clusters[default_cluster_name] = {}
if word not in default_clusters[default_cluster_name]:
default_clusters[default_cluster_name][word] = 1
default_words[word] = default_cluster_name
# 此处得到替换字符的字典
result = calculator.calculate(fid, lab, rec)
if verbose:
if result[fid]['all'] != 0:
wer = float(result[fid]['ins'] + result[fid]['sub'] +
result[fid]['del']) * 100.0 / result[fid]['all']
else:
wer = 0.0
# print('WER: %4.2f %%' % wer, end=' ')
# print('N=%d C=%d S=%d D=%d I=%d' %
# (result[fid]['all'], result[fid]['cor'], result[fid]['sub'], result[fid]['del'],
# result[fid]['ins']))
# 对齐 result[fid]['lab'] 和 result[fid]['rec']
# 通过计算宽度差值,并在每个元素前添加适当数量的空格
space = {}
space['lab'] = []
space['rec'] = []
space['error_mark'] = []
for idx in range(len(result[fid]['lab'])):
len_lab = width(result[fid]['lab'][idx])
len_rec = width(result[fid]['rec'][idx])
len_error_mark = width(result[fid]['error_mark'][idx])
length = max(len_lab, len_rec, len_error_mark)
space['lab'].append(length - len_lab)
space['rec'].append(length - len_rec)
space['error_mark'].append(length - len_error_mark)
upper_lab = len(result[fid]['lab'])
upper_rec = len(result[fid]['rec'])
upper_error_mark = len(result[fid]['error_mark'])
lab1, rec1, error_mark1 = 0, 0, 0
while lab1 < upper_lab or rec1 < upper_rec or error_mark1 < upper_error_mark:
if verbose > 1:
print('lab(%s):' % fid.encode('utf-8'), end=' ')
else:
print('lab:', end=' ')
lab2 = min(upper_lab, lab1 + max_words_per_line)
for idx in range(lab1, lab2):
token = result[fid]['lab'][idx]
print('{token}'.format(token=token), end='')
for n in range(space['lab'][idx]):
print(padding_symbol, end='')
print(' ', end='')
print()
if verbose > 1:
print('rec(%s):' % fid.encode('utf-8'), end=' ')
else:
print('rec:', end=' ')
rec2 = min(upper_rec, rec1 + max_words_per_line)
for idx in range(rec1, rec2):
token = result[fid]['rec'][idx]
print('{token}'.format(token=token), end='')
for n in range(space['rec'][idx]):
print(padding_symbol, end='')
print(' ', end='')
print()
if verbose > 1:
print('err(%s):' % fid.encode('utf-8'), end=' ')
else:
print('err:', end=' ')
error_mark2 = min(upper_error_mark, error_mark1 + max_words_per_line)
for idx in range(error_mark1, error_mark2):
token = result[fid]['error_mark'][idx]
print('{token}'.format(token=token), end='')
for n in range(space['error_mark'][idx]):
print(padding_symbol, end='')
print(' ', end='')
print()
print('\n', end='\n')
lab1 = lab2
rec1 = rec2
error_mark1 = error_mark2
# 将result字典保存为 JSON 文件
with open(os.path.join(os.path.dirname(hyp_file), 'alignment.json'), 'w', encoding='utf-8') as file:
json.dump(result, file, ensure_ascii=False)
if verbose:
for cluster_id in default_clusters:
result = calculator.cluster(
[k for k in default_clusters[cluster_id]])
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] +
result['del']) * 100.0 / result['all']
else:
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
if len(cluster_file) > 0: # compute separated WERs for word clusters
cluster_id = ''
cluster = []
for line in open(cluster_file, 'r', encoding='utf-8'):
for token in line.decode('utf-8').rstrip('\n').split():
# end of cluster reached, like </Keyword>
if token[0:2] == '</' and token[len(token)-1] == '>' and \
token.lstrip('</').rstrip('>') == cluster_id :
result = calculator.cluster(cluster)
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] +
result['del']) * 100.0 / result['all']
else:
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'],
result['del'], result['ins']))
cluster_id = ''
cluster = []
# begin of cluster reached, like <Keyword>
elif token[0] == '<' and token[len(token)-1] == '>' and \
cluster_id == '' :
cluster_id = token.lstrip('<').rstrip('>')
cluster = []
# general terms, like WEATHER / CAR / ...
else:
cluster.append(token)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mayi123/asr_timestamp_insert_text_grid.git
git@gitee.com:mayi123/asr_timestamp_insert_text_grid.git
mayi123
asr_timestamp_insert_text_grid
asr_timestamp_insert_text_grid
master

搜索帮助