1 Star 0 Fork 0

嗜雪的蚂蚁/asr_timestamp_insert_text_grid

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
SpeechrepeScore.py 8.43 KB
一键复制 编辑 原始数据 按行查看 历史
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# -- coding: utf-8 --
'''
# @Time : 2023/10/16 21:31
# @Author: from https://github.com/wenet-e2e/wenet/blob/main/tools/compute-wer.py
# @Modified By Shiyu He
# @University : Xinjiang University
'''
import os, re, sys
import json
import random
puncts = [
'!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '『', '』',
'《', '》', '.', ';', ':', '(', ')', '[', ']', '{', '}', '"', "'", '...',
'—', '-', '/', '\\', '%', '¥', '$', '·', '`', '‘', '’', '“', '”', '~', '@', '#',
'&', '*', '_', '+', '=', '|', '<', '>', '^', '【', '】'
]
def remove_punctuation(sentence: str):
try:
pattern = '[' + re.escape(''.join(puncts)) + ']'
return re.sub(pattern, '', sentence)
except re.error as re_error:
print(f"Error in remove_punctuation: {re_error}")
return sentence
# 定义一个计算语音复述评分的类
class RepetitionEditDist:
def __init__(self):
self.data = {}
self.space = []
self.cost = {}
self.cost['cor'] = 0
self.cost['sub'] = 1
self.cost['del'] = 1
self.cost['ins'] = 1
# lab: 标签 rec: 识别结果
def get_edit_dist(self, lab: list, rec: list):
try:
lab.insert(0, '')
rec.insert(0, '')
while len(self.space) < len(lab):
self.space.append([])
for row in self.space:
for element in row:
element['dist'] = 0
element['error'] = 'non'
while len(row) < len(rec):
row.append({'dist': 0, 'error': 'non'})
for i in range(len(lab)):
self.space[i][0]['dist'] = i
self.space[i][0]['error'] = 'del'
for j in range(len(rec)):
self.space[0][j]['dist'] = j
self.space[0][j]['error'] = 'ins'
self.space[0][0]['error'] = 'non'
for token in lab:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
for token in rec:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
# Computing edit distance
for i, lab_token in enumerate(lab):
for j, rec_token in enumerate(rec):
if i == 0 or j == 0:
continue
min_dist = sys.maxsize
min_error = 'none'
dist = self.space[i - 1][j]['dist'] + self.cost['del']
error = 'del'
if dist < min_dist:
min_dist = dist
min_error = error
dist = self.space[i][j - 1]['dist'] + self.cost['ins']
error = 'ins'
if dist < min_dist:
min_dist = dist
min_error = error
if lab_token == rec_token:
dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
error = 'cor'
else:
dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
error = 'sub'
if dist < min_dist:
min_dist = dist
min_error = error
self.space[i][j]['dist'] = min_dist
self.space[i][j]['error'] = min_error
# Tracing back
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
i = len(lab) - 1
j = len(rec) - 1
k = i if len(lab) > len(rec) else j
while True:
if self.space[i][j]['error'] == 'cor': # correct
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
result['all'] = result['all'] + 1
result['cor'] = result['cor'] + 1
i = i - 1
j = j - 1
k = k - 1
elif self.space[i][j]['error'] == 'sub': # substitution
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
result['all'] = result['all'] + 1
result['sub'] = result['sub'] + 1
i = i - 1
j = j - 1
k = k - 1
elif self.space[i][j]['error'] == 'del': # deletion
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
result['all'] = result['all'] + 1
result['del'] = result['del'] + 1
i = i - 1
k = k - 1
elif self.space[i][j]['error'] == 'ins': # insertion
if len(rec[j]) > 0:
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
result['ins'] = result['ins'] + 1
j = j - 1
k = k - 1
elif self.space[i][j]['error'] == 'non': # starting point
break
else: # shouldn't reach here
print(
'this should not happen , i = {i} , j = {j} , error = {error}'
.format(i=i, j=j, error=self.space[i][j]['error']))
return result
except Exception as e:
print(f"Error in get_edit_dist: {e}")
return None
class SpeechRepetitionScore:
def __init__(self, Full=10, level=3):
self.level = level
self.Full = Full
def level_to_weight(self):
if self.level == 1:
return 1.0
elif self.level == 2:
return 0.8
elif self.level == 3:
return 0.7
else:
raise ValueError("等级(level)必须是1、2或3。")
def computer_score(self, result: dict):
try:
level_weight = self.level_to_weight()
remove_score_del = result['del'] * 0.6 * level_weight
remove_score_sub = result['sub'] * 1 * level_weight
remove_score_ins = result['ins'] * 0.6 * level_weight
if len(lab) > 0:
remove_score = (remove_score_del + remove_score_sub + remove_score_ins) / len(lab) * self.Full
else:
return 0
if remove_score >= 4:
pass
# 引入语义相似度分析
else:
score = self.Full - remove_score
return score
except Exception as e:
print(f"Error in computer_score: {e}")
return None
if __name__ == '__main__':
try:
LEVEL = 2
lab = "今天不是个好日子"
rec = "今天是个好日子"
lab_token = remove_punctuation(lab)
rec_token = remove_punctuation(rec)
if lab == rec:
score = 10
else:
SpeechRepetition = RepetitionEditDist()
editdist_result = SpeechRepetition.get_edit_dist(list(lab), list(rec))
# 初始化一个评分类,传入满分分值,考试等级
speechrepetitionscore = SpeechRepetitionScore(10, LEVEL)
score = speechrepetitionscore.computer_score(editdist_result)
print("score:", score)
except ValueError as ve:
print(f"错误:{ve}")
except Exception as e:
print(f"发生异常:{e}")
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mayi123/asr_timestamp_insert_text_grid.git
git@gitee.com:mayi123/asr_timestamp_insert_text_grid.git
mayi123
asr_timestamp_insert_text_grid
asr_timestamp_insert_text_grid
master

搜索帮助