From f9df2e06843e9adc473af8f280361a94a28ff169 Mon Sep 17 00:00:00 2001 From: lijian Date: Mon, 13 Oct 2025 10:28:08 +0800 Subject: [PATCH] add transcribe.py for long audio --- ACL_PyTorch/built-in/audio/whisperx/README.md | 20 ++- .../built-in/audio/whisperx/requirements.txt | 3 +- .../built-in/audio/whisperx/transcribe.py | 119 ++++++++++++++++++ 3 files changed, 139 insertions(+), 3 deletions(-) create mode 100644 ACL_PyTorch/built-in/audio/whisperx/transcribe.py diff --git a/ACL_PyTorch/built-in/audio/whisperx/README.md b/ACL_PyTorch/built-in/audio/whisperx/README.md index 798f0f81ad..49d22e555c 100644 --- a/ACL_PyTorch/built-in/audio/whisperx/README.md +++ b/ACL_PyTorch/built-in/audio/whisperx/README.md @@ -38,7 +38,7 @@ cd weight * `large-v3.pt`:[下载链接](https://modelscope.cn/models/iic/Whisper-large-v3/files) * `large-v3-turbo.pt`:[下载链接](https://modelscope.cn/models/iic/Whisper-large-v3-turbo/files) * HuggingFace whisper-large-v3:[下载链接](https://huggingface.co/openai/whisper-large-v3/tree/main) - * HuggingFace whisper-large-v3:[下载链接](https://huggingface.co/openai/whisper-large-v3-turbo/tree/main) + * HuggingFace whisper-large-v3-turbo:[下载链接](https://huggingface.co/openai/whisper-large-v3-turbo/tree/main) * speech_fsmn_vad_zh-cn-16k-common-pytorch: [下载链接](https://huggingface.co/alextomcat/speech_fsmn_vad_zh-cn-16k-common-pytorch/tree/main) ``` cd .. @@ -93,6 +93,10 @@ cd .. ``` ## 模型推理 +**脚本功能说明**: +- infer.py主要用于短音频(<30s)的转录以及LibriSpeech数据集的性能验证 +- transcribe.py用于长音频转录,如智慧教室生成字幕场景 + 1. 激活环境变量 ```SHELL source /usr/local/Ascend/ascend-toolkit/set_env.sh # 具体路径根据你自己的情况修改 @@ -127,7 +131,7 @@ cd .. >>>>设备 0 对应 NUMA 节点: 6, NUMA node6 CPU(s): 192-223 ... ``` -5. 开始推理, 根据实际查询到的核数配置,比如 +5. 短音频和LibriSpeech数据集推理, 根据实际查询到的核数配置,比如 ```SHELL taskset -c 192-223 python3 infer.py --whisper_model_path ./weight/Whisper-large-v3/large-v3.pt ``` @@ -145,6 +149,18 @@ infer.py推理参数: * --batch_size: batch_size大小,默认为16 * --warmup:warm up次数,默认为4,首次warm up时编译成图 +6. 长音频转录: + ```SHELL + taskset -c 192-223 python3 transcribe.py --whisper_model_path ./weight/Whisper-large-v3/large-v3.pt --audio_path {audio_file} + ``` +transcribe.py参数说明: +* --whisper_model_path:whisper模型权重路径,默认为"./weight/Whisper-large-v3/large-v3.pt" +* --language:输出语言,默认为中文 +* --sample_audio:warm up阶段使用的音频,默认为"audio.mp3" +* --audio_path:长音频文件路径,必选参数 +* --device: npu设备编号,默认为0 +* --warmup:warm up次数,默认为3,首次warm up时编译成图 + ## 性能数据 infer.py取librispeech dev clean数据集中的部分音频进行转录,性能如下 diff --git a/ACL_PyTorch/built-in/audio/whisperx/requirements.txt b/ACL_PyTorch/built-in/audio/whisperx/requirements.txt index 2566be8f59..207aeafe28 100644 --- a/ACL_PyTorch/built-in/audio/whisperx/requirements.txt +++ b/ACL_PyTorch/built-in/audio/whisperx/requirements.txt @@ -77,4 +77,5 @@ urllib3==1.26.20 wcwidth==0.2.13 XlsxWriter==3.2.3 xxhash==3.5.0 -yarl==1.20.0 \ No newline at end of file +yarl==1.20.0 +zhconv==1.4.3 \ No newline at end of file diff --git a/ACL_PyTorch/built-in/audio/whisperx/transcribe.py b/ACL_PyTorch/built-in/audio/whisperx/transcribe.py new file mode 100644 index 0000000000..891aedf88a --- /dev/null +++ b/ACL_PyTorch/built-in/audio/whisperx/transcribe.py @@ -0,0 +1,119 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import string +import numpy as np +import zhconv +import torch +import torch_npu +import torchair as tng +from torchair.configs.compiler_config import CompilerConfig +import whisper + +from modeling_whisper import get_whisper_model +from pipeline import load_audio, SAMPLE_RATE + + +def parse_args(): + parser = argparse.ArgumentParser("Whisper transcribe") + parser.add_argument("--whisper_model_path", type=str, default="./weight/Whisper-large-v3/large-v3.pt", + help="whisper model checkpoint file path") + parser.add_argument("--language", type=str, default='zh', help="output language") + parser.add_argument("--sample_audio", type=str, default="./audio.mp3", + help="sample audio for warm up and compilation") + parser.add_argument("--audio_path", type=str, required=True, help="target audio to transcribe") + parser.add_argument("--device", type=int, default='0', help="npu device id") + parser.add_argument("--warmup", type=int, default='3', help="warm up times") + return parser.parse_args() + + +def generateSubtitles(model, file_path, language='zh'): + audio = load_audio(file_path, SAMPLE_RATE) + + print("start transcribe...") + output = model.transcribe(audio, language=language) + segments_list = output['segments'] + + text = '' + index = 1 + last_sentence = None + segment_lines = [] + comma = ', ' + full_stop = '.' + if language == 'zh': + comma = ',' + full_stop = '。' + + for segment in segments_list: + if segment['temperature'] > 0.8: + print(f"Low confidence segment detected, skipping... {segment['text']}") + continue + sentence = segment['text'] + if sentence.strip() in {'', '.', ',', ',', '。'}: + continue + if language == 'zh' or language == 'jw': + # 将繁体字转化成简体 + sentence = zhconv.convert(sentence, 'zh-cn') + if last_sentence is not None and last_sentence == sentence: + continue + + segment_lines.append({ + 'index': index, + 'startTime': float(segment['start']), + 'endTime': float(segment['end']), + 'sentence': sentence + }) + index += 1 + + # 添加标点符号 + if sentence[-1] not in set(string.punctuation + ",。?!;:"): + sentence += comma + text += sentence + + # 最后一个逗号替换成句号 + if len(text.strip()) > 0 and text[:-1] == comma: + text = text[:-1] + full_stop + + result = {"text": text, "segment_lines": segment_lines} + print(result) + with open("result.txt", "w") as f: + f.write(f"{result}") + + +if __name__ == '__main__': + args = parse_args() + deivce = torch.device('npu:{}'.format(args.device)) + whisper_decode_options = whisper.DecodingOptions(without_timestamps=True, fp16=True) + whisper_model = get_whisper_model(args.whisper_model_path, whisper_decode_options, deivce) + + torch_npu.npu.set_compile_mode(jit_compile=False) + config = CompilerConfig() + config.experimental_config.frozen_parameter = True + config.experimental_config.tiling_schedule_optimize = True # 使能tiling全下沉配置 + npu_backend = tng.get_npu_backend(compiler_config=config) + + print("compile model...") + whisper_model.encoder.forward = torch.compile(whisper_model.encoder.forward, dynamic=False, fullgraph=True, backend=npu_backend) + whisper_model.prefill_decoder.forward = torch.compile(whisper_model.prefill_decoder.forward, dynamic=True, fullgraph=True, backend=npu_backend) + whisper_model.decode_decoder.forward = torch.compile(whisper_model.decode_decoder.forward, dynamic=True, fullgraph=True, backend=npu_backend) + + sample_audio = load_audio(args.sample_audio) + print("start warm up...") + for i in range(args.warmup): + result = whisper_model.transcribe(sample_audio, language='zh') + print(f"warm up {i}/{args.warmup} {result['text']}") + print("warm up done") + + generateSubtitles(whisper_model, args.audio_path, args.language) \ No newline at end of file -- Gitee