46 Star 498 Fork 1.3K

Ascend/ModelZoo-PyTorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
贡献代码
同步代码
取消
提示: 由于 Git 不支持空文件夾,创建文件夹后会生成空的 .keep 文件
Loading...
README

Wenet Conformer for PyTorch

简介

模型介绍

Wenet是一款开源的、面向工业落地应用的语音识别工具包,主要特点是小而精,它不仅采用了现阶段最先进的网络设计Conformer,还用到了U2结构实现流式与非流式框架的统一。

支持特性

本仓已支持以下模型任务类型。

模型 任务类型 是否支持
Conformer 预训练
Whisper 预训练

代码实现

  • 参考实现:

    url=https://github.com/wenet-e2e/wenet.git
    commit_id=ac9a2612e8245ac473a17f64eea600dd7afbeb20
    
  • 适配昇腾 AI 处理器的实现:

    url=https://gitee.com/ascend/ModelZoo-PyTorch.git
    code_path=PyTorch/built-in/audio
    

Wenet-Conformer

准备环境

  • 推荐使用最新的版本准备训练环境。

    表 1 版本配套表

    软件 版本 安装指南
    Driver AscendHDK 25.0.RC1.1 驱动固件安装指南
    Firmware AscendHDK 25.0.RC1.1
    CANN CANN 8.1.RC1 CANN 软件安装指南
    PyTorch 2.1.0 Ascend Extension for PyTorch 配置与安装
    torch_npu release v7.0.0-pytorch2.1.0
  • 三方库依赖如下表所示。

    表 2 三方库依赖表

    Torch_Version 三方库依赖版本
    PyTorch 2.1 torch_audio==2.1.0
  • 安装依赖。

    在模型源码包根目录下执行命令。

    pip install -r requirements_2_1.txt
    
  • 编译安装torchaudio

    在官网根据PyTorch版本获取torchaudio对应版本,解压至torchaudio文件夹,运行以下命令

    cd torchaudio
    python setup.py develop
    

准备数据集

  1. 获取数据集。

    用户自行下载 aishell-1 数据集,并将下载好的数据集放置服务器的任意目录下。该数据集包含由 400 位说话人录制的超过 170 小时的语音。数据集目录结构参考如下所示。

     aishell-1
        ├── data_aishell.tgz
        |
        └── resource_aishell.tgz
    

    说明: 该数据集的训练过程脚本只作为一种参考示例。

开始训练

训练模型

  1. 进入解压后的源码包根目录。

    cd /${模型文件夹名称} 
    
  2. 运行训练脚本。

    该模型支持单机8卡训练。

    • 单机8卡训练

      cd examples/aishell/s0/test
      bash train_full_8p.sh --stage=起始stage --stop_stage=终止stage --data_path=/data/xxx/  # 8卡精度
      bash train_performance_8p.sh --data_path=/data/xxx/  # 8卡性能
      bash train_full_8p_whisper.sh --stage=起始stage --stop_stage=终止stage --data_path=/data/xxx/  # 裁剪了CNN Module的8卡精度、性能
      

    模型训练脚本参数说明如下。

    --stage              //模型训练的起始阶段,默认为-1,即从数据下载开始启动训练。若之前数据下载、准备、特征生成等阶段已完成,可配置--stage=4开始训练。
    --stop_stage         //模型训练的终止阶段
    --data_path          //数据集路径
    

    说明:

    --stage <-1 ~ 5>、--stop_stage <-1 ~ 5>:控制模型训练的起始、终止阶段。模型包含 -1 ~ 5 训练阶段,其中 -1 ~ 3 为数据下载、准备、特征生成等阶段,4为模型训练,5为ASR任务评估。首次运行时请从 -1 开始,-1 ~ 3 阶段执行过一次之后,后续可以从stage 4 开始训练。

    --data_path参数填写数据集路径,需要写到数据集的一级目录。

    训练完成后,权重文件保存在当前路径下,并输出模型训练精度和性能信息。

训练结果展示

表 3 conformer训练结果展示表

NAME Error FPS(iters/sec) Epochs AMP_Type Torch_Version
8p-竞品A - 800.44 15 fp32 1.11
8p-Atlas 800T A2 - 526.34 15 fp32 1.11
8p-竞品A - 958.98 15 fp32 2.1
8p-Atlas 800T A2 - 830.49 15 fp32 2.1

表 4 whisper训练结果展示表

NAME Error FPS(iters/sec) Epochs AMP_Type Torch_Version
8p-竞品A - 746.39 15 fp32 1.11
8p-Atlas 800T A2 - 667.62 15 fp32 1.11
8p-竞品A - 748.85 15 fp32 2.1
8p-Atlas 800T A2 - 789.31 15 fp32 2.1

表 5 conformer result

  • Feature info: using fbank feature, dither, cmvn, online speed perturb
  • Training info: lr 0.002, batch size 18, 4 gpu, acc_grad 4, 240 epochs, dither 0.1
  • Decoding info: ctc_weight 0.5, average_num 20
decoding mode WER
ctc greedy search 4.96

说明:上表为历史数据,仅供参考。2025年5月10日更新的性能数据如下:

NAME 精度类型 FPS
Conformer 8p-竞品 FP32 958.98
Conformer 8p-Atlas 900 A2 PoDc FP32 1166.96
whisper 8p-竞品 FP16 748.85
whisper 8p-Atlas 900 A2 PoDc FP16 1532.5

公网地址说明

代码涉及公网地址参考public_address_statement.md

变更说明

2023.09.01:首次发布。

2024.03.16: 增加PyTorch2.1基线,增加FAQ。

FAQ

Q1:Pytorch2.1版本,运行时可能出现段错误

A1:问题原因是Pytorch与torchaudio版本不匹配,需要手动编译安装torchaudio,并在编译时设置变量 BUILD_SOX=0

同时对模型内部分相关代码进行修改:

1). 在tools/compute_cmvn_stats.py中修改:

# 1. 找到:
torchaudio.set_audio_backend("sox_io")
# 修改为:
# torchaudio.set_audio_backend("sox_io")

# 2. 找到:
sample_rate = torchaudio.backend.sox_io_backend.info(wav_path).sample_rate
# 修改为:
sample_rate = torchaudio.info(wav_path).sample_rate

# 3. 找到:
waveform, sample_rate = torchaudio.backend.sox_io_backend.load(
                  filepath=wav_path,
                  num_frames=end_frame - start_frame,
                  frame_offset=start_frame)
# 修改为:
waveform, sample_rate = torchaudio.load(
                  filepath=wav_path,
                  num_frames=end_frame - start_frame,
                  frame_offset=start_frame)

2). 在tools/make_shard_list.py中修改:

# 1. 找到:
import torchaudio.backend.sox_io_backend as sox
# 修改为:
# import torchaudio.backend.sox_io_backend as sox

# 2. 找到:
waveforms, sample_rate = sox.load(wav, normalize=False)
# 修改为:
waveforms, sample_rate = torchaudio.load(wav, normalize=False)

# 3. 找到:
sox.save(f, audio, resample, format="wav", bits_per_sample=16)
# 修改为:
torchaudio.save(f, audio, resample, format="wav", bits_per_sample=16)

2). 在wenet/dataset/processor.py中修改:

# 1. 找到:
torchaudio.utils.sox_utils.set_buffer_size(16500)
# 修改为:
# torchaudio.utils.sox_utils.set_buffer_size(16500)

# 2. 找到:
sample_rate = torchaudio.backend.sox_io_backend.info(
                  wav_file).sample_rate
# 修改为:
sample_rate = torchaudio.info(
                  wav_file).sample_rate

# 3. 找到:
waveform, _ = torchaudio.backend.sox_io_backend.load(
                  filepath=wav_file,
                  num_frames=end_frame - start_frame,
                  frame_offset=start_frame)
# 修改为:
waveform, _ = torchaudio.load(
                  filepath=wav_file,
                  num_frames=end_frame - start_frame,
                  frame_offset=start_frame)

# 4. 找到:
        if speed != 1.0:
          wav, _ = torchaudio.sox_effects.apply_effects_tensor(
              waveform, sample_rate,
              [['speed', str(speed)], ['rate', str(sample_rate)]])
          sample['wav'] = wav
# 修改为:
        # if speed != 1.0:
        #   wav, _ = torchaudio.sox_effects.apply_effects_tensor(
        #       waveform, sample_rate,
        #       [['speed', str(speed)], ['rate', str(sample_rate)]])
        #   sample['wav'] = wav
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/ascend/ModelZoo-PyTorch.git
git@gitee.com:ascend/ModelZoo-PyTorch.git
ascend
ModelZoo-PyTorch
ModelZoo-PyTorch
master

搜索帮助