From 176cf1b165625e474ec2a3d9c07e37f24711e783 Mon Sep 17 00:00:00 2001 From: han_yifeng <hanyifeng2@huawei.com> Date: Wed, 7 Feb 2024 19:10:44 +0800 Subject: [PATCH 1/2] =?UTF-8?q?zipformer=E6=A8=A1=E5=9E=8B=E9=9D=9E?= =?UTF-8?q?=E6=B5=81=E5=BC=8Freadme=E4=B8=8E=E8=84=9A=E6=9C=AC=E4=B8=8A?= =?UTF-8?q?=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/audio/Zipformer/README.md | 221 ++ .../built-in/audio/Zipformer/export-onnx.py | 625 +++++ .../icefall_pt/export_torch_aie_ts_dec.py | 62 + .../icefall_pt/export_torch_aie_ts_enc.py | 65 + .../icefall_pt/export_torch_aie_ts_join.py | 67 + .../Zipformer/icefall_pt/model_pt_dec.py | 50 + .../Zipformer/icefall_pt/model_pt_enc.py | 51 + .../Zipformer/icefall_pt/model_pt_join.py | 51 + .../audio/Zipformer/icefall_pt/pt_val_dec.py | 202 ++ .../audio/Zipformer/icefall_pt/pt_val_enc.py | 202 ++ .../audio/Zipformer/icefall_pt/pt_val_join.py | 203 ++ .../audio/Zipformer/modify_decoder.py | 25 + .../built-in/audio/Zipformer/zipformer.py | 2438 +++++++++++++++++ 13 files changed, 4262 insertions(+) create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/README.md create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.py create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_dec.py create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_enc.py create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_join.py create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_dec.py create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_enc.py create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_join.py create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_dec.py create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_enc.py create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_join.py create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/modify_decoder.py create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.py diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/README.md b/AscendIE/TorchAIE/built-in/audio/Zipformer/README.md new file mode 100644 index 0000000000..7eee8c3e05 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/README.md @@ -0,0 +1,221 @@ +# Zipformer流式模型-推理指导 + + +- [概述](#ZH-CN_TOPIC_0000001172161501) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + +- [模型推理性能精度](#ZH-CN_TOPIC_0000001172201573) + + +# 概述<a name="ZH-CN_TOPIC_0000001172161501"></a> + +(来自论文摘要)Conformer 已成为自动语音识别 (ASR) 中最流行的编码器模型。它将卷积模块添加到变压器中以学习局部和全局依赖性。 +在这项工作中,我们描述了一种更快、内存效率更高、性能更好的转换器,称为 Zipformer。建模变化包括:1)类似 U-Net 的编码器结构,其中中间堆栈以较低的帧速率运行; +2)重新组织了具有更多模块的块结构,其中我们重新使用注意力权重以提高效率;3)LayerNorm的一种修改形式称为BiasNorm,允许我们保留一些长度信息;4)新的激活函数SwooshR和SwooshL比Swish效果更好。 +我们还提出了一个新的优化器,称为 ScaledAdam,它通过每个张量的当前尺度来缩放更新以保持相对变化大致相同,并且还显式地学习参数尺度。它比 Adam 实现了更快的收敛和更好的性能。 +在 LibriSpeech、Aishell-1 和 WenetSpeech 数据集上进行的大量实验证明了我们提出的 Zipformer 相对于其他最先进的 ASR 模型的有效性。 + + +# 推理环境准备\[所有版本\]<a name="ZH-CN_TOPIC_0000001126281702"></a> + +- 该模型需要以下依赖 + + **表 1** 版本配套表 + +| 配套 | 版本 | +|-----------------------|-------------| +| CANN | 7.0.RC1 | - | +| Python | 3.9.11 | +| torch | 2.0.1 | +| Ascend-cann-torch-aie | - +| Ascend-cann-aie | - +| 芯片类型 | Ascend310P3 | - | + +# 快速上手<a name="ZH-CN_TOPIC_0000001126281700"></a> + +## 环境安装 + +1. 安装k2 + 1. (NPU)x86环境 + ```shell + wget https://huggingface.co/csukuangfj/k2/resolve/main/cpu/k2-1.24.4.dev20231220+cpu.torch2.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + pip install k2-1.24.4.dev20231220+cpu.torch2.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + ``` + 2. (NPU/GPU)arm环境,需要从源码编译。 + ```shell + git clone https://github.com/k2-fsa/k2.git + cd k2 + export K2_MAKE_ARGS="-j6" + python3 setup.py install + ``` + 若执行以上命令遇到错误,请参考[此链接](https://k2-fsa.github.io/k2/installation/from_source.html)。 + 3.(GPU)x86环境。从[此链接](https://k2-fsa.github.io/k2/cuda.html)下载对应CUDA版本的whl文件,然后使用pip进行安装。 + 4. 验证k2是否安装成功 + ```shell + python3 -m k2.version + ``` +2. 安装其他依赖 + ```shell + pip install lhotse + pip install kaldifeat + ``` +3. 安装icefall + ```shell + git clone https://github.com/k2-fsa/icefall.git + git reset --hard e2fcb42f5f176d9e39eb38506ab99d0a3adaf202 + + cd icefall + pip install -r requirements.txt + ``` +4. 将icefall加入环境变量, "/path/to/icefall"替换为icefall文件夹所在的路径。 + **这一步很重要,否则会报icefall找不到的错误。** + ```shell + export PYTHONPATH=/path/to/icefall:$PYTHONPATH + ``` + +## 模型下载 +1. 安装 git lfs + ```shell + curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash + + sudo apt-get install git-lfs + git lfs install --skip-repo + ``` +2. 下载模型 + ```shell + git clone https://huggingface.co/pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615 + ``` + 若下载失败,请尝试从以上链接手动下载文件。模型转换和推理时只需要用到以下文件: + - data/lang_char/tokens.txt + - exp/epoch-12.pt + +## 模型推理 +1. 替换代码 + 使用本目录下的zipformer.py替换egs/librispeech/ASR/zipformer/zipformer.py + export-onnx.py替换egs/librispeech/ASR/zipformer/export-onnx.py + +2. 将本代码仓的icefall_pt目录拷贝到工程的根目录./icefall下。 + +3. 导出onnx模型,用于精度测试。 + ```shell + cd icefall/egs/librispeech/ASR/zipformer + # 注意将"icefall-asr-zipformer-streaming-wenetspeech-20230615"修改为实际路径 + python ./export-onnx.py \ + # 暂时没有使用以下参数 + --tokens icefall-asr-zipformer-streaming-wenetspeech-20230615/data/lang_char/tokens.txt \ + --use-averaged-model 0 \ + --epoch 12 \ + --avg 1 \ + --exp-dir icefall-asr-zipformer-streaming-wenetspeech-20230615/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal True \ + --chunk-size 16 \ + --left-context-frames 128 + ``` + 执行结束后,会在“icefall-asr-zipformer-streaming-wenetspeech-20230615/exp”目录下生成三个onnx文件与三个ts文件: + - encoder-epoch-12-avg-1.onnx + - decoder-epoch-12-avg-1.onnx + - joiner-epoch-12-avg-1.onnx + - encoder-epoch-12-avg-1.pt + - decoder-epoch-12-avg-1.pt + - joiner-epoch-12-avg-1.pt +4. 对torchscript模型进行编译。 + ```shell + cd icefall/icefall_pt + + # 注意将"icefall-asr-zipformer-streaming-wenetspeech-20230615"修改为实际路径 + python ./export_torch_aie_ts_enc.py + python ./export_torch_aie_ts_dec.py + python ./export_torch_aie_ts_join.py + ``` + 执行结束后,会在“icefall_pt/pt_compiled_model”目录下生成三个编译好的torchscript文件: + - encoder-epoch-12-avg-1_torch_aie_bs1.pt + - decoder-epoch-12-avg-1_torch_aie_bs1.pt + - joiner-epoch-12-avg-1_torch_aie_bs1.pt + +5. 运行推理样例 + 1. 下载样例语音数据(测试暂时使用全0输入) + ```shell + cd icefall/egs/librispeech/ASR/zipformer + wget https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav + ``` + 2. 执行推理 + ```shell + # 注意将"icefall-asr-zipformer-streaming-wenetspeech-20230615"修改为实际路径 + cd icefall/icefall_pt + python ./pt_val_enc.py + python ./pt_val_dec.py + python ./pt_val_join.py + ``` + 执行结束后,会在icefall_pt/result下看到三个模型的推理结果: + 同时输出PT执行的性能 + + +6. 精度测试 + ```shell + cd icefall/egs/librispeech/ASR/zipformer + 将pretrained.py中的encoder,decoder与joiner的输入改为全全0:enc(1,100,80)+常量100,dec(1,2),joiner(1,512)+(1,512),并保存三个输出 + python ./pretrained.py + 三个输出分别比较余弦相似度 + ``` + 执行结束后,结果为: + ```shell + encoder:0.97+ + decoder:0.99+ + joiner:0.99+ + ``` +7. 性能测试 + 1. aie模型性能测试 + ```shell + 执行pt_val_enc/dec/join.py时会打印 + ``` + 2. onnx模型性能测试。 + 1. (可选)若使用GPU,请确保已安装CUDA和pytorch-gpu版本,同时需安装onnxruntime-gpu,如下所示: + ```shell + + ``` + 执行结束后,三个模型的性能信息会打印在命令行,如下所示: + ```shell + Encoder latency: 964.9530 ms + Encoder throughput: 1.0363 fps + Decoder latency: 0.4806 ms + Decoder throughput: 2080.7143 fps + Joiner latency: 0.4994 ms + Joiner throughput: 2002.3092 fps + ``` + 3. om性能测试: + 在原本onnx模型的基础上,encoder模型进行onnxsim + ```shell + onnxsim encoder-epoch-12-avg-1.onnx encoder-epoch-12-avg-1_sim.onnx + onnxsim decoder-epoch-12-avg-1.onnx decoder-epoch-12-avg-1_sim.onnx + python modify_decoder.py + python3 -m ais_bench --model encoder_linux_aarch64.om --loop=2000 + python3 -m ais_bench --model decoder.om --loop=2000 + python3 -m ais_bench --model joiner.om --loop=2000 + ``` +# 模型推理性能精度<a name="ZH-CN_TOPIC_0000001172201573"></a> + +Zipformer流式模型由三个子模型组成,分别是encoder、decoder和joiner,其性能如下表所示: + +| 模型 | pt插件 - 310P性能(时延/吞吐率) | T4性能(时延/吞吐率) | A10性能(时延/吞吐率) | +|---------|-----------------------|--------------------|--------------------| +| encoder | 20.4 ms / 49 fps | 24.7 ms / 40 fps | 19 ms / 52 fps | +| decoder | 0.19 ms / 5156 fps | 0.59 ms / 1684 fps | 0.13 ms / 7604 fps | +| joiner | 0.22 ms / 4448 fps | 0.13 ms / 7645 fps | 0.11 ms / 9224 fps | +| 端到端 | 20.81 ms / 48 fps | 25.42 ms / 39 fps | 19.24 ms / 52 fps | + diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.py new file mode 100644 index 0000000000..805f7405c8 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.py @@ -0,0 +1,625 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) +# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +python3 ./egs/librispeech/ASR/zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal False \ + --chunk-size "16,32,64,-1" \ + --left-context-frames "64,128,256,-1" + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import k2 +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_model, get_params +from zipformer import Zipformer2 + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=12, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="/home/devkit/hanyifeng/icefall/egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="/home/devkit/hanyifeng/icefall/egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/data/lang_char/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__( + self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear + ): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of Zipformer.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + x_lens: + A 1-D tensor of shape (N,). Its dtype is torch.int64 + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - encoder_out_lens, A 1-D tensor of shape (N,) + """ + x, x_lens = self.encoder_embed(x, x_lens) + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, encoder_out_lens + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - encoder_out_lens, a tensor of shape (N,) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + + encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) + encoder_model.save(str(encoder_filename).replace("onnx", "pt")) + torch.onnx.export( + encoder_model, + (x, x_lens), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["encoder_out", "encoder_out_lens"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "zipformer2", + "version": "1", + "model_author": "k2-fsa", + "comment": "non-streaming zipformer2", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + ts_decoder_model = torch.jit.trace(decoder_model, y) + ts_decoder_model.save(str(decoder_filename).replace("onnx", "pt")) + decoder_model = torch.jit.script(decoder_model) + + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + ts_joiner_model = torch.jit.trace(joiner_model, (projected_encoder_out, projected_decoder_out)) + ts_joiner_model.save(str(joiner_filename).replace("onnx", "pt")) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table["<blk>"] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_embed=model.encoder_embed, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul", "Gather"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_dec.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_dec.py new file mode 100644 index 0000000000..bb341fca41 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_dec.py @@ -0,0 +1,62 @@ +# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse + +import torch +import torch_aie +from torch_aie import _enums + +def export_torch_aie(opt): + trace_model = torch.jit.load(opt.torch_script_path) + trace_model.eval() + + torch_aie.set_device(0) + inputs = [] + # inputs.append(torch_aie.Input([10, 2], dtype = torch_aie.dtype.INT64)) + inputs.append(torch_aie.Input([opt.batch_size, 2], dtype = torch_aie.dtype.INT64)) + + torchaie_model = torch_aie.compile( + trace_model, + inputs=inputs, + precision_policy=_enums.PrecisionPolicy.FP16, + truncate_long_and_double=True, + require_full_compilation=False, + allow_tensor_replace_int=False, + min_block_size=3, + torch_executed_ops=[], + soc_version='Ascend310P3', + optimization_level=0) + suffix = os.path.splitext(opt.torch_script_path)[-1] + saved_name = os.path.basename(opt.torch_script_path).split('.')[0] + f"_torch_aie_bs{opt.batch_size}" + suffix + torchaie_model.save(os.path.join(opt.save_path, saved_name)) + print("torch aie tdnn compiled done. saved model is ", os.path.join(opt.save_path, saved_name)) + + +def parse_opt(): + parser = argparse.ArgumentParser() + parser.add_argument('--torch_script_path', type=str, default='../egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/decoder-epoch-12-avg-1.pt', help='trace model path') + parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version') + parser.add_argument('--batch_size', type=int, default=1, help='batch size') + parser.add_argument('--save_path', type=str, default='./pt_compiled_model/', help='compiled model path') + opt = parser.parse_args() + return opt + +def main(opt): + export_torch_aie(opt) + +if __name__ == '__main__': + opt = parse_opt() + main(opt) \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_enc.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_enc.py new file mode 100644 index 0000000000..05c9e64e06 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_enc.py @@ -0,0 +1,65 @@ +# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse + +import torch +import torch_aie +from torch_aie import _enums + +def export_torch_aie(opt): + trace_model = torch.jit.load(opt.torch_script_path) + trace_model.eval() + + torch_aie.set_device(0) + inputs = [] + # x = torch.zeros(1, 100, 80, dtype=torch.float32) + # x_lens = torch.tensor([100], dtype=torch.int64) + inputs.append(torch_aie.Input([1, 100, 80], dtype = torch_aie.dtype.FLOAT)) + inputs.append(torch_aie.Input([1], dtype = torch_aie.dtype.INT64)) + + torchaie_model = torch_aie.compile( + trace_model, + inputs=inputs, + precision_policy=_enums.PrecisionPolicy.FP16, + # truncate_long_and_double=True, + # require_full_compilation=False, + # allow_tensor_replace_int=False, + # min_block_size=3, + # torch_executed_ops=[], + soc_version='Ascend310P3', + optimization_level=0 + ) + suffix = os.path.splitext(opt.torch_script_path)[-1] + saved_name = os.path.basename(opt.torch_script_path).split('.')[0] + f"_torch_aie_bs{opt.batch_size}" + suffix + torchaie_model.save(os.path.join(opt.save_path, saved_name)) + print("torch aie tdnn compiled done. saved model is ", os.path.join(opt.save_path, saved_name)) + + +def parse_opt(): + parser = argparse.ArgumentParser() + parser.add_argument('--torch_script_path', type=str, default='../egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/encoder-epoch-12-avg-1.pt', help='trace model path') + parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version') + parser.add_argument('--batch_size', type=int, default=1, help='batch size') + parser.add_argument('--save_path', type=str, default='./pt_compiled_model/', help='compiled model path') + opt = parser.parse_args() + return opt + +def main(opt): + export_torch_aie(opt) + +if __name__ == '__main__': + opt = parse_opt() + main(opt) \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_join.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_join.py new file mode 100644 index 0000000000..ae0b1e5b47 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_join.py @@ -0,0 +1,67 @@ +# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse + +import torch +import torch_aie +from torch_aie import _enums + +def export_torch_aie(opt): + trace_model = torch.jit.load(opt.torch_script_path) + trace_model.eval() + + torch_aie.set_device(0) + inputs = [] + # inputs.append(torch_aie.Input([11, 512], dtype = torch_aie.dtype.FLOAT)) + # inputs.append(torch_aie.Input([11, 512], dtype = torch_aie.dtype.FLOAT)) + inputs.append(torch_aie.Input([opt.batch_size, 512], dtype = torch_aie.dtype.FLOAT)) + inputs.append(torch_aie.Input([opt.batch_size, 512], dtype = torch_aie.dtype.FLOAT)) + # projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + # projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + # inputs.append(torch_aie.Input([10], dtype = torch_aie.dtype.INT64)) + + torchaie_model = torch_aie.compile( + trace_model, + inputs=inputs, + precision_policy=_enums.PrecisionPolicy.FP16, + truncate_long_and_double=True, + require_full_compilation=False, + allow_tensor_replace_int=False, + min_block_size=3, + torch_executed_ops=[], + soc_version='Ascend310P3', + optimization_level=0) + suffix = os.path.splitext(opt.torch_script_path)[-1] + saved_name = os.path.basename(opt.torch_script_path).split('.')[0] + f"_torch_aie_bs{opt.batch_size}" + suffix + torchaie_model.save(os.path.join(opt.save_path, saved_name)) + print("torch aie tdnn compiled done. saved model is ", os.path.join(opt.save_path, saved_name)) + + +def parse_opt(): + parser = argparse.ArgumentParser() + parser.add_argument('--torch_script_path', type=str, default='../egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/joiner-epoch-12-avg-1.pt', help='trace model path') + parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version') + parser.add_argument('--batch_size', type=int, default=1, help='batch size') + parser.add_argument('--save_path', type=str, default='./pt_compiled_model/', help='compiled model path') + opt = parser.parse_args() + return opt + +def main(opt): + export_torch_aie(opt) + +if __name__ == '__main__': + opt = parse_opt() + main(opt) \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_dec.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_dec.py new file mode 100644 index 0000000000..d8934b5f29 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_dec.py @@ -0,0 +1,50 @@ +# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from tqdm import tqdm + +import torch +import torch_aie + + +def forward_infer(model, dataloader, batchsize, device_id): + pred_results = [] + inference_time = [] + loop_num = 0 + for snd in tqdm(dataloader): + result, inference_time = pt_infer(model, snd[0].to(torch.int64), device_id, loop_num, inference_time) + pred_results.append(result) + loop_num += 1 + + avg_inf_time = sum(inference_time) / len(inference_time) / batchsize * 1000 + print('performance(ms):', avg_inf_time) + print("throughput(fps): ", 1000 / avg_inf_time) + + return pred_results + +def pt_infer(model, input_li_1, device_id, loop_num, inference_time): + + input_npu_li_1 = input_li_1.to("npu:" + str(device_id)) + stream = torch_aie.npu.Stream("npu:" + str(device_id)) + with torch_aie.npu.stream(stream): + inf_start = time.time() + output_npu = model.forward(input_npu_li_1) + stream.synchronize() + inf_end = time.time() + inf = inf_end - inf_start + if loop_num >= 5: # use 5 step to warmup + inference_time.append(inf) + results = [output_npu.to("cpu")] + return results, inference_time diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_enc.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_enc.py new file mode 100644 index 0000000000..61396050b6 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_enc.py @@ -0,0 +1,51 @@ +# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from tqdm import tqdm +import numpy as np +import torch +import torch_aie + + +def forward_infer(model, dataloader, batchsize, device_id): + pred_results = [] + inference_time = [] + loop_num = 0 + for snd in tqdm(dataloader): + result, inference_time = pt_infer(model, snd[0].to(torch.float32), snd[1].to(torch.int64), device_id, loop_num, inference_time) + pred_results.append(result) + loop_num += 1 + + avg_inf_time = sum(inference_time) / len(inference_time) / batchsize * 1000 + print('performance(ms):', avg_inf_time) + print("throughput(fps): ", 1000 / avg_inf_time) + + return pred_results + +def pt_infer(model, input_li_1, input_li_2, device_id, loop_num, inference_time): + + input_npu_li_1 = input_li_1.to("npu:" + str(device_id)) + input_npu_li_2 = input_li_2.to("npu:" + str(device_id)) + stream = torch_aie.npu.Stream("npu:" + str(device_id)) + with torch_aie.npu.stream(stream): + inf_start = time.time() + output_npu = model.forward(input_npu_li_1, input_npu_li_2) + stream.synchronize() + inf_end = time.time() + inf = inf_end - inf_start + if loop_num >= 5: # use 5 step to warmup + inference_time.append(inf) + results = [output_npu[0].to("cpu"), output_npu[1].to("cpu")] + return results, inference_time diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_join.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_join.py new file mode 100644 index 0000000000..45b8ec0f93 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_join.py @@ -0,0 +1,51 @@ +# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from tqdm import tqdm + +import torch +import torch_aie + + +def forward_infer(model, dataloader, batchsize, device_id): + pred_results = [] + inference_time = [] + loop_num = 0 + for snd in tqdm(dataloader): + result, inference_time = pt_infer(model, snd[0].to(torch.float32), snd[1].to(torch.float32), device_id, loop_num, inference_time) + pred_results.append(result) + loop_num += 1 + + avg_inf_time = sum(inference_time) / len(inference_time) / batchsize * 1000 + print('performance(ms):', avg_inf_time) + print("throughput(fps): ", 1000 / avg_inf_time) + + return pred_results + +def pt_infer(model, input_li_1, input_li_2, device_id, loop_num, inference_time): + + input_npu_li_1 = input_li_1.to("npu:" + str(device_id)) + input_npu_li_2 = input_li_2.to("npu:" + str(device_id)) + stream = torch_aie.npu.Stream("npu:" + str(device_id)) + with torch_aie.npu.stream(stream): + inf_start = time.time() + output_npu = model.forward(input_npu_li_1, input_npu_li_2) + stream.synchronize() + inf_end = time.time() + inf = inf_end - inf_start + if loop_num >= 5: # use 5 step to warmup + inference_time.append(inf) + results = [output_npu.to("cpu")] + return results, inference_time diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_dec.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_dec.py new file mode 100644 index 0000000000..3106753822 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_dec.py @@ -0,0 +1,202 @@ +# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy +import argparse + +import torch +import torch_aie +import numpy as np +from torch_aie import _enums +from torch.utils.data import dataloader + +from model_pt_dec import forward_infer + + +class InfiniteDataLoader(dataloader.DataLoader): + """ Dataloader that reuses workers + + Uses same syntax as vanilla DataLoader + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) + self.iterator = super().__iter__() + + def __len__(self): + return len(self.batch_sampler.sampler) + + def __iter__(self): + for i in range(len(self)): + yield next(self.iterator) + + +class _RepeatSampler: + """ Sampler that repeats forever + + Args: + sampler (Sampler) + """ + + def __init__(self, sampler): + self.sampler = sampler + + def __iter__(self): + while True: + yield from iter(self.sampler) + +# def collate_fn(batch): +# """ +# data preprocessing +# """ +# def func(p): +# """ +# data size +# """ +# return p[0].size(1) + +# batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True) +# longest_sample = max(batch, key=func)[0] +# freq_size = longest_sample.size(0) +# minibatch_size = len(batch) +# max_seqlength = longest_sample.size(1) +# inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength) +# input_percentages = torch.FloatTensor(minibatch_size) +# for x in range(minibatch_size): +# sample = batch[x] +# tensor = sample[0] +# seq_length = tensor.size(1) +# inputs[x][0].narrow(1, 0, seq_length).copy_(tensor) +# input_percentages[x] = seq_length / float(max_seqlength) +# return inputs, input_percentages, [] + + +def get_dataloader(opt): + # with open(to_absolute_path(opt.label_file)) as label_file: + # labels = json.load(label_file) + + # dataset = SpectrogramDataset( + # audio_conf=DataConfig.spect, + # input_path=opt.data_file, + # labels=labels, + # normalize=True, + # aug_cfg=DataConfig.augmentation + # ) + # inputs, input_percentages, _ = collate_fn(dataset) + # input_sizes = input_percentages.mul_(int(inputs.size(3))).int() + # print(inputs[0]) + # print(input_sizes.tolist()) + + # datasets = [[inputs[i], input_sizes[i]] for i in range(len(input_sizes))] + x = torch.zeros(2, dtype=torch.int64) + # x_lens = torch.tensor([100], dtype=torch.int64) + # x_lens = 100 + datasets = [] + for i in range(20): + datasets.append([copy.deepcopy(x)]) + # print(datasets) + while len(datasets) % opt.batch_size != 0: + datasets.append(datasets[-1]) + m = 1 + datasets_orig = copy.deepcopy(datasets) + while m < opt.multi: + datasets += datasets_orig + m += 1 + + loader = InfiniteDataLoader # only DataLoader allows for attribute updates + print("OPT_BATCHSIZE: ", opt.batch_size) + return loader(datasets, + batch_size=opt.batch_size, + shuffle=False, + num_workers=1, + sampler=None, + pin_memory=True) + + +def save_tensor_arr_to_file(arr, file_path): + write_sen = "" + for m in arr: + for l in m: + for c in l: + write_sen += str(c) + " " + write_sen += "\n" + with open(file_path, "w", encoding='utf-8') as f: + f.write(write_sen) + +def save_size_to_file(size, file_path): + write_sen = "" + str(size) + " " + with open(file_path, "w", encoding='utf-8') as f: + f.write(write_sen) + +def main(opt): + # load model + model = torch.jit.load(opt.model) + batch_size = opt.batch_size + torch_aie.set_device(opt.device_id) + if opt.need_compile: + inputs = [] + inputs.append(torch_aie.Input([10, 2], dtype=torch_aie.dtype.INT64)) + # inputs.append(torch_aie.Input([opt.batch_size], dtype=torch_aie.dtype.INT32)) + + model = torch_aie.compile( + model, + inputs=inputs, + precision_policy=_enums.PrecisionPolicy.FP16, + truncate_long_and_double=True, + require_full_compilation=False, + allow_tensor_replace_int=False, + min_block_size=3, + torch_executed_ops=[], + soc_version='Ascend310P3', + optimization_level=0) + + dataloader = get_dataloader(opt) + pred_results = forward_infer(model, dataloader, batch_size, opt.device_id) + # for index, res in enumerate(pred_results): + # print(index, " ", res) + # print(res[0].shape) + + if opt.batch_size == 1 and opt.multi == 1: + result_path = opt.result_path + if(os.path.exists(result_path) == False): + os.makedirs(result_path) + for index, res in enumerate(pred_results): + # for i in range(batch_size): + # result_fname_0 = 'data' + str(index * batch_size + i + 1) + '_0.txt' + # result_fname_1 = 'data' + str(index * batch_size + i + 1) + '_1.txt' + result_fname_0 = 'data' + str(index) + '_0.txt' + # result_fname_1 = 'data' + str(index) + '_1.txt' + # res = np.array(res) + # save_tensor_arr_to_file(np.array(res[0][i]), os.path.join(result_path, result_fname_0)) + # save_size_to_file(res[1].numpy()[i], os.path.join(result_path, result_fname_1)) + save_tensor_arr_to_file(np.array(res), os.path.join(result_path, result_fname_0)) + # save_size_to_file(res[1].numpy()[0], os.path.join(result_path, result_fname_1)) + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='DeepSpeech2 offline model inference.') + parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version') + parser.add_argument('--model', type=str, default="./pt_compiled_model/decoder-epoch-12-avg-1_torch_aie_bs1.pt", help='ts model path') + parser.add_argument('--need_compile', action="store_true", help='if the loaded model needs to be compiled or not') + parser.add_argument('--batch_size', type=int, default=1, help='batch size') + parser.add_argument('--device_id', type=int, default=0, help='device id') + parser.add_argument('--data_file', default='./deepspeech.pytorch/data/an4_test_manifest.json') + parser.add_argument('--label_file', default='./deepspeech.pytorch/labels.json') + parser.add_argument('--result_path', default='result/decoder') + parser.add_argument('--multi', type=int, default=1, help='multiples of dataset replication for enough infer loop. if multi != 1, the pred result will not be stored.') + opt = parser.parse_args() + main(opt) diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_enc.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_enc.py new file mode 100644 index 0000000000..496e767fe0 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_enc.py @@ -0,0 +1,202 @@ +# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy +import argparse + +import torch +import torch_aie +import numpy as np +from torch_aie import _enums +from torch.utils.data import dataloader + +from model_pt_enc import forward_infer + + +class InfiniteDataLoader(dataloader.DataLoader): + """ Dataloader that reuses workers + + Uses same syntax as vanilla DataLoader + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) + self.iterator = super().__iter__() + + def __len__(self): + return len(self.batch_sampler.sampler) + + def __iter__(self): + for i in range(len(self)): + yield next(self.iterator) + + +class _RepeatSampler: + """ Sampler that repeats forever + + Args: + sampler (Sampler) + """ + + def __init__(self, sampler): + self.sampler = sampler + + def __iter__(self): + while True: + yield from iter(self.sampler) + +# def collate_fn(batch): +# """ +# data preprocessing +# """ +# def func(p): +# """ +# data size +# """ +# return p[0].size(1) + +# batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True) +# longest_sample = max(batch, key=func)[0] +# freq_size = longest_sample.size(0) +# minibatch_size = len(batch) +# max_seqlength = longest_sample.size(1) +# inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength) +# input_percentages = torch.FloatTensor(minibatch_size) +# for x in range(minibatch_size): +# sample = batch[x] +# tensor = sample[0] +# seq_length = tensor.size(1) +# inputs[x][0].narrow(1, 0, seq_length).copy_(tensor) +# input_percentages[x] = seq_length / float(max_seqlength) +# return inputs, input_percentages, [] + + +def get_dataloader(opt): + # with open(to_absolute_path(opt.label_file)) as label_file: + # labels = json.load(label_file) + + # dataset = SpectrogramDataset( + # audio_conf=DataConfig.spect, + # input_path=opt.data_file, + # labels=labels, + # normalize=True, + # aug_cfg=DataConfig.augmentation + # ) + # inputs, input_percentages, _ = collate_fn(dataset) + # input_sizes = input_percentages.mul_(int(inputs.size(3))).int() + # print(inputs[0]) + # print(input_sizes.tolist()) + + # datasets = [[inputs[i], input_sizes[i]] for i in range(len(input_sizes))] + x = torch.zeros(100, 80, dtype=torch.float32) + # x_lens = torch.tensor([100], dtype=torch.int64) + x_lens = 100 + datasets = [] + for i in range(20): + datasets.append([copy.deepcopy(x), copy.deepcopy(x_lens)]) + # print(datasets) + while len(datasets) % opt.batch_size != 0: + datasets.append(datasets[-1]) + m = 1 + datasets_orig = copy.deepcopy(datasets) + while m < opt.multi: + datasets += datasets_orig + m += 1 + + loader = InfiniteDataLoader # only DataLoader allows for attribute updates + print("OPT_BATCHSIZE: ", opt.batch_size) + return loader(datasets, + batch_size=opt.batch_size, + shuffle=False, + num_workers=1, + sampler=None, + pin_memory=True) + + +def save_tensor_arr_to_file(arr, file_path): + write_sen = "" + for m in arr: + for l in m: + for c in l: + write_sen += str(c) + " " + write_sen += "\n" + with open(file_path, "w", encoding='utf-8') as f: + f.write(write_sen) + +def save_size_to_file(size, file_path): + write_sen = "" + str(size) + " " + with open(file_path, "w", encoding='utf-8') as f: + f.write(write_sen) + +def main(opt): + # load model + model = torch.jit.load(opt.model) + batch_size = opt.batch_size + torch_aie.set_device(opt.device_id) + if opt.need_compile: + inputs = [] + inputs.append(torch_aie.Input([opt.batch_size, 100, 80], dtype=torch_aie.dtype.FLOAT)) + inputs.append(torch_aie.Input([opt.batch_size], dtype=torch_aie.dtype.INT32)) + + model = torch_aie.compile( + model, + inputs=inputs, + precision_policy=_enums.PrecisionPolicy.FP16, + truncate_long_and_double=True, + require_full_compilation=False, + allow_tensor_replace_int=False, + min_block_size=3, + torch_executed_ops=[], + soc_version='Ascend310P3', + optimization_level=0) + + dataloader = get_dataloader(opt) + pred_results = forward_infer(model, dataloader, batch_size, opt.device_id) + for index, res in enumerate(pred_results): + print(index, " ", res) + print(res[0].shape) + + if opt.batch_size == 1 and opt.multi == 1: + result_path = opt.result_path + if(os.path.exists(result_path) == False): + os.makedirs(result_path) + for index, res in enumerate(pred_results): + # for i in range(batch_size): + # result_fname_0 = 'data' + str(index * batch_size + i + 1) + '_0.txt' + # result_fname_1 = 'data' + str(index * batch_size + i + 1) + '_1.txt' + result_fname_0 = 'data' + str(index) + '_0.txt' + result_fname_1 = 'data' + str(index) + '_1.txt' + # res = np.array(res) + # save_tensor_arr_to_file(np.array(res[0][i]), os.path.join(result_path, result_fname_0)) + # save_size_to_file(res[1].numpy()[i], os.path.join(result_path, result_fname_1)) + save_tensor_arr_to_file(np.array(res[0]), os.path.join(result_path, result_fname_0)) + save_size_to_file(res[1].numpy()[0], os.path.join(result_path, result_fname_1)) + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='DeepSpeech2 offline model inference.') + parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version') + parser.add_argument('--model', type=str, default="./pt_compiled_model/encoder-epoch-12-avg-1_torch_aie_bs1.pt", help='ts model path') + parser.add_argument('--need_compile', action="store_true", help='if the loaded model needs to be compiled or not') + parser.add_argument('--batch_size', type=int, default=1, help='batch size') + parser.add_argument('--device_id', type=int, default=0, help='device id') + parser.add_argument('--data_file', default='./deepspeech.pytorch/data/an4_test_manifest.json') + parser.add_argument('--label_file', default='./deepspeech.pytorch/labels.json') + parser.add_argument('--result_path', default='result/encoder') + parser.add_argument('--multi', type=int, default=1, help='multiples of dataset replication for enough infer loop. if multi != 1, the pred result will not be stored.') + opt = parser.parse_args() + main(opt) diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_join.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_join.py new file mode 100644 index 0000000000..a4fe80f4de --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_join.py @@ -0,0 +1,203 @@ +# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy +import argparse + +import torch +import torch_aie +import numpy as np +from torch_aie import _enums +from torch.utils.data import dataloader + +from model_pt_join import forward_infer + + +class InfiniteDataLoader(dataloader.DataLoader): + """ Dataloader that reuses workers + + Uses same syntax as vanilla DataLoader + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) + self.iterator = super().__iter__() + + def __len__(self): + return len(self.batch_sampler.sampler) + + def __iter__(self): + for i in range(len(self)): + yield next(self.iterator) + + +class _RepeatSampler: + """ Sampler that repeats forever + + Args: + sampler (Sampler) + """ + + def __init__(self, sampler): + self.sampler = sampler + + def __iter__(self): + while True: + yield from iter(self.sampler) + +# def collate_fn(batch): +# """ +# data preprocessing +# """ +# def func(p): +# """ +# data size +# """ +# return p[0].size(1) + +# batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True) +# longest_sample = max(batch, key=func)[0] +# freq_size = longest_sample.size(0) +# minibatch_size = len(batch) +# max_seqlength = longest_sample.size(1) +# inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength) +# input_percentages = torch.FloatTensor(minibatch_size) +# for x in range(minibatch_size): +# sample = batch[x] +# tensor = sample[0] +# seq_length = tensor.size(1) +# inputs[x][0].narrow(1, 0, seq_length).copy_(tensor) +# input_percentages[x] = seq_length / float(max_seqlength) +# return inputs, input_percentages, [] + + +def get_dataloader(opt): + # with open(to_absolute_path(opt.label_file)) as label_file: + # labels = json.load(label_file) + + # dataset = SpectrogramDataset( + # audio_conf=DataConfig.spect, + # input_path=opt.data_file, + # labels=labels, + # normalize=True, + # aug_cfg=DataConfig.augmentation + # ) + # inputs, input_percentages, _ = collate_fn(dataset) + # input_sizes = input_percentages.mul_(int(inputs.size(3))).int() + # print(inputs[0]) + # print(input_sizes.tolist()) + + # datasets = [[inputs[i], input_sizes[i]] for i in range(len(input_sizes))] + x = torch.zeros(512, dtype=torch.float32) + y = torch.zeros(512, dtype=torch.float32) + # x_lens = torch.tensor([100], dtype=torch.int64) + # x_lens = 100 + datasets = [] + for i in range(20): + datasets.append([copy.deepcopy(x), copy.deepcopy(y)]) + # print(datasets) + while len(datasets) % opt.batch_size != 0: + datasets.append(datasets[-1]) + m = 1 + datasets_orig = copy.deepcopy(datasets) + while m < opt.multi: + datasets += datasets_orig + m += 1 + + loader = InfiniteDataLoader # only DataLoader allows for attribute updates + print("OPT_BATCHSIZE: ", opt.batch_size) + return loader(datasets, + batch_size=opt.batch_size, + shuffle=False, + num_workers=1, + sampler=None, + pin_memory=True) + + +def save_tensor_arr_to_file(arr, file_path): + write_sen = "" + for m in arr: + for l in m: + for c in l: + write_sen += str(c) + " " + write_sen += "\n" + with open(file_path, "w", encoding='utf-8') as f: + f.write(write_sen) + +def save_size_to_file(size, file_path): + write_sen = "" + str(size) + " " + with open(file_path, "w", encoding='utf-8') as f: + f.write(write_sen) + +def main(opt): + # load model + model = torch.jit.load(opt.model) + batch_size = opt.batch_size + torch_aie.set_device(opt.device_id) + if opt.need_compile: + inputs = [] + inputs.append(torch_aie.Input([opt.batch_size, 512], dtype = torch_aie.dtype.FLOAT)) + inputs.append(torch_aie.Input([opt.batch_size, 512], dtype = torch_aie.dtype.FLOAT)) + + model = torch_aie.compile( + model, + inputs=inputs, + precision_policy=_enums.PrecisionPolicy.FP16, + truncate_long_and_double=True, + require_full_compilation=False, + allow_tensor_replace_int=False, + min_block_size=3, + torch_executed_ops=[], + soc_version='Ascend310P3', + optimization_level=0) + + dataloader = get_dataloader(opt) + pred_results = forward_infer(model, dataloader, batch_size, opt.device_id) + # for index, res in enumerate(pred_results): + # print(index, " ", res) + # print(res[0].shape) + + if opt.batch_size == 1 and opt.multi == 1: + result_path = opt.result_path + if(os.path.exists(result_path) == False): + os.makedirs(result_path) + for index, res in enumerate(pred_results): + # for i in range(batch_size): + # result_fname_0 = 'data' + str(index * batch_size + i + 1) + '_0.txt' + # result_fname_1 = 'data' + str(index * batch_size + i + 1) + '_1.txt' + result_fname_0 = 'data' + str(index) + '_0.txt' + # result_fname_1 = 'data' + str(index) + '_1.txt' + # res = np.array(res) + # save_tensor_arr_to_file(np.array(res[0][i]), os.path.join(result_path, result_fname_0)) + # save_size_to_file(res[1].numpy()[i], os.path.join(result_path, result_fname_1)) + save_tensor_arr_to_file(np.array(res), os.path.join(result_path, result_fname_0)) + # save_size_to_file(res[1].numpy()[0], os.path.join(result_path, result_fname_1)) + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='DeepSpeech2 offline model inference.') + parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version') + parser.add_argument('--model', type=str, default="./pt_compiled_model/joiner-epoch-12-avg-1_torch_aie_bs1.pt", help='ts model path') + parser.add_argument('--need_compile', action="store_true", help='if the loaded model needs to be compiled or not') + parser.add_argument('--batch_size', type=int, default=1, help='batch size') + parser.add_argument('--device_id', type=int, default=0, help='device id') + parser.add_argument('--data_file', default='./deepspeech.pytorch/data/an4_test_manifest.json') + parser.add_argument('--label_file', default='./deepspeech.pytorch/labels.json') + parser.add_argument('--result_path', default='result/joiner') + parser.add_argument('--multi', type=int, default=1, help='multiples of dataset replication for enough infer loop. if multi != 1, the pred result will not be stored.') + opt = parser.parse_args() + main(opt) diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/modify_decoder.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/modify_decoder.py new file mode 100644 index 0000000000..866216820f --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/modify_decoder.py @@ -0,0 +1,25 @@ +from argparse import ArgumentParser + +from auto_optimizer import OnnxGraph + + +def main(): + parser = ArgumentParser() + parser.add_argument("--onnx", type=str, required=True) + args = parser.parse_args() + + graph = OnnxGraph.parse(args.onnx) + graph.remove("/decoder/Clip") + gather = graph["/decoder/embedding/Gather"] + gather.inputs[1] = "y" + graph.update_map() + graph.infershape() + + g_sim = graph.simplify() + save_path = args.onnx.replace(".onnx", "_modified.onnx") + g_sim.save(save_path) + print("Modified model saved to ", save_path) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.py new file mode 100644 index 0000000000..163eb5f6b2 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.py @@ -0,0 +1,2438 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import math +import random +import warnings +from typing import List, Optional, Tuple, Union + +import torch +from encoder_interface import EncoderInterface +from scaling import ( + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. +) +from scaling import ( + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. +) +from scaling import ( + ActivationDropoutAndLinear, + Balancer, + BiasNorm, + ChunkCausalDepthwiseConv1d, + Dropout2, + FloatLike, + ScheduledFloat, + Whiten, + convert_num_channels, + limit_param_value, + penalize_abs_values_gt, + softmax, +) +from torch import Tensor, nn + + +class Zipformer2(EncoderInterface): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + output_downsampling_factor (int): how much to downsample at the output. Note: + we also downsample by a factor of 2 in the Conv2dSubsampling encoder. + You should probably leave this at 2. + downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. + Note: this is in addition to the downsampling factor of 2 that is applied in + the frontend (self.encoder_embed). + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack. + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of + the encoder stacks for purposes of per-frame dropout (recommend 256 for + now). + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per + attention head + value_head_dim (int or Tuple[int]): dimension of value in each attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules + cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module + + pos_dim (int): the dimension of each positional-encoding vector prior to projection, + e.g. 128. + + dropout (float): dropout rate + warmup_batches (float): number of batches to warm up over; this controls + dropout of encoder layers. + causal (bool): if True, support chunkwise causal convolution. This should + not hurt WER as no modeling power is lost, but the convolution modules will be + slightly slower and use more memory. Enables use of the chunk_size and + left_context_chunks options in forward(), which simulates streaming + decoding. + chunk_size: (list of int): only set this to other than [-1] if causal; + the chunk size will be randomly chosen from this list. -1 means no chunking. + left_context_frames: (list of int): determines the number of left- + context chunks for causal training; will be rounded to a number of + chunks. Must not be less than cnn_module_kernel (after factoring in + rounding and downsampling); an error will be thrown if this is violated. + """ + + def __init__( + self, + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (2, 4), + encoder_dim: Union[int, Tuple[int]] = 384, + num_encoder_layers: Union[int, Tuple[int]] = 4, + encoder_unmasked_dim: Union[int, Tuple[int]] = 256, + query_head_dim: Union[int, Tuple[int]] = 24, + pos_head_dim: Union[int, Tuple[int]] = 4, + value_head_dim: Union[int, Tuple[int]] = 12, + num_heads: Union[int, Tuple[int]] = 8, + feedforward_dim: Union[int, Tuple[int]] = 1536, + cnn_module_kernel: Union[int, Tuple[int]] = 31, + pos_dim: int = 192, + dropout: FloatLike = None, # see code below for default + warmup_batches: float = 4000.0, + causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], + ) -> None: + super(Zipformer2, self).__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + + def _to_tuple(x): + """Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + if isinstance(x, int): + x = (x,) + if len(x) == 1: + x = x * len(downsampling_factor) + else: + assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + return x + + self.output_downsampling_factor = output_downsampling_factor # int + self.downsampling_factor = downsampling_factor # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple + self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple( + encoder_unmasked_dim + ) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers) + self.num_encoder_layers = num_encoder_layers + self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) + self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) + pos_head_dim = _to_tuple(pos_head_dim) + self.num_heads = num_heads = _to_tuple(num_heads) + feedforward_dim = _to_tuple(feedforward_dim) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + + self.causal = causal + self.chunk_size = chunk_size + self.left_context_frames = left_context_frames + + for u, d in zip(encoder_unmasked_dim, encoder_dim): + assert u <= d + + # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder + encoders = [] + + num_encoders = len(downsampling_factor) + for i in range(num_encoders): + encoder_layer = Zipformer2EncoderLayer( + embed_dim=encoder_dim[i], + pos_dim=pos_dim, + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + pos_head_dim=pos_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_dim=feedforward_dim[i], + dropout=dropout, + cnn_module_kernel=cnn_module_kernel[i], + causal=causal, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = Zipformer2Encoder( + encoder_layer, + num_encoder_layers[i], + pos_dim=pos_dim, + dropout=dropout, + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), + ) + + if downsampling_factor[i] != 1: + encoder = DownsampledZipformer2Encoder( + encoder, + dim=encoder_dim[i], + downsample=downsampling_factor[i], + dropout=dropout, + ) + + encoders.append(encoder) + + self.encoders = nn.ModuleList(encoders) + + self.downsample_output = SimpleDownsample( + max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout + ) + + def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: + """ + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. + On e.g. 15% of frames, these masks will zero out all enocder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoer dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.zipformer_subsampling_factor times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (1, batch_size, encoder_dims0) + """ + num_encoders = len(self.encoder_dim) + if not self.training: + return [1.0] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + assert self.encoder_dim[0] == _encoder_dims0, ( + self.encoder_dim[0], + _encoder_dims0, + ) + + feature_mask_dropout_prob = 0.125 + + # mask1 shape: (1, batch_size, 1) + mask1 = ( + torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob + ).to(x.dtype) + + # mask2 has additional sequences masked, about twice the number. + mask2 = torch.logical_and( + mask1, + ( + torch.rand(1, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype), + ) + + # dim: (1, batch_size, 2) + mask = torch.cat((mask1, mask2), dim=-1) + + feature_masks = [] + for i in range(num_encoders): + channels = self.encoder_dim[i] + feature_mask = torch.ones( + 1, batch_size, channels, dtype=x.dtype, device=x.device + ) + u1 = self.encoder_unmasked_dim[i] + u2 = u1 + (channels - u1) // 2 + + feature_mask[:, :, u1:u2] *= mask[..., 0:1] + feature_mask[:, :, u2:] *= mask[..., 1:2] + + feature_masks.append(feature_mask) + + return feature_masks + + def get_chunk_info(self) -> Tuple[int, int]: + """ + Returns chunk_size and left_context_chunks. + """ + if not self.causal: + return -1, -1 + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.chunk_size) == 1, self.chunk_size + chunk_size = self.chunk_size[0] + else: + chunk_size = random.choice(self.chunk_size) + + if chunk_size == -1: + left_context_chunks = -1 + else: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + assert len(self.left_context_frames) == 1, self.left_context_frames + left_context_frames = self.left_context_frames[0] + else: + left_context_frames = random.choice(self.left_context_frames) + # Note: in Python, -1 // n == -1 for n > 0 + left_context_chunks = left_context_frames // chunk_size + if left_context_chunks == 0: + left_context_chunks = 1 + + return chunk_size, left_context_chunks + + def forward( + self, + x: Tensor, + x_lens: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + outputs = [] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + feature_masks = [1.0] * len(self.encoder_dim) + else: + feature_masks = self.get_feature_masks(x) + + chunk_size, left_context_chunks = self.get_chunk_info() + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # Not support exporting a model for simulating streaming decoding + attn_mask = None + else: + attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + + for i, module in enumerate(self.encoders): + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x = module( + x, + chunk_size=chunk_size, + feature_mask=feature_masks[i], + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + attn_mask=attn_mask, + ) + outputs.append(x) + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2, self.output_downsampling_factor + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths + + def _get_attn_mask( + self, x: Tensor, chunk_size: int, left_context_chunks: int + ) -> Optional[Tensor]: + """ + Return None if chunk_size == -1, else return attention mask of shape + (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True + means a masked position. + Args: + x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). + chunk_size: chunk size, must divide + """ + if chunk_size <= 0: + return None + assert all(chunk_size % d == 0 for d in self.downsampling_factor) + if left_context_chunks >= 0: + num_encoders = len(self.encoder_dim) + assert all( + chunk_size * left_context_chunks + >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] + for i in range(num_encoders) + ) + else: + left_context_chunks = 1000000 + + seq_len = x.shape[0] + + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) + # c is chunk index for each frame, shape (seq_len,) + if torch.jit.is_scripting() or torch.jit.is_tracing(): + c = t // chunk_size + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + c = t // chunk_size + src_c = c + tgt_c = c.unsqueeze(-1) + + attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) + if __name__ == "__main__": + logging.info(f"attn_mask = {attn_mask}") + return attn_mask + + def _get_full_dim_output(self, outputs: List[Tensor]): + num_encoders = len(self.encoder_dim) + assert len(outputs) == num_encoders + output_dim = max(self.encoder_dim) + output_pieces = [outputs[-1]] + cur_dim = self.encoder_dim[-1] + for i in range(num_encoders - 2, -1, -1): + d = self.encoder_dim[i] + if d > cur_dim: + this_output = outputs[i] + output_pieces.append(this_output[..., cur_dim:d]) + cur_dim = d + assert cur_dim == output_dim + return torch.cat(output_pieces, dim=-1) + + def streaming_forward( + self, + x: Tensor, + x_lens: Tensor, + states: List[Tensor], + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + states: list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + - updated states + """ + outputs = [] + new_states = [] + layer_offset = 0 + + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + ds = self.downsampling_factor[i] + x = convert_num_channels(x, self.encoder_dim[i]) + + x, new_layer_states = module.streaming_forward( + x, + states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], + left_context_len=self.left_context_frames[0] // ds, + src_key_padding_mask=src_key_padding_mask[..., ::ds], + ) + layer_offset += num_layers + outputs.append(x) + new_states += new_layer_states + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2 + if torch.jit.is_scripting() or torch.jit.is_tracing(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[Tensor]: + """Get initial states. + + A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + """ + states = [] + for i, module in enumerate(self.encoders): + num_layers = module.num_layers + embed_dim = self.encoder_dim[i] + ds = self.downsampling_factor[i] + num_heads = self.num_heads[i] + key_dim = self.query_head_dim[i] * num_heads + value_dim = self.value_head_dim[i] * num_heads + downsample_left = self.left_context_frames[0] // ds + nonlin_attn_head_dim = 3 * embed_dim // 4 + conv_left_pad = self.cnn_module_kernel[i] // 2 + for layer in range(num_layers): + cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( + device + ) + cached_nonlin_attn = torch.zeros( + 1, batch_size, downsample_left, nonlin_attn_head_dim + ).to(device) + cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( + device + ) + cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( + device + ) + states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + return states + + +def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: + return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) + + +def _balancer_schedule(min_prob: float): + return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) + + +class Zipformer2EncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + value_head_dim: int, + feedforward_dim: int, + dropout: FloatLike = 0.1, + cnn_module_kernel: int = 31, + causal: bool = False, + attention_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + conv_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 + ), + const_attention_rate: FloatLike = ScheduledFloat( + (0.0, 0.25), (4000.0, 0.025), default=0 + ), + ff2_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + ff3_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) + ), + bypass_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.5), (4000.0, 0.02), default=0 + ), + ) -> None: + super(Zipformer2EncoderLayer, self).__init__() + self.embed_dim = embed_dim + + # self.bypass implements layer skipping as well as bypass; see its default values. + self.bypass = BypassModule( + embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0 + ) + # bypass_mid is bypass used in the middle of the layer. + self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) + + # skip probability for dynamic modules (meaning: anything but feedforward). + self.attention_skip_rate = copy.deepcopy(attention_skip_rate) + # an additional skip probability that applies to ConvModule to stop it from + # contributing too much early on. + self.conv_skip_rate = copy.deepcopy(conv_skip_rate) + + # ff2_skip_rate is to prevent the ff2 module from having output that's too big + # compared to its residual. + self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) + self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) + + self.const_attention_rate = copy.deepcopy(const_attention_rate) + + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, + pos_dim=pos_dim, + num_heads=num_heads, + query_head_dim=query_head_dim, + pos_head_dim=pos_head_dim, + dropout=0.0, + ) + + self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) + + self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim) + + self.feed_forward1 = FeedforwardModule( + embed_dim, (feedforward_dim * 3) // 4, dropout + ) + + self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) + + self.feed_forward3 = FeedforwardModule( + embed_dim, (feedforward_dim * 5) // 4, dropout + ) + + self.nonlin_attention = NonlinAttention( + embed_dim, hidden_channels=3 * embed_dim // 4 + ) + + self.conv_module1 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) + + self.conv_module2 = ConvolutionModule( + embed_dim, cnn_module_kernel, causal=causal + ) + + # TODO: remove it + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + + self.norm = BiasNorm(embed_dim) + + self.balancer1 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.2, + max_abs=4.0, + ) + + # balancer for output of NonlinAttentionModule + self.balancer_na = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), + prob=0.05, # out of concern for memory usage + ) + + # balancer for output of feedforward2, prevent it from staying too + # small. give this a very small probability, even at the start of + # training, it's to fix a rare problem and it's OK to fix it slowly. + self.balancer_ff2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), + max_abs=2.0, + prob=0.05, + ) + + self.balancer_ff3 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), + max_abs=4.0, + prob=0.05, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(4.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.balancer2 = Balancer( + embed_dim, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + min_abs=0.1, + max_abs=4.0, + ) + + def get_sequence_dropout_mask( + self, x: Tensor, dropout_rate: float + ) -> Optional[Tensor]: + if ( + dropout_rate == 0.0 + or not self.training + or torch.jit.is_scripting() + or torch.jit.is_tracing() + ): + return None + batch_size = x.shape[1] + mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) + return mask + + def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: + """ + Apply sequence-level dropout to x. + x shape: (seq_len, batch_size, embed_dim) + """ + dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) + if dropout_mask is None: + return x + else: + return x * dropout_mask + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + # dropout rate for non-feedforward submodules + if torch.jit.is_scripting() or torch.jit.is_tracing(): + attention_skip_rate = 0.0 + else: + attention_skip_rate = ( + float(self.attention_skip_rate) if self.training else 0.0 + ) + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_mask=attn_mask, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + self_attn_dropout_mask = self.get_sequence_dropout_mask( + src, attention_skip_rate + ) + + selected_attn_weights = attn_weights[0:1] + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif not self.training and random.random() < float(self.const_attention_rate): + # Make attention weights constant. The intention is to + # encourage these modules to do something similar to an + # averaging-over-time operation. + # only need the mask, can just use the 1st one and expand later + selected_attn_weights = selected_attn_weights[0:1] + selected_attn_weights = (selected_attn_weights > 0.0).to( + selected_attn_weights.dtype + ) + selected_attn_weights = selected_attn_weights * ( + 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True) + ) + + na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) + + src = src + ( + na if self_attn_dropout_mask is None else na * self_attn_dropout_mask + ) + + self_attn = self.self_attn1(src, attn_weights) + + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.conv_module1( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + conv_skip_rate, + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff2_skip_rate = 0.0 + else: + ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate + ) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn = self.self_attn2(src, attn_weights) + + src = src + ( + self_attn + if self_attn_dropout_mask is None + else self_attn * self_attn_dropout_mask + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + conv_skip_rate = 0.0 + else: + conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.conv_module2( + src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + ), + conv_skip_rate, + ) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + ff3_skip_rate = 0.0 + else: + ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 + src = src + self.sequence_dropout( + self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate + ) + + src = self.balancer1(src) + src = self.norm(src) + + src = self.bypass(src_orig, src) + + src = self.balancer2(src) + src = self.whiten(src) + + return src + + def streaming_forward( + self, + src: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + cached_nonlin_attn: Tensor, + cached_val1: Tensor, + cached_val2: Tensor, + cached_conv1: Tensor, + cached_conv2: Tensor, + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Pass the input through the encoder layer in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or + (batch_size, left_context_len+2*seq_len-1, pos_emb_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + cached_val1: cached left context for the first attention module, + of shape (left_context_len, batch_size, value_dim) + cached_val2: cached left context for the second attention module, + of shape (left_context_len, batch_size, value_dim) + cached_conv1: cached left context for the first convolution module, + of shape (batch_size, channels, left_pad) + cached_conv2: cached left context for the second convolution module, + of shape (batch_size, channels, left_pad) + left_context_len: number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - x, with the same shape as src + - updated cached_key + - updated cached_nonlin_attn + - updated cached_val1 + - updated cached_val2 + - updated cached_conv1 + - updated cached_conv2 + """ + src_orig = src + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights, cached_key = self.self_attn_weights.streaming_forward( + src, + pos_emb=pos_emb, + cached_key=cached_key, + left_context_len=left_context_len, + key_padding_mask=src_key_padding_mask, + ) + + src = src + self.feed_forward1(src) + + na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( + src, + attn_weights[0:1], + cached_x=cached_nonlin_attn, + left_context_len=left_context_len, + ) + src = src + na + + self_attn, cached_val1 = self.self_attn1.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val1, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv1 = self.conv_module1.streaming_forward( + src, + cache=cached_conv1, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward2(src) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn, cached_val2 = self.self_attn2.streaming_forward( + src, + attn_weights=attn_weights, + cached_val=cached_val2, + left_context_len=left_context_len, + ) + src = src + self_attn + + src_conv, cached_conv2 = self.conv_module2.streaming_forward( + src, + cache=cached_conv2, + src_key_padding_mask=src_key_padding_mask[:, left_context_len:], + ) + src = src + src_conv + + src = src + self.feed_forward3(src) + + src = self.norm(src) + + src = self.bypass(src_orig, src) + + return ( + src, + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) + + +class Zipformer2Encoder(nn.Module): + r"""Zipformer2Encoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + pos_dim: the dimension for the relative positional encoding + + Examples:: + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + pos_dim: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + initial_layerdrop_rate: float = 0.5, + final_layerdrop_rate: float = 0.05, + ) -> None: + super().__init__() + self.encoder_pos = CompactRelPositionalEncoding( + pos_dim, dropout_rate=0.15, length_factor=1.0 + ) + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + assert 0 <= warmup_begin <= warmup_end + + delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin # interpreted as a training batch index + for i in range(num_layers): + cur_end = cur_begin + delta + self.layers[i].bypass.skip_rate = ScheduledFloat( + (cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0, + ) + cur_begin = cur_end + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + pos_emb = self.encoder_pos(src) + output = src + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + output = output * feature_mask + + for i, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + chunk_size=chunk_size, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + output = output * feature_mask + + return output + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape + (batch_size, left_context_len + seq_len); True means masked position. + May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + pos_emb = self.encoder_pos(src, left_context_len) + output = src + + new_states = [] + for i, mod in enumerate(self.layers): + ( + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ) = states[i * 6 : (i + 1) * 6] + ( + output, + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ) = mod.streaming_forward( + output, + pos_emb, + cached_key=cached_key, + cached_nonlin_attn=cached_nonlin_attn, + cached_val1=cached_val1, + cached_val2=cached_val2, + cached_conv1=cached_conv1, + cached_conv2=cached_conv2, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + new_states += [ + new_cached_key, + new_cached_nonlin_attn, + new_cached_val1, + new_cached_val2, + new_cached_conv1, + new_cached_conv2, + ] + + return output, new_states + + +class BypassModule(nn.Module): + """ + An nn.Module that implements a learnable bypass scale, and also randomized per-sequence + layer-skipping. The bypass is limited during early stages of training to be close to + "straight-through", i.e. to not do the bypass operation much initially, in order to + force all the modules to learn something. + """ + + def __init__( + self, + embed_dim: int, + skip_rate: FloatLike = 0.0, + straight_through_rate: FloatLike = 0.0, + scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), + scale_max: FloatLike = 1.0, + ): + super().__init__() + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.skip_rate = copy.deepcopy(skip_rate) + self.straight_through_rate = copy.deepcopy(straight_through_rate) + self.scale_min = copy.deepcopy(scale_min) + self.scale_max = copy.deepcopy(scale_max) + + def _get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the non-residual term, so 0 correponds to bypassing + # this module. + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: + return self.bypass_scale + else: + ans = limit_param_value( + self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max) + ) + skip_rate = float(self.skip_rate) + if skip_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for sequences + # on which we have randomly chosen to do layer-skipping. + straight_through_rate = float(self.straight_through_rate) + if straight_through_rate != 0.0: + mask = ( + torch.rand((batch_size, 1), device=ans.device) + < straight_through_rate + ) + ans = torch.maximum(ans, mask.to(ans.dtype)) + return ans + + def forward(self, src_orig: Tensor, src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + bypass_scale = self._get_bypass_scale(src.shape[1]) + return src_orig + (src - src_orig) * bypass_scale + + +class DownsampledZipformer2Encoder(nn.Module): + r""" + DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, + after convolutional downsampling, and then upsampled again at the output, and combined + with the origin input, so that the output has the same shape as the input. + """ + + def __init__( + self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike + ): + super(DownsampledZipformer2Encoder, self).__init__() + self.downsample_factor = downsample + self.downsample = SimpleDownsample(dim, downsample, dropout) + self.num_layers = encoder.num_layers + self.encoder = encoder + self.upsample = SimpleUpsample(dim, downsample) + self.out_combiner = BypassModule(dim, straight_through_rate=0) + + def forward( + self, + src: Tensor, + chunk_size: int = -1, + feature_mask: Union[Tensor, float] = 1.0, + attn_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Downsample, go through encoder, upsample. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + src_orig = src + src = self.downsample(src) + ds = self.downsample_factor + if attn_mask is not None: + attn_mask = attn_mask[::ds, ::ds] + + src = self.encoder( + src, + chunk_size=chunk_size // ds, + feature_mask=feature_mask, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src) + + def streaming_forward( + self, + src: Tensor, + states: List[Tensor], + left_context_len: int, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, List[Tensor]]: + r"""Downsample, go through encoder, upsample, in streaming forward mode. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is + (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + left_context_len: Number of left context frames. + src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len); + True means masked position. May be None. + + Returns: + - output, a Tensor with the same shape as src. + - updated states + """ + src_orig = src + src = self.downsample(src) + + src, new_states = self.encoder.streaming_forward( + src, + states=states, + left_context_len=left_context_len, + src_key_padding_mask=src_key_padding_mask, + ) + src = self.upsample(src) + # remove any extra frames that are not a multiple of downsample_factor + src = src[: src_orig.shape[0]] + + return self.out_combiner(src_orig, src), new_states + + +class SimpleDownsample(torch.nn.Module): + """ + Does downsampling with attention, by weighted sum, and a projection.. + """ + + def __init__(self, channels: int, downsample: int, dropout: FloatLike): + super(SimpleDownsample, self).__init__() + + self.bias = nn.Parameter(torch.zeros(downsample)) + + self.name = None # will be set from training code + self.dropout = copy.deepcopy(dropout) + + self.downsample = downsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, in_channels) + Returns a tensor of shape + ( (seq_len+downsample-1)//downsample, batch_size, channels) + """ + (seq_len, batch_size, in_channels) = src.shape + ds = self.downsample + d_seq_len = (seq_len + ds - 1) // ds + + # Pad to an exact multiple of self.downsample + # right-pad src, repeating the last element. + pad = d_seq_len * ds - seq_len + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + assert src.shape[0] == d_seq_len * ds + + src = src.reshape(d_seq_len, ds, batch_size, in_channels) + + weights = self.bias.softmax(dim=0) + # weights: (downsample, 1, 1) + weights = weights.unsqueeze(-1).unsqueeze(-1) + + # ans1 is the first `in_channels` channels of the output + ans = (src * weights).sum(dim=1) + + return ans + + +class SimpleUpsample(torch.nn.Module): + """ + A very simple form of upsampling that mostly just repeats the input, but + also adds a position-specific bias. + """ + + def __init__(self, num_channels: int, upsample: int): + super(SimpleUpsample, self).__init__() + self.upsample = upsample + + def forward(self, src: Tensor) -> Tensor: + """ + x: (seq_len, batch_size, num_channels) + Returns a tensor of shape + ( (seq_len*upsample), batch_size, num_channels) + """ + upsample = self.upsample + (seq_len, batch_size, num_channels) = src.shape + src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src.reshape(seq_len * upsample, batch_size, num_channels) + return src + + +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able to encode + the important information about the relative position in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) + make very little difference to the embedding. Such differences were potentially important + when encoding absolute position, but not important when encoding relative position because there + is now no need to compare two large offsets with each other. + + Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the fourier transform of that fixed interval. The + atan() function would compress the "long tails" too small, + making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic + function to compress large offsets to a smaller range before applying atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) + + + Args: + embed_dim: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. + """ + + def __init__( + self, + embed_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, + ) -> None: + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() + self.embed_dim = embed_dim + assert embed_dim % 2 == 0 + self.dropout = Dropout2(dropout_rate) + self.pe = None + assert length_factor >= 1.0 + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + def extend_pe(self, x: Tensor, left_context_len: int = 0): + """Reset the positional encodings.""" + T = x.size(0) + left_context_len + + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= T * 2 - 1: + # self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return self.pe.to(dtype=x.dtype, device=x.device) + + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) + + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution + # for small time offsets but less resolution for large time offsets. + compression_length = self.embed_dim**0.5 + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; + # but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which + # is important. + x_compressed = ( + compression_length + * x.sign() + * ((x.abs() + compression_length).log() - math.log(compression_length)) + ) + + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + # self.pe = pe.to(dtype=x.dtype) + return pe.to(dtype=x.dtype) + + def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: + """Create positional encoding. + + Args: + x (Tensor): Input tensor (time, batch, `*`). + left_context_len: (int): Length of cached left context. + + Returns: + positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + """ + pe = self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + # length of positive side: x.size(0) + left_context_len + # length of negative side: x.size(0) + pos_emb = pe[ + pe.size(0) // 2 + - x_size_left + + 1 : pe.size(0) // 2 # noqa E203 + + x.size(0), + :, + ] + pos_emb = pos_emb.unsqueeze(0) + return self.dropout(pos_emb) + + +class RelPositionMultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with relative position encoding. + Various other modules consume the resulting attention weights: see, for example, the + SimpleAttention module which allows you to compute conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + pos_dim: dimension of the positional encoding vectors, e.g. 128. + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. + dropout: dropout probability for attn_output_weights. Default: 0.0. + pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on + any given call to forward(), in training time. + """ + + def __init__( + self, + embed_dim: int, + pos_dim: int, + num_heads: int, + query_head_dim: int, + pos_head_dim: int, + dropout: float = 0.0, + pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.pos_head_dim = pos_head_dim + self.dropout = dropout + self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) + self.name = None # will be overwritten in training code; for diagnostics. + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.in_proj = ScaledLinear( + embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 + ) + + self.whiten_keys = Whiten( + num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025, + ) + + # add a balancer for the keys that runs with very small probability, and + # tries to enforce that all dimensions have mean around zero. The + # weights produced by this module are invariant to adding a constant to + # the keys, so the derivative of the bias is mathematically zero; but + # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero + # bias because the small numerical roundoff tends to have a non-random + # sign. This module is intended to prevent that. Use a very small + # probability; that should be suffixient to fix the problem. + self.balance_keys = Balancer( + key_head_dim * num_heads, + channel_dim=-1, + min_positive=0.4, + max_positive=0.6, + min_abs=0.0, + max_abs=100.0, + prob=0.025, + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear( + pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 + ) + + # the following are for diagnosics only, see --print-diagnostics option + self.copy_pos_query = Identity() + self.copy_query = Identity() + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), + interpreted as ([batch_size,] tgt_seq_len, src_seq_len) + saying which positions are allowed to attend to which other positions. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + # p is the position-encoding query + p = x[..., 2 * query_dim :] + assert p.shape[-1] == num_heads * pos_head_dim + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + use_pos_scores = False + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # We can't put random.random() in the same line + use_pos_scores = True + elif not self.training or random.random() >= float(self.pos_emb_skip_rate): + use_pos_scores = True + + if use_pos_scores: + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, seq_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif self.training and random.random() < 0.1: + # This is a harder way of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 50.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt( + attn_scores, limit=25.0, penalty=1.0e-04, name=self.name + ) + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + if attn_mask is not None: + assert attn_mask.dtype == torch.bool + # use -1000 to avoid nan's where attn_mask and key_padding_mask make + # all scores zero. It's important that this be large enough that exp(-1000) + # is exactly zero, for reasons related to const_attention_rate, it + # compares the final weights with zero. + attn_scores = attn_scores.masked_fill(attn_mask, -1000) + + if key_padding_mask is not None: + assert key_padding_mask.shape == ( + batch_size, + seq_len, + ), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if torch.jit.is_scripting() or torch.jit.is_tracing(): + pass + elif random.random() < 0.001 and not self.training: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + def streaming_forward( + self, + x: Tensor, + pos_emb: Tensor, + cached_key: Tensor, + left_context_len: int, + key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim) + cached_key: cached attention key tensor of left context, + of shape (left_context_len, batch_size, key_dim) + left_context_len: number of left context frames. + key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + + Returns: + - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + - updated cached attention key tensor of left context. + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_head_dim = self.pos_head_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + # self-attention + q = x[..., 0:query_dim] + k = x[..., query_dim : 2 * query_dim] + # p is the position-encoding query + p = x[..., 2 * query_dim :] + assert p.shape[-1] == num_heads * pos_head_dim + + # Pad cached left contexts + assert cached_key.shape[0] == left_context_len, ( + cached_key.shape[0], + left_context_len, + ) + k = torch.cat([cached_key, k], dim=0) + # Update cached left contexts + cached_key = k[-left_context_len:, ...] + + # The length of key + k_len = k.shape[0] + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + k = k.reshape(k_len, batch_size, num_heads, query_head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + pos_emb = self.linear_pos(pos_emb) + seq_len2 = 2 * seq_len - 1 + left_context_len + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( + 2, 0, 3, 1 + ) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + + if torch.jit.is_tracing(): + (num_heads, batch_size, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(k_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + else: + pos_scores = pos_scores.as_strided( + (num_heads, batch_size, seq_len, k_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (seq_len - 1), + ) + + attn_scores = attn_scores + pos_scores + + assert attn_scores.shape == ( + num_heads, + batch_size, + seq_len, + k_len, + ), attn_scores.shape + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + attn_weights = attn_scores.softmax(dim=-1) + + return attn_weights, cached_key + + def _print_attn_entropy(self, attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = ( + -((attn_weights + 1.0e-20).log() * attn_weights) + .sum(dim=-1) + .mean(dim=(1, 2)) + ) + logging.info( + f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" + ) + + +class SelfAttention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim: the input and output embedding dimension + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) + + self.out_proj = ScaledLinear( + num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + Returns: + a tensor with the same shape as x. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.whiten(x) + + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_val: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), + with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect + attn_weights.sum(dim=-1) == 1. + cached_val: cached attention value tensor of left context, + of shape (left_context_len, batch_size, value_dim) + left_context_len: number of left context frames. + + Returns: + - attention weighted output, a tensor with the same shape as x. + - updated cached attention value tensor of left context. + """ + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + seq_len2 = seq_len + left_context_len + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) + + x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) + + # Pad cached left contexts + assert cached_val.shape[0] == left_context_len, ( + cached_val.shape[0], + left_context_len, + ) + x = torch.cat([cached_val, x], dim=0) + # Update cached left contexts + cached_val = x[-left_context_len:, ...] + + x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, seq_len, value_head_dim) + + x = ( + x.permute(2, 1, 0, 3) + .contiguous() + .view(seq_len, batch_size, num_heads * value_head_dim) + ) + + # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + + return x, cached_val + + +class FeedforwardModule(nn.Module): + """Feedforward module in Zipformer2 model.""" + + def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(embed_dim, feedforward_dim) + + self.hidden_balancer = Balancer( + feedforward_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0, + ) + + # shared_dim=0 means we share the dropout mask along the time axis + self.out_proj = ActivationDropoutAndLinear( + feedforward_dim, + embed_dim, + activation="SwooshL", + dropout_p=dropout, + dropout_shared_dim=0, + bias=True, + initial_scale=0.1, + ) + + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward(self, x: Tensor): + x = self.in_proj(x) + x = self.hidden_balancer(x) + # out_proj contains SwooshL activation, then dropout, then linear. + x = self.out_proj(x) + x = self.out_whiten(x) + return x + + +class NonlinAttention(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed + from the attention module) in place of actual convolution. We also took out the second nonlinearity, the + one after the attention mechanism. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + + # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, + # because we noticed that well-trained instances of this module have abs-value before the sigmoid + # starting from about 3, and poorly-trained instances of the module have smaller abs values + # before the sigmoid. + self.balancer = Balancer( + hidden_channels, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), + max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), + min_abs=0.5, + max_abs=5.0, + ) + self.tanh = nn.Tanh() + + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + self.identity3 = Identity() # for diagnostics. + + self.out_proj = ScaledLinear( + hidden_channels, channels, bias=True, initial_scale=0.05 + ) + + self.whiten1 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.whiten2 = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=2) + + # s will go through tanh. + + s = self.balancer(s) + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + y = self.identity2(y) + x = x * y + x = self.identity3(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + def streaming_forward( + self, + x: Tensor, + attn_weights: Tensor, + cached_x: Tensor, + left_context_len: int, + ) -> Tuple[Tensor, Tensor]: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) + attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + cached_x: left context, a Tensor of shape + (num_heads, batch_size, left_context_len, head_dim) + left_context_len: number of left context frames. + Returns: + - a Tensor with the same shape as x + - updated left context with same shape as cached_x + """ + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=2) + + # s will go through tanh. + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = x * s + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == ( + num_heads, + batch_size, + seq_len, + left_context_len + seq_len, + ) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + + # Pad cached tensor + assert cached_x.shape[2] == left_context_len, ( + cached_x.shape[2], + left_context_len, + ) + x_pad = torch.cat([cached_x, x], dim=2) + # Update cached tensor + cached_x = x_pad[:, :, -left_context_len:, :] + + x = torch.matmul(attn_weights, x_pad) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + x = x * y + + x = self.out_proj(x) + return x, cached_x + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Zipformer2 model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + causal: bool, + ) -> None: + """Construct a ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + bottleneck_dim = channels + self.causal = causal + + self.in_proj = nn.Linear( + channels, + 2 * bottleneck_dim, + ) + # the gradients on in_proj are a little noisy, likely to do with the + # sigmoid in glu. + + # after in_proj we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.balancer1 = Balancer( + bottleneck_dim, + channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), + max_positive=1.0, + min_abs=1.5, + max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), + ) + + self.activation1 = Identity() # for diagnostics + + self.sigmoid = nn.Sigmoid() + + self.activation2 = Identity() # for diagnostics + + assert kernel_size % 2 == 1 + + self.depthwise_conv = ( + ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) + if causal + else nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, + groups=bottleneck_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + ) + + self.balancer2 = Balancer( + bottleneck_dim, + channel_dim=1, + min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), + max_positive=1.0, + min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), + max_abs=10.0, + ) + + self.whiten = Whiten( + num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + self.out_proj = ActivationDropoutAndLinear( + bottleneck_dim, + channels, + activation="SwooshR", + dropout_p=0.0, + initial_scale=0.05, + ) + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + chunk_size: int = -1, + ) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.balancer1(s) + s = self.sigmoid(s) + x = self.activation1(x) # identity. + x = x * s + x = self.activation2(x) # identity + + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + if ( + not torch.jit.is_scripting() + and not torch.jit.is_tracing() + and chunk_size >= 0 + ): + # Not support exporting a model for simulated streaming decoding + assert ( + self.causal + ), "Must initialize model with causal=True if you use chunk_size" + x = self.depthwise_conv(x, chunk_size=chunk_size) + else: + x = self.depthwise_conv(x) + + x = self.balancer2(x) + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.whiten(x) # (time, batch, channels) + x = self.out_proj(x) # (time, batch, channels) + + return x + + def streaming_forward( + self, + x: Tensor, + cache: Tensor, + src_key_padding_mask: Tensor, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module in streaming forward mode. + + Args: + x: Input tensor (#time, batch, channels). + cache: cached left context for depthwise_conv of shape + (#batch, channels, left_pad) + src_key_padding_mask: the mask for the src keys per batch (optional): + (batch, #time), contains True in masked positions. + + Returns: + - Output tensor (#time, batch, channels). + - Updated cache (#batch, channels, left_pad) + """ + + x = self.in_proj(x) # (time, batch, 2*channels) + + x, s = x.chunk(2, dim=2) + s = self.sigmoid(s) + x = x * s + # (time, batch, channels) + + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) + + x = x.permute(2, 0, 1) # (time, batch, channels) + + x = self.out_proj(x) # (time, batch, channels) + + return x, cache + + +class ScalarMultiply(nn.Module): + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, x): + return x * self.scale + + +def _test_zipformer_main(causal: bool = False): + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + + c = Zipformer2( + encoder_dim=(64, 96), + encoder_unmasked_dim=(48, 64), + num_heads=(4, 4), + causal=causal, + chunk_size=(4,) if causal else (-1,), + left_context_frames=(64,), + ) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f[0].sum().backward() + c.eval() + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main(False) + _test_zipformer_main(True) \ No newline at end of file -- Gitee From de2dd27b8525405aae019b69ab1f6b4808596242 Mon Sep 17 00:00:00 2001 From: han_yifeng <hanyifeng2@huawei.com> Date: Sat, 9 Mar 2024 18:08:34 +0800 Subject: [PATCH 2/2] =?UTF-8?q?Zipformer=20PT=E7=BC=96=E8=AF=91=E9=9D=9E?= =?UTF-8?q?=E6=B5=81=E5=BC=8F=E4=B8=8A=E7=BA=BF-part1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/audio/Zipformer/README.md | 218 +- .../audio/Zipformer/export-onnx.patch | 94 + .../built-in/audio/Zipformer/export-onnx.py | 625 ----- ...ie_ts_enc.py => export_torch_aie_model.py} | 42 +- .../icefall_pt/export_torch_aie_ts_dec.py | 62 - .../icefall_pt/export_torch_aie_ts_join.py | 67 - .../Zipformer/icefall_pt/model_pt_dec.py | 8 +- .../Zipformer/icefall_pt/model_pt_enc.py | 9 +- .../Zipformer/icefall_pt/model_pt_join.py | 8 +- .../audio/Zipformer/icefall_pt/pt_val_dec.py | 96 +- .../audio/Zipformer/icefall_pt/pt_val_enc.py | 90 +- .../audio/Zipformer/icefall_pt/pt_val_join.py | 94 +- .../audio/Zipformer/modify_decoder.py | 25 - .../built-in/audio/Zipformer/zipformer.patch | 59 + .../built-in/audio/Zipformer/zipformer.py | 2438 ----------------- 15 files changed, 335 insertions(+), 3600 deletions(-) create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.patch delete mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.py rename AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/{export_torch_aie_ts_enc.py => export_torch_aie_model.py} (59%) delete mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_dec.py delete mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_join.py delete mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/modify_decoder.py create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.patch delete mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.py diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/README.md b/AscendIE/TorchAIE/built-in/audio/Zipformer/README.md index 7eee8c3e05..619ce41e49 100644 --- a/AscendIE/TorchAIE/built-in/audio/Zipformer/README.md +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/README.md @@ -9,7 +9,6 @@ - [模型推理性能精度](#ZH-CN_TOPIC_0000001172201573) - # 概述<a name="ZH-CN_TOPIC_0000001172161501"></a> (来自论文摘要)Conformer 已成为自动语音识别 (ASR) 中最流行的编码器模型。它将卷积模块添加到变压器中以学习局部和全局依赖性。 @@ -17,7 +16,6 @@ 2)重新组织了具有更多模块的块结构,其中我们重新使用注意力权重以提高效率;3)LayerNorm的一种修改形式称为BiasNorm,允许我们保留一些长度信息;4)新的激活函数SwooshR和SwooshL比Swish效果更好。 我们还提出了一个新的优化器,称为 ScaledAdam,它通过每个张量的当前尺度来缩放更新以保持相对变化大致相同,并且还显式地学习参数尺度。它比 Adam 实现了更快的收敛和更好的性能。 在 LibriSpeech、Aishell-1 和 WenetSpeech 数据集上进行的大量实验证明了我们提出的 Zipformer 相对于其他最先进的 ASR 模型的有效性。 - # 推理环境准备\[所有版本\]<a name="ZH-CN_TOPIC_0000001126281702"></a> @@ -25,14 +23,12 @@ **表 1** 版本配套表 -| 配套 | 版本 | -|-----------------------|-------------| -| CANN | 7.0.RC1 | - | -| Python | 3.9.11 | -| torch | 2.0.1 | -| Ascend-cann-torch-aie | - -| Ascend-cann-aie | - -| 芯片类型 | Ascend310P3 | - | +| 配套 | 版本 | +|-----------------------------|-------------| +| CANN | 8.0.T5 | - | +| Python | 3.10.13 | +| torch | 2.1.0 | +| 芯片类型 | Ascend310P3 | - | # 快速上手<a name="ZH-CN_TOPIC_0000001126281700"></a> @@ -51,8 +47,9 @@ export K2_MAKE_ARGS="-j6" python3 setup.py install ``` - 若执行以上命令遇到错误,请参考[此链接](https://k2-fsa.github.io/k2/installation/from_source.html)。 - 3.(GPU)x86环境。从[此链接](https://k2-fsa.github.io/k2/cuda.html)下载对应CUDA版本的whl文件,然后使用pip进行安装。 + * **若编译失败,尝试再次编译前,需要先删除build文件夹。** + * 若执行以上命令遇到错误,请参考[此链接](https://k2-fsa.github.io/k2/installation/from_source.html)。 + 3. (GPU)x86环境。从[此链接](https://k2-fsa.github.io/k2/cuda.html)下载对应CUDA版本的whl文件,然后使用pip进行安装。 4. 验证k2是否安装成功 ```shell python3 -m k2.version @@ -61,13 +58,21 @@ ```shell pip install lhotse pip install kaldifeat + + apt install libsndfile1 + ``` + * kaldifeat若安装失败,请执行以下命令使用源码安装: + ```shell + git clone https://github.com/csukuangfj/kaldifeat.git + cd kaldifeat + python3 setup.py install ``` 3. 安装icefall ```shell git clone https://github.com/k2-fsa/icefall.git - git reset --hard e2fcb42f5f176d9e39eb38506ab99d0a3adaf202 cd icefall + git reset --hard e2fcb42f5f176d9e39eb38506ab99d0a3adaf202 pip install -r requirements.txt ``` 4. 将icefall加入环境变量, "/path/to/icefall"替换为icefall文件夹所在的路径。 @@ -75,41 +80,42 @@ ```shell export PYTHONPATH=/path/to/icefall:$PYTHONPATH ``` - ## 模型下载 -1. 安装 git lfs - ```shell - curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash - - sudo apt-get install git-lfs - git lfs install --skip-repo - ``` -2. 下载模型 - ```shell - git clone https://huggingface.co/pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615 - ``` - 若下载失败,请尝试从以上链接手动下载文件。模型转换和推理时只需要用到以下文件: - - data/lang_char/tokens.txt +从[此链接](https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615)下载模型相关文件。 +模型转换和推理时只需要用到以下文件: + - data/lang_char/tokens.txt - exp/epoch-12.pt +下载完后,整理成如下目录结构: +```shell +icefall-asr-zipformer-wenetspeech-20230615 +├── data +│ └── lang_char +│ └── tokens.txt +└── exp + └── epoch-12.pt +``` ## 模型推理 -1. 替换代码 - 使用本目录下的zipformer.py替换egs/librispeech/ASR/zipformer/zipformer.py - export-onnx.py替换egs/librispeech/ASR/zipformer/export-onnx.py - -2. 将本代码仓的icefall_pt目录拷贝到工程的根目录./icefall下。 +1. 打代码补丁(首先将export_onnx.patch与zipformer.patch拷贝到icefall/egs/librispeech/ASR/zipformer目录下) + ```shell + cd icefall/egs/librispeech/ASR/zipformer/ + + patch < export_onnx.patch + patch < zipformer.patch + ``` +2. 将本代码仓的icefall_pt目录拷贝到icefall工程根目录下。 -3. 导出onnx模型,用于精度测试。 +3. 导出onnx模型与torchscript模型。 ```shell - cd icefall/egs/librispeech/ASR/zipformer - # 注意将"icefall-asr-zipformer-streaming-wenetspeech-20230615"修改为实际路径 - python ./export-onnx.py \ - # 暂时没有使用以下参数 - --tokens icefall-asr-zipformer-streaming-wenetspeech-20230615/data/lang_char/tokens.txt \ + cd icefall/ + repo=/path/to/icefall//egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615 + # 注意将"icefall-asr-zipformer-wenetspeech-20230615"修改为实际路径 + python ./egs/librispeech/ASR/zipformer/export-onnx.py \ + --tokens $repo/data/lang_char/tokens.txt \ --use-averaged-model 0 \ --epoch 12 \ --avg 1 \ - --exp-dir icefall-asr-zipformer-streaming-wenetspeech-20230615/exp \ + --exp-dir $repo/exp \ --num-encoder-layers "2,2,3,4,3,2" \ --downsampling-factor "1,2,4,8,4,2" \ --feedforward-dim "512,768,1024,1536,1024,768" \ @@ -123,99 +129,87 @@ --cnn-module-kernel "31,31,15,15,15,31" \ --decoder-dim 512 \ --joiner-dim 512 \ - --causal True \ - --chunk-size 16 \ - --left-context-frames 128 + --causal False \ + --chunk-size "16, 32, 64, -1" \ + --left-context-frames "64, 128, 256, -1" ``` - 执行结束后,会在“icefall-asr-zipformer-streaming-wenetspeech-20230615/exp”目录下生成三个onnx文件与三个ts文件: + 执行结束后,会在“icefall-asr-zipformer-wenetspeech-20230615/exp”目录下生成三个onnx文件: - encoder-epoch-12-avg-1.onnx - decoder-epoch-12-avg-1.onnx - joiner-epoch-12-avg-1.onnx - encoder-epoch-12-avg-1.pt - decoder-epoch-12-avg-1.pt - joiner-epoch-12-avg-1.pt -4. 对torchscript模型进行编译。 +4. 对torchscript模型使用PT插件进行编译。 ```shell cd icefall/icefall_pt - - # 注意将"icefall-asr-zipformer-streaming-wenetspeech-20230615"修改为实际路径 - python ./export_torch_aie_ts_enc.py - python ./export_torch_aie_ts_dec.py - python ./export_torch_aie_ts_join.py + python export_torch_aie_model.py + ``` + 参数说明: + ```shell + --torch_script_path:torhscript模型路径 + --export_part:选择编译部分(encoder,decoder或joiner) + --soc_version:硬件版本 + --batch_size + --save_path:编译后的模型保存路径 ``` - 执行结束后,会在“icefall_pt/pt_compiled_model”目录下生成三个编译好的torchscript文件: - - encoder-epoch-12-avg-1_torch_aie_bs1.pt - - decoder-epoch-12-avg-1_torch_aie_bs1.pt - - joiner-epoch-12-avg-1_torch_aie_bs1.pt + 执行结束后,会在save_path对应目录下生成三个编译好的torchscript文件: + - encoder-epoch-12-avg-1_mindietorch_bs1.pt + - decoder-epoch-12-avg-1_mindietorch_bs1.pt + - joiner-epoch-12-avg-1_mindietorch_bs1.pt 5. 运行推理样例 - 1. 下载样例语音数据(测试暂时使用全0输入) - ```shell - cd icefall/egs/librispeech/ASR/zipformer - wget https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav - ``` - 2. 执行推理 - ```shell - # 注意将"icefall-asr-zipformer-streaming-wenetspeech-20230615"修改为实际路径 - cd icefall/icefall_pt - python ./pt_val_enc.py - python ./pt_val_dec.py - python ./pt_val_join.py - ``` - 执行结束后,会在icefall_pt/result下看到三个模型的推理结果: - 同时输出PT执行的性能 - - -6. 精度测试 ```shell - cd icefall/egs/librispeech/ASR/zipformer - 将pretrained.py中的encoder,decoder与joiner的输入改为全全0:enc(1,100,80)+常量100,dec(1,2),joiner(1,512)+(1,512),并保存三个输出 - python ./pretrained.py - 三个输出分别比较余弦相似度 + cd icefall/icefall_pt + python pt_val_enc.py + python pt_val_dec.py + python pt_val_join.py ``` - 执行结束后,结果为: + 参数说明: + ```shell + --model:PT模型路径 + --need_compile:是否需要编译。若model指向torchscript模型,则需要进行编译后运行 + --soc_version:硬件版本 + --batch_size + --result_path:模型运行结果保存路径 + --device_id:硬件编号 + --multi:数据加倍的倍数(默认生成20条数据,multi为1。若multi为n,则数据实际为n*20条。注意,若n不为1,则不生成result文件) + ``` +6. 精度测试(将onnx_test下的四个测试脚本拷贝到icefall/egs/librispeech/ASR/zipformer) + ```shell + cd icefall/egs/librispeech/ASR + python ./zipformer/onnx_test_enc.py --encoder-model-filename $repo/exp/encoder-epoch-12-avg-1.onnx --tokens $repo/data/lang_char/tokens.txt + python ./zipformer/onnx_test_dec.py --decoder-model-filename $repo/exp/decoder-epoch-12-avg-1.onnx --tokens $repo/data/lang_char/tokens.txt + python ./zipformer/onnx_test_join.py --joiner-model-filename $repo/exp/joiner-epoch-12-avg-1.onnx --tokens $repo/data/lang_char/tokens.txt + ``` + 执行结束后,将得到的结果路径与步骤5中得到的结果路径分别配置到脚本cosine_similarity_test.py中,并运行查看相似度: + ```shell + python cosine_similarity_test.py + ``` + 可以得出与T4及A10对比的精度结果: ```shell - encoder:0.97+ - decoder:0.99+ - joiner:0.99+ + Cosine Similarity PT_T4_ENC :0.9999+ + Cosine Similarity PT_T4_DEC :1 + Cosine Similarity PT_T4_JOIN :1 + Cosine Similarity PT_A10_ENC :0.9999+ + Cosine Similarity PT_A10_DEC :1 + Cosine Similarity PT_A10_JOIN :1 ``` 7. 性能测试 - 1. aie模型性能测试 - ```shell - 执行pt_val_enc/dec/join.py时会打印 - ``` + 1. pt模型性能测试 + 第5步运行pt_val脚本时会打印PT模型性能 + 2. onnx模型性能测试。 - 1. (可选)若使用GPU,请确保已安装CUDA和pytorch-gpu版本,同时需安装onnxruntime-gpu,如下所示: - ```shell - - ``` - 执行结束后,三个模型的性能信息会打印在命令行,如下所示: - ```shell - Encoder latency: 964.9530 ms - Encoder throughput: 1.0363 fps - Decoder latency: 0.4806 ms - Decoder throughput: 2080.7143 fps - Joiner latency: 0.4994 ms - Joiner throughput: 2002.3092 fps - ``` - 3. om性能测试: - 在原本onnx模型的基础上,encoder模型进行onnxsim - ```shell - onnxsim encoder-epoch-12-avg-1.onnx encoder-epoch-12-avg-1_sim.onnx - onnxsim decoder-epoch-12-avg-1.onnx decoder-epoch-12-avg-1_sim.onnx - python modify_decoder.py - python3 -m ais_bench --model encoder_linux_aarch64.om --loop=2000 - python3 -m ais_bench --model decoder.om --loop=2000 - python3 -m ais_bench --model joiner.om --loop=2000 - ``` + 第6步运行onnx_test脚本时会打印onnx模型性能 + # 模型推理性能精度<a name="ZH-CN_TOPIC_0000001172201573"></a> Zipformer流式模型由三个子模型组成,分别是encoder、decoder和joiner,其性能如下表所示: -| 模型 | pt插件 - 310P性能(时延/吞吐率) | T4性能(时延/吞吐率) | A10性能(时延/吞吐率) | -|---------|-----------------------|--------------------|--------------------| -| encoder | 20.4 ms / 49 fps | 24.7 ms / 40 fps | 19 ms / 52 fps | -| decoder | 0.19 ms / 5156 fps | 0.59 ms / 1684 fps | 0.13 ms / 7604 fps | -| joiner | 0.22 ms / 4448 fps | 0.13 ms / 7645 fps | 0.11 ms / 9224 fps | -| 端到端 | 20.81 ms / 48 fps | 25.42 ms / 39 fps | 19.24 ms / 52 fps | +| 模型 | pt插件 - 310P性能(时延/吞吐率) | T4 onnx性能(时延/吞吐率) | A10 onnx性能(时延/吞吐率) | +|---------|---------------------------|---------------------------|---------------------------| +| encoder | 43.1391 ms / 23.1807 fps | 25.6406 ms / 39.0005 fps | 16.2751 ms / 61.4434 fps | +| decoder | 0.4387 ms / 2279.2528 fps | 0.5691 ms / 1757.0740 fps | 0.1219 ms / 8200.5706 fps | +| joiner | 0.4586 ms / 2180.0839 fps | 0.1526 ms / 6551.7825 fps | 0.1107 ms / 9026.4239 fps | +| 端到端 | 44.0364 ms / 22.7084 fps | 26.3623 ms / 37.9329 fps | 16.5077 ms / 60.5777 fps | diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.patch b/AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.patch new file mode 100644 index 0000000000..c9580a6096 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.patch @@ -0,0 +1,94 @@ +--- export-onnx.py 2024-03-08 15:33:12.340000000 +0800 ++++ export-onnx.py 2024-03-08 15:33:12.340000000 +0800 +@@ -27,10 +27,10 @@ + + 2. Export the model to ONNX + +-./zipformer/export-onnx.py \ +- --tokens $repo/data/lang_bpe_500/tokens.txt \ ++python3 ./export-onnx.py \ ++ --tokens $repo/data/lang_char/tokens.txt \ + --use-averaged-model 0 \ +- --epoch 99 \ ++ --epoch 12 \ + --avg 1 \ + --exp-dir $repo/exp \ + --num-encoder-layers "2,2,3,4,3,2" \ +@@ -92,7 +92,7 @@ + parser.add_argument( + "--epoch", + type=int, +- default=28, ++ default=12, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", +@@ -111,7 +111,7 @@ + parser.add_argument( + "--avg", + type=int, +- default=15, ++ default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", +@@ -120,7 +120,7 @@ + parser.add_argument( + "--use-averaged-model", + type=str2bool, +- default=True, ++ default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." +@@ -131,7 +131,7 @@ + parser.add_argument( + "--exp-dir", + type=str, +- default="zipformer/exp", ++ default="/home/devkit/hanyifeng/icefall/egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, +@@ -140,7 +140,7 @@ + parser.add_argument( + "--tokens", + type=str, +- default="data/lang_bpe_500/tokens.txt", ++ default="/home/devkit/hanyifeng/icefall/egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/data/lang_char/tokens.txt", + help="Path to the tokens.txt", + ) + +@@ -298,7 +298,7 @@ + x_lens = torch.tensor([100], dtype=torch.int64) + + encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) +- ++ encoder_model.save(str(encoder_filename).replace("onnx", "pt")) + torch.onnx.export( + encoder_model, + (x, x_lens), +@@ -352,7 +352,9 @@ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + +- y = torch.zeros(10, context_size, dtype=torch.int64) ++ y = torch.zeros(1, context_size, dtype=torch.int64) ++ ts_decoder_model = torch.jit.trace(decoder_model, y) ++ ts_decoder_model.save(str(decoder_filename).replace("onnx", "pt")) + decoder_model = torch.jit.script(decoder_model) + torch.onnx.export( + decoder_model, +@@ -393,8 +395,10 @@ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + +- projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) +- projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) ++ projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) ++ projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) ++ ts_joiner_model = torch.jit.trace(joiner_model, (projected_encoder_out, projected_decoder_out)) ++ ts_joiner_model.save(str(joiner_filename).replace("onnx", "pt")) + + torch.onnx.export( + joiner_model, diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.py deleted file mode 100644 index 805f7405c8..0000000000 --- a/AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.py +++ /dev/null @@ -1,625 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) -# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) - -""" -This script exports a transducer model from PyTorch to ONNX. - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -python3 ./egs/librispeech/ASR/zipformer/export-onnx.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --num-encoder-layers "2,2,3,4,3,2" \ - --downsampling-factor "1,2,4,8,4,2" \ - --feedforward-dim "512,768,1024,1536,1024,768" \ - --num-heads "4,4,4,8,4,4" \ - --encoder-dim "192,256,384,512,384,256" \ - --query-head-dim 32 \ - --value-head-dim 12 \ - --pos-head-dim 4 \ - --pos-dim 48 \ - --encoder-unmasked-dim "192,192,256,256,256,192" \ - --cnn-module-kernel "31,31,15,15,15,31" \ - --decoder-dim 512 \ - --joiner-dim 512 \ - --causal False \ - --chunk-size "16,32,64,-1" \ - --left-context-frames "64,128,256,-1" - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -See ./onnx_pretrained.py and ./onnx_check.py for how to -use the exported ONNX models. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict, Tuple - -import k2 -import onnx -import torch -import torch.nn as nn -from decoder import Decoder -from onnxruntime.quantization import QuantType, quantize_dynamic -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_model, get_params -from zipformer import Zipformer2 - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import make_pad_mask, num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=12, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=1, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=False, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="/home/devkit/hanyifeng/icefall/egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="/home/devkit/hanyifeng/icefall/egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/data/lang_char/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = value - - onnx.save(model, filename) - - -class OnnxEncoder(nn.Module): - """A wrapper for Zipformer and the encoder_proj from the joiner""" - - def __init__( - self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear - ): - """ - Args: - encoder: - A Zipformer encoder. - encoder_proj: - The projection layer for encoder from the joiner. - """ - super().__init__() - self.encoder = encoder - self.encoder_embed = encoder_embed - self.encoder_proj = encoder_proj - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Please see the help information of Zipformer.forward - - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 1-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a tuple containing: - - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) - - encoder_out_lens, A 1-D tensor of shape (N,) - """ - x, x_lens = self.encoder_embed(x, x_lens) - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) - encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) - encoder_out = self.encoder_proj(encoder_out) - # Now encoder_out is of shape (N, T, joiner_dim) - - return encoder_out, encoder_out_lens - - -class OnnxDecoder(nn.Module): - """A wrapper for Decoder and the decoder_proj from the joiner""" - - def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): - super().__init__() - self.decoder = decoder - self.decoder_proj = decoder_proj - - def forward(self, y: torch.Tensor) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, context_size). - Returns - Return a 2-D tensor of shape (N, joiner_dim) - """ - need_pad = False - decoder_output = self.decoder(y, need_pad=need_pad) - decoder_output = decoder_output.squeeze(1) - output = self.decoder_proj(decoder_output) - - return output - - -class OnnxJoiner(nn.Module): - """A wrapper for the joiner""" - - def __init__(self, output_linear: nn.Linear): - super().__init__() - self.output_linear = output_linear - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - logit = encoder_out + decoder_out - logit = self.output_linear(torch.tanh(logit)) - return logit - - -def export_encoder_model_onnx( - encoder_model: OnnxEncoder, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has two inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - and it has two outputs: - - - encoder_out, a tensor of shape (N, T', joiner_dim) - - encoder_out_lens, a tensor of shape (N,) - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) - - encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) - encoder_model.save(str(encoder_filename).replace("onnx", "pt")) - torch.onnx.export( - encoder_model, - (x, x_lens), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens"], - output_names=["encoder_out", "encoder_out_lens"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - }, - ) - - meta_data = { - "model_type": "zipformer2", - "version": "1", - "model_author": "k2-fsa", - "comment": "non-streaming zipformer2", - } - logging.info(f"meta_data: {meta_data}") - - add_meta_data(filename=encoder_filename, meta_data=meta_data) - - -def export_decoder_model_onnx( - decoder_model: OnnxDecoder, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - context_size = decoder_model.decoder.context_size - vocab_size = decoder_model.decoder.vocab_size - - y = torch.zeros(10, context_size, dtype=torch.int64) - ts_decoder_model = torch.jit.trace(decoder_model, y) - ts_decoder_model.save(str(decoder_filename).replace("onnx", "pt")) - decoder_model = torch.jit.script(decoder_model) - - torch.onnx.export( - decoder_model, - y, - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - - meta_data = { - "context_size": str(context_size), - "vocab_size": str(vocab_size), - } - add_meta_data(filename=decoder_filename, meta_data=meta_data) - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - encoder_out: a tensor of shape (N, joiner_dim) - - decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - """ - joiner_dim = joiner_model.output_linear.weight.shape[1] - logging.info(f"joiner dim: {joiner_dim}") - - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - ts_joiner_model = torch.jit.trace(joiner_model, (projected_encoder_out, projected_decoder_out)) - ts_joiner_model.save(str(joiner_filename).replace("onnx", "pt")) - - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "encoder_out", - "decoder_out", - ], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - meta_data = { - "joiner_dim": str(joiner_dim), - } - add_meta_data(filename=joiner_filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table["<blk>"] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) - - encoder = OnnxEncoder( - encoder=model.encoder, - encoder_embed=model.encoder_embed, - encoder_proj=model.joiner.encoder_proj, - ) - - decoder = OnnxDecoder( - decoder=model.decoder, - decoder_proj=model.joiner.decoder_proj, - ) - - joiner = OnnxJoiner(output_linear=model.joiner.output_linear) - - encoder_num_param = sum([p.numel() for p in encoder.parameters()]) - decoder_num_param = sum([p.numel() for p in decoder.parameters()]) - joiner_num_param = sum([p.numel() for p in joiner.parameters()]) - total_num_param = encoder_num_param + decoder_num_param + joiner_num_param - logging.info(f"encoder parameters: {encoder_num_param}") - logging.info(f"decoder parameters: {decoder_num_param}") - logging.info(f"joiner parameters: {joiner_num_param}") - logging.info(f"total parameters: {total_num_param}") - - if params.iter > 0: - suffix = f"iter-{params.iter}" - else: - suffix = f"epoch-{params.epoch}" - - suffix += f"-avg-{params.avg}" - - opset_version = 13 - - logging.info("Exporting encoder") - encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" - export_encoder_model_onnx( - encoder, - encoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported encoder to {encoder_filename}") - - logging.info("Exporting decoder") - decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" - export_decoder_model_onnx( - decoder, - decoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported decoder to {decoder_filename}") - - logging.info("Exporting joiner") - joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" - export_joiner_model_onnx( - joiner, - joiner_filename, - opset_version=opset_version, - ) - logging.info(f"Exported joiner to {joiner_filename}") - - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=encoder_filename, - model_output=encoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=decoder_filename, - model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul", "Gather"], - weight_type=QuantType.QInt8, - ) - - joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" - quantize_dynamic( - model_input=joiner_filename, - model_output=joiner_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_enc.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_model.py similarity index 59% rename from AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_enc.py rename to AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_model.py index 05c9e64e06..0019d4937f 100644 --- a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_enc.py +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_model.py @@ -1,4 +1,4 @@ -# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,34 +16,35 @@ import os import argparse import torch -import torch_aie -from torch_aie import _enums +import mindietorch +from mindietorch import _enums -def export_torch_aie(opt): +def export_mindietorch(opt): trace_model = torch.jit.load(opt.torch_script_path) trace_model.eval() - torch_aie.set_device(0) + mindietorch.set_device(0) inputs = [] - # x = torch.zeros(1, 100, 80, dtype=torch.float32) - # x_lens = torch.tensor([100], dtype=torch.int64) - inputs.append(torch_aie.Input([1, 100, 80], dtype = torch_aie.dtype.FLOAT)) - inputs.append(torch_aie.Input([1], dtype = torch_aie.dtype.INT64)) + if opt.export_part == 'encoder': + inputs.append(mindietorch.Input([opt.batch_size, 100, 80], dtype = mindietorch.dtype.FLOAT)) + inputs.append(mindietorch.Input([opt.batch_size], dtype = mindietorch.dtype.INT64)) + elif opt.export_part == 'decoder': + inputs.append(mindietorch.Input([opt.batch_size, 2], dtype=mindietorch.dtype.INT64)) + else: + inputs.append(mindietorch.Input([opt.batch_size, 512], dtype=mindietorch.dtype.FLOAT)) + inputs.append(mindietorch.Input([opt.batch_size, 512], dtype=mindietorch.dtype.FLOAT)) - torchaie_model = torch_aie.compile( + torchaie_model = mindietorch.compile( trace_model, inputs=inputs, - precision_policy=_enums.PrecisionPolicy.FP16, - # truncate_long_and_double=True, - # require_full_compilation=False, - # allow_tensor_replace_int=False, - # min_block_size=3, - # torch_executed_ops=[], - soc_version='Ascend310P3', + precision_policy=_enums.PrecisionPolicy.FP32, + soc_version=opt.soc_version, optimization_level=0 - ) + ) suffix = os.path.splitext(opt.torch_script_path)[-1] - saved_name = os.path.basename(opt.torch_script_path).split('.')[0] + f"_torch_aie_bs{opt.batch_size}" + suffix + saved_name = os.path.basename(opt.torch_script_path).split('.')[0] + f"_mindietorch_bs{opt.batch_size}" + suffix + if not os.path.exists(opt.save_path): + os.makedirs(opt.save_path) torchaie_model.save(os.path.join(opt.save_path, saved_name)) print("torch aie tdnn compiled done. saved model is ", os.path.join(opt.save_path, saved_name)) @@ -51,6 +52,7 @@ def export_torch_aie(opt): def parse_opt(): parser = argparse.ArgumentParser() parser.add_argument('--torch_script_path', type=str, default='../egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/encoder-epoch-12-avg-1.pt', help='trace model path') + parser.add_argument('--export_part', type=str, default='encoder', help='the part of model(encoder, decoder, and joiner) to be exported.') parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version') parser.add_argument('--batch_size', type=int, default=1, help='batch size') parser.add_argument('--save_path', type=str, default='./pt_compiled_model/', help='compiled model path') @@ -58,7 +60,7 @@ def parse_opt(): return opt def main(opt): - export_torch_aie(opt) + export_mindietorch(opt) if __name__ == '__main__': opt = parse_opt() diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_dec.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_dec.py deleted file mode 100644 index bb341fca41..0000000000 --- a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_dec.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import argparse - -import torch -import torch_aie -from torch_aie import _enums - -def export_torch_aie(opt): - trace_model = torch.jit.load(opt.torch_script_path) - trace_model.eval() - - torch_aie.set_device(0) - inputs = [] - # inputs.append(torch_aie.Input([10, 2], dtype = torch_aie.dtype.INT64)) - inputs.append(torch_aie.Input([opt.batch_size, 2], dtype = torch_aie.dtype.INT64)) - - torchaie_model = torch_aie.compile( - trace_model, - inputs=inputs, - precision_policy=_enums.PrecisionPolicy.FP16, - truncate_long_and_double=True, - require_full_compilation=False, - allow_tensor_replace_int=False, - min_block_size=3, - torch_executed_ops=[], - soc_version='Ascend310P3', - optimization_level=0) - suffix = os.path.splitext(opt.torch_script_path)[-1] - saved_name = os.path.basename(opt.torch_script_path).split('.')[0] + f"_torch_aie_bs{opt.batch_size}" + suffix - torchaie_model.save(os.path.join(opt.save_path, saved_name)) - print("torch aie tdnn compiled done. saved model is ", os.path.join(opt.save_path, saved_name)) - - -def parse_opt(): - parser = argparse.ArgumentParser() - parser.add_argument('--torch_script_path', type=str, default='../egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/decoder-epoch-12-avg-1.pt', help='trace model path') - parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version') - parser.add_argument('--batch_size', type=int, default=1, help='batch size') - parser.add_argument('--save_path', type=str, default='./pt_compiled_model/', help='compiled model path') - opt = parser.parse_args() - return opt - -def main(opt): - export_torch_aie(opt) - -if __name__ == '__main__': - opt = parse_opt() - main(opt) \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_join.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_join.py deleted file mode 100644 index ae0b1e5b47..0000000000 --- a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_join.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import argparse - -import torch -import torch_aie -from torch_aie import _enums - -def export_torch_aie(opt): - trace_model = torch.jit.load(opt.torch_script_path) - trace_model.eval() - - torch_aie.set_device(0) - inputs = [] - # inputs.append(torch_aie.Input([11, 512], dtype = torch_aie.dtype.FLOAT)) - # inputs.append(torch_aie.Input([11, 512], dtype = torch_aie.dtype.FLOAT)) - inputs.append(torch_aie.Input([opt.batch_size, 512], dtype = torch_aie.dtype.FLOAT)) - inputs.append(torch_aie.Input([opt.batch_size, 512], dtype = torch_aie.dtype.FLOAT)) - # projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - # projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - # inputs.append(torch_aie.Input([10], dtype = torch_aie.dtype.INT64)) - - torchaie_model = torch_aie.compile( - trace_model, - inputs=inputs, - precision_policy=_enums.PrecisionPolicy.FP16, - truncate_long_and_double=True, - require_full_compilation=False, - allow_tensor_replace_int=False, - min_block_size=3, - torch_executed_ops=[], - soc_version='Ascend310P3', - optimization_level=0) - suffix = os.path.splitext(opt.torch_script_path)[-1] - saved_name = os.path.basename(opt.torch_script_path).split('.')[0] + f"_torch_aie_bs{opt.batch_size}" + suffix - torchaie_model.save(os.path.join(opt.save_path, saved_name)) - print("torch aie tdnn compiled done. saved model is ", os.path.join(opt.save_path, saved_name)) - - -def parse_opt(): - parser = argparse.ArgumentParser() - parser.add_argument('--torch_script_path', type=str, default='../egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/joiner-epoch-12-avg-1.pt', help='trace model path') - parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version') - parser.add_argument('--batch_size', type=int, default=1, help='batch size') - parser.add_argument('--save_path', type=str, default='./pt_compiled_model/', help='compiled model path') - opt = parser.parse_args() - return opt - -def main(opt): - export_torch_aie(opt) - -if __name__ == '__main__': - opt = parse_opt() - main(opt) \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_dec.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_dec.py index d8934b5f29..9149479660 100644 --- a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_dec.py +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_dec.py @@ -1,4 +1,4 @@ -# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ import time from tqdm import tqdm import torch -import torch_aie +import mindietorch def forward_infer(model, dataloader, batchsize, device_id): @@ -37,8 +37,8 @@ def forward_infer(model, dataloader, batchsize, device_id): def pt_infer(model, input_li_1, device_id, loop_num, inference_time): input_npu_li_1 = input_li_1.to("npu:" + str(device_id)) - stream = torch_aie.npu.Stream("npu:" + str(device_id)) - with torch_aie.npu.stream(stream): + stream = mindietorch.npu.Stream("npu:" + str(device_id)) + with mindietorch.npu.stream(stream): inf_start = time.time() output_npu = model.forward(input_npu_li_1) stream.synchronize() diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_enc.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_enc.py index 61396050b6..c2197fc9e0 100644 --- a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_enc.py +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_enc.py @@ -1,4 +1,4 @@ -# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,9 +14,8 @@ import time from tqdm import tqdm -import numpy as np import torch -import torch_aie +import mindietorch def forward_infer(model, dataloader, batchsize, device_id): @@ -38,8 +37,8 @@ def pt_infer(model, input_li_1, input_li_2, device_id, loop_num, inference_time) input_npu_li_1 = input_li_1.to("npu:" + str(device_id)) input_npu_li_2 = input_li_2.to("npu:" + str(device_id)) - stream = torch_aie.npu.Stream("npu:" + str(device_id)) - with torch_aie.npu.stream(stream): + stream = mindietorch.npu.Stream("npu:" + str(device_id)) + with mindietorch.npu.stream(stream): inf_start = time.time() output_npu = model.forward(input_npu_li_1, input_npu_li_2) stream.synchronize() diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_join.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_join.py index 45b8ec0f93..858046037d 100644 --- a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_join.py +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_join.py @@ -1,4 +1,4 @@ -# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ import time from tqdm import tqdm import torch -import torch_aie +import mindietorch def forward_infer(model, dataloader, batchsize, device_id): @@ -38,8 +38,8 @@ def pt_infer(model, input_li_1, input_li_2, device_id, loop_num, inference_time) input_npu_li_1 = input_li_1.to("npu:" + str(device_id)) input_npu_li_2 = input_li_2.to("npu:" + str(device_id)) - stream = torch_aie.npu.Stream("npu:" + str(device_id)) - with torch_aie.npu.stream(stream): + stream = mindietorch.npu.Stream("npu:" + str(device_id)) + with mindietorch.npu.stream(stream): inf_start = time.time() output_npu = model.forward(input_npu_li_1, input_npu_li_2) stream.synchronize() diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_dec.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_dec.py index 3106753822..8363fe5177 100644 --- a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_dec.py +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_dec.py @@ -1,4 +1,4 @@ -# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,9 +17,9 @@ import copy import argparse import torch -import torch_aie +import mindietorch import numpy as np -from torch_aie import _enums +from mindietorch import _enums from torch.utils.data import dataloader from model_pt_dec import forward_infer @@ -58,56 +58,12 @@ class _RepeatSampler: while True: yield from iter(self.sampler) -# def collate_fn(batch): -# """ -# data preprocessing -# """ -# def func(p): -# """ -# data size -# """ -# return p[0].size(1) - -# batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True) -# longest_sample = max(batch, key=func)[0] -# freq_size = longest_sample.size(0) -# minibatch_size = len(batch) -# max_seqlength = longest_sample.size(1) -# inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength) -# input_percentages = torch.FloatTensor(minibatch_size) -# for x in range(minibatch_size): -# sample = batch[x] -# tensor = sample[0] -# seq_length = tensor.size(1) -# inputs[x][0].narrow(1, 0, seq_length).copy_(tensor) -# input_percentages[x] = seq_length / float(max_seqlength) -# return inputs, input_percentages, [] - def get_dataloader(opt): - # with open(to_absolute_path(opt.label_file)) as label_file: - # labels = json.load(label_file) - - # dataset = SpectrogramDataset( - # audio_conf=DataConfig.spect, - # input_path=opt.data_file, - # labels=labels, - # normalize=True, - # aug_cfg=DataConfig.augmentation - # ) - # inputs, input_percentages, _ = collate_fn(dataset) - # input_sizes = input_percentages.mul_(int(inputs.size(3))).int() - # print(inputs[0]) - # print(input_sizes.tolist()) - - # datasets = [[inputs[i], input_sizes[i]] for i in range(len(input_sizes))] x = torch.zeros(2, dtype=torch.int64) - # x_lens = torch.tensor([100], dtype=torch.int64) - # x_lens = 100 datasets = [] for i in range(20): datasets.append([copy.deepcopy(x)]) - # print(datasets) while len(datasets) % opt.batch_size != 0: datasets.append(datasets[-1]) m = 1 @@ -136,66 +92,40 @@ def save_tensor_arr_to_file(arr, file_path): with open(file_path, "w", encoding='utf-8') as f: f.write(write_sen) -def save_size_to_file(size, file_path): - write_sen = "" + str(size) + " " - with open(file_path, "w", encoding='utf-8') as f: - f.write(write_sen) def main(opt): # load model model = torch.jit.load(opt.model) batch_size = opt.batch_size - torch_aie.set_device(opt.device_id) + mindietorch.set_device(opt.device_id) if opt.need_compile: inputs = [] - inputs.append(torch_aie.Input([10, 2], dtype=torch_aie.dtype.INT64)) - # inputs.append(torch_aie.Input([opt.batch_size], dtype=torch_aie.dtype.INT32)) - - model = torch_aie.compile( + inputs.append(mindietorch.Input([batch_size, 2], dtype=mindietorch.dtype.INT64)) + model = mindietorch.compile( model, inputs=inputs, - precision_policy=_enums.PrecisionPolicy.FP16, - truncate_long_and_double=True, - require_full_compilation=False, - allow_tensor_replace_int=False, - min_block_size=3, - torch_executed_ops=[], - soc_version='Ascend310P3', - optimization_level=0) - + precision_policy=_enums.PrecisionPolicy.FP32, + soc_version=opt.soc_version, + optimization_level=0 + ) dataloader = get_dataloader(opt) pred_results = forward_infer(model, dataloader, batch_size, opt.device_id) - # for index, res in enumerate(pred_results): - # print(index, " ", res) - # print(res[0].shape) if opt.batch_size == 1 and opt.multi == 1: result_path = opt.result_path if(os.path.exists(result_path) == False): - os.makedirs(result_path) + os.makedirs(result_path) for index, res in enumerate(pred_results): - # for i in range(batch_size): - # result_fname_0 = 'data' + str(index * batch_size + i + 1) + '_0.txt' - # result_fname_1 = 'data' + str(index * batch_size + i + 1) + '_1.txt' result_fname_0 = 'data' + str(index) + '_0.txt' - # result_fname_1 = 'data' + str(index) + '_1.txt' - # res = np.array(res) - # save_tensor_arr_to_file(np.array(res[0][i]), os.path.join(result_path, result_fname_0)) - # save_size_to_file(res[1].numpy()[i], os.path.join(result_path, result_fname_1)) save_tensor_arr_to_file(np.array(res), os.path.join(result_path, result_fname_0)) - # save_size_to_file(res[1].numpy()[0], os.path.join(result_path, result_fname_1)) - - if __name__ == '__main__': - parser = argparse.ArgumentParser(description='DeepSpeech2 offline model inference.') + parser = argparse.ArgumentParser(description='Zipformer offline model decoder inference.') parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version') - parser.add_argument('--model', type=str, default="./pt_compiled_model/decoder-epoch-12-avg-1_torch_aie_bs1.pt", help='ts model path') + parser.add_argument('--model', type=str, default="./pt_compiled_model/decoder-epoch-12-avg-1_mindietorch_bs1.pt", help='ts model path') parser.add_argument('--need_compile', action="store_true", help='if the loaded model needs to be compiled or not') parser.add_argument('--batch_size', type=int, default=1, help='batch size') parser.add_argument('--device_id', type=int, default=0, help='device id') - parser.add_argument('--data_file', default='./deepspeech.pytorch/data/an4_test_manifest.json') - parser.add_argument('--label_file', default='./deepspeech.pytorch/labels.json') parser.add_argument('--result_path', default='result/decoder') parser.add_argument('--multi', type=int, default=1, help='multiples of dataset replication for enough infer loop. if multi != 1, the pred result will not be stored.') opt = parser.parse_args() diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_enc.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_enc.py index 496e767fe0..55af8f7239 100644 --- a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_enc.py +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_enc.py @@ -1,4 +1,4 @@ -# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,9 +17,9 @@ import copy import argparse import torch -import torch_aie +import mindietorch import numpy as np -from torch_aie import _enums +from mindietorch import _enums from torch.utils.data import dataloader from model_pt_enc import forward_infer @@ -58,56 +58,12 @@ class _RepeatSampler: while True: yield from iter(self.sampler) -# def collate_fn(batch): -# """ -# data preprocessing -# """ -# def func(p): -# """ -# data size -# """ -# return p[0].size(1) - -# batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True) -# longest_sample = max(batch, key=func)[0] -# freq_size = longest_sample.size(0) -# minibatch_size = len(batch) -# max_seqlength = longest_sample.size(1) -# inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength) -# input_percentages = torch.FloatTensor(minibatch_size) -# for x in range(minibatch_size): -# sample = batch[x] -# tensor = sample[0] -# seq_length = tensor.size(1) -# inputs[x][0].narrow(1, 0, seq_length).copy_(tensor) -# input_percentages[x] = seq_length / float(max_seqlength) -# return inputs, input_percentages, [] - - def get_dataloader(opt): - # with open(to_absolute_path(opt.label_file)) as label_file: - # labels = json.load(label_file) - - # dataset = SpectrogramDataset( - # audio_conf=DataConfig.spect, - # input_path=opt.data_file, - # labels=labels, - # normalize=True, - # aug_cfg=DataConfig.augmentation - # ) - # inputs, input_percentages, _ = collate_fn(dataset) - # input_sizes = input_percentages.mul_(int(inputs.size(3))).int() - # print(inputs[0]) - # print(input_sizes.tolist()) - - # datasets = [[inputs[i], input_sizes[i]] for i in range(len(input_sizes))] x = torch.zeros(100, 80, dtype=torch.float32) - # x_lens = torch.tensor([100], dtype=torch.int64) x_lens = 100 datasets = [] for i in range(20): datasets.append([copy.deepcopy(x), copy.deepcopy(x_lens)]) - # print(datasets) while len(datasets) % opt.batch_size != 0: datasets.append(datasets[-1]) m = 1 @@ -116,7 +72,7 @@ def get_dataloader(opt): datasets += datasets_orig m += 1 - loader = InfiniteDataLoader # only DataLoader allows for attribute updates + loader = InfiniteDataLoader # only DataLoader allows for attribute updates print("OPT_BATCHSIZE: ", opt.batch_size) return loader(datasets, batch_size=opt.batch_size, @@ -145,57 +101,41 @@ def main(opt): # load model model = torch.jit.load(opt.model) batch_size = opt.batch_size - torch_aie.set_device(opt.device_id) + mindietorch.set_device(opt.device_id) if opt.need_compile: inputs = [] - inputs.append(torch_aie.Input([opt.batch_size, 100, 80], dtype=torch_aie.dtype.FLOAT)) - inputs.append(torch_aie.Input([opt.batch_size], dtype=torch_aie.dtype.INT32)) + inputs.append(mindietorch.Input([opt.batch_size, 100, 80], dtype=mindietorch.dtype.FLOAT)) + inputs.append(mindietorch.Input([opt.batch_size], dtype=mindietorch.dtype.INT32)) - model = torch_aie.compile( + model = mindietorch.compile( model, inputs=inputs, - precision_policy=_enums.PrecisionPolicy.FP16, - truncate_long_and_double=True, - require_full_compilation=False, - allow_tensor_replace_int=False, - min_block_size=3, - torch_executed_ops=[], - soc_version='Ascend310P3', - optimization_level=0) + precision_policy=_enums.PrecisionPolicy.FP32, + soc_version=opt.soc_version, + optimization_level=0 + ) dataloader = get_dataloader(opt) pred_results = forward_infer(model, dataloader, batch_size, opt.device_id) - for index, res in enumerate(pred_results): - print(index, " ", res) - print(res[0].shape) if opt.batch_size == 1 and opt.multi == 1: result_path = opt.result_path if(os.path.exists(result_path) == False): - os.makedirs(result_path) + os.makedirs(result_path) for index, res in enumerate(pred_results): - # for i in range(batch_size): - # result_fname_0 = 'data' + str(index * batch_size + i + 1) + '_0.txt' - # result_fname_1 = 'data' + str(index * batch_size + i + 1) + '_1.txt' result_fname_0 = 'data' + str(index) + '_0.txt' result_fname_1 = 'data' + str(index) + '_1.txt' - # res = np.array(res) - # save_tensor_arr_to_file(np.array(res[0][i]), os.path.join(result_path, result_fname_0)) - # save_size_to_file(res[1].numpy()[i], os.path.join(result_path, result_fname_1)) save_tensor_arr_to_file(np.array(res[0]), os.path.join(result_path, result_fname_0)) save_size_to_file(res[1].numpy()[0], os.path.join(result_path, result_fname_1)) - if __name__ == '__main__': - parser = argparse.ArgumentParser(description='DeepSpeech2 offline model inference.') + parser = argparse.ArgumentParser(description='Zipformer offline model encoder inference.') parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version') - parser.add_argument('--model', type=str, default="./pt_compiled_model/encoder-epoch-12-avg-1_torch_aie_bs1.pt", help='ts model path') + parser.add_argument('--model', type=str, default="./pt_compiled_model/encoder-epoch-12-avg-1_mindietorch_bs1.pt", help='ts model path') parser.add_argument('--need_compile', action="store_true", help='if the loaded model needs to be compiled or not') parser.add_argument('--batch_size', type=int, default=1, help='batch size') parser.add_argument('--device_id', type=int, default=0, help='device id') - parser.add_argument('--data_file', default='./deepspeech.pytorch/data/an4_test_manifest.json') - parser.add_argument('--label_file', default='./deepspeech.pytorch/labels.json') parser.add_argument('--result_path', default='result/encoder') parser.add_argument('--multi', type=int, default=1, help='multiples of dataset replication for enough infer loop. if multi != 1, the pred result will not be stored.') opt = parser.parse_args() diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_join.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_join.py index a4fe80f4de..9f5b31ae6c 100644 --- a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_join.py +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_join.py @@ -1,4 +1,4 @@ -# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,9 +17,9 @@ import copy import argparse import torch -import torch_aie +import mindietorch import numpy as np -from torch_aie import _enums +from mindietorch import _enums from torch.utils.data import dataloader from model_pt_join import forward_infer @@ -58,57 +58,13 @@ class _RepeatSampler: while True: yield from iter(self.sampler) -# def collate_fn(batch): -# """ -# data preprocessing -# """ -# def func(p): -# """ -# data size -# """ -# return p[0].size(1) - -# batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True) -# longest_sample = max(batch, key=func)[0] -# freq_size = longest_sample.size(0) -# minibatch_size = len(batch) -# max_seqlength = longest_sample.size(1) -# inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength) -# input_percentages = torch.FloatTensor(minibatch_size) -# for x in range(minibatch_size): -# sample = batch[x] -# tensor = sample[0] -# seq_length = tensor.size(1) -# inputs[x][0].narrow(1, 0, seq_length).copy_(tensor) -# input_percentages[x] = seq_length / float(max_seqlength) -# return inputs, input_percentages, [] - def get_dataloader(opt): - # with open(to_absolute_path(opt.label_file)) as label_file: - # labels = json.load(label_file) - - # dataset = SpectrogramDataset( - # audio_conf=DataConfig.spect, - # input_path=opt.data_file, - # labels=labels, - # normalize=True, - # aug_cfg=DataConfig.augmentation - # ) - # inputs, input_percentages, _ = collate_fn(dataset) - # input_sizes = input_percentages.mul_(int(inputs.size(3))).int() - # print(inputs[0]) - # print(input_sizes.tolist()) - - # datasets = [[inputs[i], input_sizes[i]] for i in range(len(input_sizes))] x = torch.zeros(512, dtype=torch.float32) y = torch.zeros(512, dtype=torch.float32) - # x_lens = torch.tensor([100], dtype=torch.int64) - # x_lens = 100 datasets = [] for i in range(20): datasets.append([copy.deepcopy(x), copy.deepcopy(y)]) - # print(datasets) while len(datasets) % opt.batch_size != 0: datasets.append(datasets[-1]) m = 1 @@ -137,66 +93,44 @@ def save_tensor_arr_to_file(arr, file_path): with open(file_path, "w", encoding='utf-8') as f: f.write(write_sen) -def save_size_to_file(size, file_path): - write_sen = "" + str(size) + " " - with open(file_path, "w", encoding='utf-8') as f: - f.write(write_sen) def main(opt): # load model model = torch.jit.load(opt.model) batch_size = opt.batch_size - torch_aie.set_device(opt.device_id) + mindietorch.set_device(opt.device_id) if opt.need_compile: inputs = [] - inputs.append(torch_aie.Input([opt.batch_size, 512], dtype = torch_aie.dtype.FLOAT)) - inputs.append(torch_aie.Input([opt.batch_size, 512], dtype = torch_aie.dtype.FLOAT)) + inputs.append(mindietorch.Input([opt.batch_size, 512], dtype = mindietorch.dtype.FLOAT)) + inputs.append(mindietorch.Input([opt.batch_size, 512], dtype = mindietorch.dtype.FLOAT)) - model = torch_aie.compile( + model = mindietorch.compile( model, inputs=inputs, - precision_policy=_enums.PrecisionPolicy.FP16, - truncate_long_and_double=True, - require_full_compilation=False, - allow_tensor_replace_int=False, - min_block_size=3, - torch_executed_ops=[], - soc_version='Ascend310P3', - optimization_level=0) + precision_policy=_enums.PrecisionPolicy.FP32, + soc_version=opt.soc_version, + optimization_level=0 + ) dataloader = get_dataloader(opt) pred_results = forward_infer(model, dataloader, batch_size, opt.device_id) - # for index, res in enumerate(pred_results): - # print(index, " ", res) - # print(res[0].shape) if opt.batch_size == 1 and opt.multi == 1: result_path = opt.result_path if(os.path.exists(result_path) == False): - os.makedirs(result_path) + os.makedirs(result_path) for index, res in enumerate(pred_results): - # for i in range(batch_size): - # result_fname_0 = 'data' + str(index * batch_size + i + 1) + '_0.txt' - # result_fname_1 = 'data' + str(index * batch_size + i + 1) + '_1.txt' result_fname_0 = 'data' + str(index) + '_0.txt' - # result_fname_1 = 'data' + str(index) + '_1.txt' - # res = np.array(res) - # save_tensor_arr_to_file(np.array(res[0][i]), os.path.join(result_path, result_fname_0)) - # save_size_to_file(res[1].numpy()[i], os.path.join(result_path, result_fname_1)) save_tensor_arr_to_file(np.array(res), os.path.join(result_path, result_fname_0)) - # save_size_to_file(res[1].numpy()[0], os.path.join(result_path, result_fname_1)) - if __name__ == '__main__': - parser = argparse.ArgumentParser(description='DeepSpeech2 offline model inference.') + parser = argparse.ArgumentParser(description='Zipformer offline model joiner inference.') parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version') - parser.add_argument('--model', type=str, default="./pt_compiled_model/joiner-epoch-12-avg-1_torch_aie_bs1.pt", help='ts model path') + parser.add_argument('--model', type=str, default="./pt_compiled_model/joiner-epoch-12-avg-1_mindietorch_bs1.pt", help='ts model path') parser.add_argument('--need_compile', action="store_true", help='if the loaded model needs to be compiled or not') parser.add_argument('--batch_size', type=int, default=1, help='batch size') parser.add_argument('--device_id', type=int, default=0, help='device id') - parser.add_argument('--data_file', default='./deepspeech.pytorch/data/an4_test_manifest.json') - parser.add_argument('--label_file', default='./deepspeech.pytorch/labels.json') parser.add_argument('--result_path', default='result/joiner') parser.add_argument('--multi', type=int, default=1, help='multiples of dataset replication for enough infer loop. if multi != 1, the pred result will not be stored.') opt = parser.parse_args() diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/modify_decoder.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/modify_decoder.py deleted file mode 100644 index 866216820f..0000000000 --- a/AscendIE/TorchAIE/built-in/audio/Zipformer/modify_decoder.py +++ /dev/null @@ -1,25 +0,0 @@ -from argparse import ArgumentParser - -from auto_optimizer import OnnxGraph - - -def main(): - parser = ArgumentParser() - parser.add_argument("--onnx", type=str, required=True) - args = parser.parse_args() - - graph = OnnxGraph.parse(args.onnx) - graph.remove("/decoder/Clip") - gather = graph["/decoder/embedding/Gather"] - gather.inputs[1] = "y" - graph.update_map() - graph.infershape() - - g_sim = graph.simplify() - save_path = args.onnx.replace(".onnx", "_modified.onnx") - g_sim.save(save_path) - print("Modified model saved to ", save_path) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.patch b/AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.patch new file mode 100644 index 0000000000..d5a19abee9 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.patch @@ -0,0 +1,59 @@ +--- zipformer.py 2024-03-08 15:18:44.384000000 +0800 ++++ zipformer.py 2024-03-08 15:18:42.056000000 +0800 +@@ -1415,7 +1415,7 @@ + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + +- def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: ++ def extend_pe(self, x: Tensor, left_context_len: int = 0) -> Tensor: + """Reset the positional encodings.""" + T = x.size(0) + left_context_len + +@@ -1423,8 +1423,7 @@ + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= T * 2 - 1: +- self.pe = self.pe.to(dtype=x.dtype, device=x.device) +- return ++ return self.pe.to(dtype=x.dtype, device=x.device) + + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) +@@ -1458,12 +1457,13 @@ + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + +- pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) +- pe[:, 0::2] = cosines +- pe[:, 1::2] = sines +- pe[:, -1] = 1.0 # for bias. ++ cos_shape0 = cosines.shape[0] ++ bias_one = torch.ones(cos_shape0, 1) ++ pe = torch.cat((cosines.unsqueeze(2), sines.unsqueeze(2)), dim=2) ++ pe = pe.reshape(cos_shape0, -1) ++ pe = torch.cat((pe[:, :-1], bias_one), dim=1) + +- self.pe = pe.to(dtype=x.dtype) ++ return pe.to(dtype=x.dtype) + + def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: + """Create positional encoding. +@@ -1475,14 +1475,14 @@ + Returns: + positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). + """ +- self.extend_pe(x, left_context_len) ++ pe = self.extend_pe(x, left_context_len) + x_size_left = x.size(0) + left_context_len + # length of positive side: x.size(0) + left_context_len + # length of negative side: x.size(0) +- pos_emb = self.pe[ +- self.pe.size(0) // 2 ++ pos_emb = pe[ ++ pe.size(0) // 2 + - x_size_left +- + 1 : self.pe.size(0) // 2 # noqa E203 ++ + 1 : pe.size(0) // 2 # noqa E203 + + x.size(0), + :, + ] diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.py deleted file mode 100644 index 163eb5f6b2..0000000000 --- a/AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.py +++ /dev/null @@ -1,2438 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import logging -import math -import random -import warnings -from typing import List, Optional, Tuple, Union - -import torch -from encoder_interface import EncoderInterface -from scaling import ( - Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. -) -from scaling import ( - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. -) -from scaling import ( - ActivationDropoutAndLinear, - Balancer, - BiasNorm, - ChunkCausalDepthwiseConv1d, - Dropout2, - FloatLike, - ScheduledFloat, - Whiten, - convert_num_channels, - limit_param_value, - penalize_abs_values_gt, - softmax, -) -from torch import Tensor, nn - - -class Zipformer2(EncoderInterface): - """ - Args: - - Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length - as downsampling_factor if they are single ints or one-element tuples. The length of - downsampling_factor defines the number of stacks. - - output_downsampling_factor (int): how much to downsample at the output. Note: - we also downsample by a factor of 2 in the Conv2dSubsampling encoder. - You should probably leave this at 2. - downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. - Note: this is in addition to the downsampling factor of 2 that is applied in - the frontend (self.encoder_embed). - encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per - encoder stack. - num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack - encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of - the encoder stacks for purposes of per-frame dropout (recommend 256 for - now). - query_head_dim (int or Tuple[int]): dimension of query and key per attention - head: per stack, if a tuple.. - pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per - attention head - value_head_dim (int or Tuple[int]): dimension of value in each attention head - num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. - Must be at least 4. - feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules - cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module - - pos_dim (int): the dimension of each positional-encoding vector prior to projection, - e.g. 128. - - dropout (float): dropout rate - warmup_batches (float): number of batches to warm up over; this controls - dropout of encoder layers. - causal (bool): if True, support chunkwise causal convolution. This should - not hurt WER as no modeling power is lost, but the convolution modules will be - slightly slower and use more memory. Enables use of the chunk_size and - left_context_chunks options in forward(), which simulates streaming - decoding. - chunk_size: (list of int): only set this to other than [-1] if causal; - the chunk size will be randomly chosen from this list. -1 means no chunking. - left_context_frames: (list of int): determines the number of left- - context chunks for causal training; will be rounded to a number of - chunks. Must not be less than cnn_module_kernel (after factoring in - rounding and downsampling); an error will be thrown if this is violated. - """ - - def __init__( - self, - output_downsampling_factor: int = 2, - downsampling_factor: Tuple[int] = (2, 4), - encoder_dim: Union[int, Tuple[int]] = 384, - num_encoder_layers: Union[int, Tuple[int]] = 4, - encoder_unmasked_dim: Union[int, Tuple[int]] = 256, - query_head_dim: Union[int, Tuple[int]] = 24, - pos_head_dim: Union[int, Tuple[int]] = 4, - value_head_dim: Union[int, Tuple[int]] = 12, - num_heads: Union[int, Tuple[int]] = 8, - feedforward_dim: Union[int, Tuple[int]] = 1536, - cnn_module_kernel: Union[int, Tuple[int]] = 31, - pos_dim: int = 192, - dropout: FloatLike = None, # see code below for default - warmup_batches: float = 4000.0, - causal: bool = False, - chunk_size: Tuple[int] = [-1], - left_context_frames: Tuple[int] = [-1], - ) -> None: - super(Zipformer2, self).__init__() - - if dropout is None: - dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) - - def _to_tuple(x): - """Converts a single int or a 1-tuple of an int to a tuple with the same length - as downsampling_factor""" - if isinstance(x, int): - x = (x,) - if len(x) == 1: - x = x * len(downsampling_factor) - else: - assert len(x) == len(downsampling_factor) and isinstance(x[0], int) - return x - - self.output_downsampling_factor = output_downsampling_factor # int - self.downsampling_factor = downsampling_factor # tuple - self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple - self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple( - encoder_unmasked_dim - ) # tuple - num_encoder_layers = _to_tuple(num_encoder_layers) - self.num_encoder_layers = num_encoder_layers - self.query_head_dim = query_head_dim = _to_tuple(query_head_dim) - self.value_head_dim = value_head_dim = _to_tuple(value_head_dim) - pos_head_dim = _to_tuple(pos_head_dim) - self.num_heads = num_heads = _to_tuple(num_heads) - feedforward_dim = _to_tuple(feedforward_dim) - self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) - - self.causal = causal - self.chunk_size = chunk_size - self.left_context_frames = left_context_frames - - for u, d in zip(encoder_unmasked_dim, encoder_dim): - assert u <= d - - # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder - encoders = [] - - num_encoders = len(downsampling_factor) - for i in range(num_encoders): - encoder_layer = Zipformer2EncoderLayer( - embed_dim=encoder_dim[i], - pos_dim=pos_dim, - num_heads=num_heads[i], - query_head_dim=query_head_dim[i], - pos_head_dim=pos_head_dim[i], - value_head_dim=value_head_dim[i], - feedforward_dim=feedforward_dim[i], - dropout=dropout, - cnn_module_kernel=cnn_module_kernel[i], - causal=causal, - ) - - # For the segment of the warmup period, we let the Conv2dSubsampling - # layer learn something. Then we start to warm up the other encoders. - encoder = Zipformer2Encoder( - encoder_layer, - num_encoder_layers[i], - pos_dim=pos_dim, - dropout=dropout, - warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), - final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), - ) - - if downsampling_factor[i] != 1: - encoder = DownsampledZipformer2Encoder( - encoder, - dim=encoder_dim[i], - downsample=downsampling_factor[i], - dropout=dropout, - ) - - encoders.append(encoder) - - self.encoders = nn.ModuleList(encoders) - - self.downsample_output = SimpleDownsample( - max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout - ) - - def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: - """ - In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of - randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all enocder dims larger than - some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoer dim. - - We generate the random masks at this level because we want the 2 masks to 'agree' - all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_subsampling_factor times. - - Args: - x: the embeddings (needed for the shape and dtype and device), of shape - (1, batch_size, encoder_dims0) - """ - num_encoders = len(self.encoder_dim) - if not self.training: - return [1.0] * num_encoders - - (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dim[0] == _encoder_dims0, ( - self.encoder_dim[0], - _encoder_dims0, - ) - - feature_mask_dropout_prob = 0.125 - - # mask1 shape: (1, batch_size, 1) - mask1 = ( - torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob - ).to(x.dtype) - - # mask2 has additional sequences masked, about twice the number. - mask2 = torch.logical_and( - mask1, - ( - torch.rand(1, batch_size, 1, device=x.device) - > feature_mask_dropout_prob - ).to(x.dtype), - ) - - # dim: (1, batch_size, 2) - mask = torch.cat((mask1, mask2), dim=-1) - - feature_masks = [] - for i in range(num_encoders): - channels = self.encoder_dim[i] - feature_mask = torch.ones( - 1, batch_size, channels, dtype=x.dtype, device=x.device - ) - u1 = self.encoder_unmasked_dim[i] - u2 = u1 + (channels - u1) // 2 - - feature_mask[:, :, u1:u2] *= mask[..., 0:1] - feature_mask[:, :, u2:] *= mask[..., 1:2] - - feature_masks.append(feature_mask) - - return feature_masks - - def get_chunk_info(self) -> Tuple[int, int]: - """ - Returns chunk_size and left_context_chunks. - """ - if not self.causal: - return -1, -1 - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - assert len(self.chunk_size) == 1, self.chunk_size - chunk_size = self.chunk_size[0] - else: - chunk_size = random.choice(self.chunk_size) - - if chunk_size == -1: - left_context_chunks = -1 - else: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - assert len(self.left_context_frames) == 1, self.left_context_frames - left_context_frames = self.left_context_frames[0] - else: - left_context_frames = random.choice(self.left_context_frames) - # Note: in Python, -1 // n == -1 for n > 0 - left_context_chunks = left_context_frames // chunk_size - if left_context_chunks == 0: - left_context_chunks = 1 - - return chunk_size, left_context_chunks - - def forward( - self, - x: Tensor, - x_lens: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x: - The input tensor. Its shape is (seq_len, batch_size, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - src_key_padding_mask: - The mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - """ - outputs = [] - if torch.jit.is_scripting() or torch.jit.is_tracing(): - feature_masks = [1.0] * len(self.encoder_dim) - else: - feature_masks = self.get_feature_masks(x) - - chunk_size, left_context_chunks = self.get_chunk_info() - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - # Not support exporting a model for simulating streaming decoding - attn_mask = None - else: - attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) - - for i, module in enumerate(self.encoders): - ds = self.downsampling_factor[i] - x = convert_num_channels(x, self.encoder_dim[i]) - - x = module( - x, - chunk_size=chunk_size, - feature_mask=feature_masks[i], - src_key_padding_mask=( - None - if src_key_padding_mask is None - else src_key_padding_mask[..., ::ds] - ), - attn_mask=attn_mask, - ) - outputs.append(x) - - # if the last output has the largest dimension, x will be unchanged, - # it will be the same as outputs[-1]. Otherwise it will be concatenated - # from different pieces of 'outputs', taking each dimension from the - # most recent output that has it present. - x = self._get_full_dim_output(outputs) - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2, self.output_downsampling_factor - if torch.jit.is_scripting() or torch.jit.is_tracing(): - lengths = (x_lens + 1) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = (x_lens + 1) // 2 - - return x, lengths - - def _get_attn_mask( - self, x: Tensor, chunk_size: int, left_context_chunks: int - ) -> Optional[Tensor]: - """ - Return None if chunk_size == -1, else return attention mask of shape - (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True - means a masked position. - Args: - x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). - chunk_size: chunk size, must divide - """ - if chunk_size <= 0: - return None - assert all(chunk_size % d == 0 for d in self.downsampling_factor) - if left_context_chunks >= 0: - num_encoders = len(self.encoder_dim) - assert all( - chunk_size * left_context_chunks - >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] - for i in range(num_encoders) - ) - else: - left_context_chunks = 1000000 - - seq_len = x.shape[0] - - # t is frame index, shape (seq_len,) - t = torch.arange(seq_len, dtype=torch.int32, device=x.device) - # c is chunk index for each frame, shape (seq_len,) - if torch.jit.is_scripting() or torch.jit.is_tracing(): - c = t // chunk_size - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - c = t // chunk_size - src_c = c - tgt_c = c.unsqueeze(-1) - - attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) - if __name__ == "__main__": - logging.info(f"attn_mask = {attn_mask}") - return attn_mask - - def _get_full_dim_output(self, outputs: List[Tensor]): - num_encoders = len(self.encoder_dim) - assert len(outputs) == num_encoders - output_dim = max(self.encoder_dim) - output_pieces = [outputs[-1]] - cur_dim = self.encoder_dim[-1] - for i in range(num_encoders - 2, -1, -1): - d = self.encoder_dim[i] - if d > cur_dim: - this_output = outputs[i] - output_pieces.append(this_output[..., cur_dim:d]) - cur_dim = d - assert cur_dim == output_dim - return torch.cat(output_pieces, dim=-1) - - def streaming_forward( - self, - x: Tensor, - x_lens: Tensor, - states: List[Tensor], - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (seq_len, batch_size, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - states: list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - src_key_padding_mask: - The mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - - updated states - """ - outputs = [] - new_states = [] - layer_offset = 0 - - for i, module in enumerate(self.encoders): - num_layers = module.num_layers - ds = self.downsampling_factor[i] - x = convert_num_channels(x, self.encoder_dim[i]) - - x, new_layer_states = module.streaming_forward( - x, - states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], - left_context_len=self.left_context_frames[0] // ds, - src_key_padding_mask=src_key_padding_mask[..., ::ds], - ) - layer_offset += num_layers - outputs.append(x) - new_states += new_layer_states - - # if the last output has the largest dimension, x will be unchanged, - # it will be the same as outputs[-1]. Otherwise it will be concatenated - # from different pieces of 'outputs', taking each dimension from the - # most recent output that has it present. - x = self._get_full_dim_output(outputs) - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2 - if torch.jit.is_scripting() or torch.jit.is_tracing(): - lengths = (x_lens + 1) // 2 - else: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = (x_lens + 1) // 2 - - return x, lengths, new_states - - @torch.jit.export - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> List[Tensor]: - """Get initial states. - - A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - """ - states = [] - for i, module in enumerate(self.encoders): - num_layers = module.num_layers - embed_dim = self.encoder_dim[i] - ds = self.downsampling_factor[i] - num_heads = self.num_heads[i] - key_dim = self.query_head_dim[i] * num_heads - value_dim = self.value_head_dim[i] * num_heads - downsample_left = self.left_context_frames[0] // ds - nonlin_attn_head_dim = 3 * embed_dim // 4 - conv_left_pad = self.cnn_module_kernel[i] // 2 - for layer in range(num_layers): - cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( - device - ) - cached_nonlin_attn = torch.zeros( - 1, batch_size, downsample_left, nonlin_attn_head_dim - ).to(device) - cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( - device - ) - cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( - device - ) - cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( - device - ) - cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( - device - ) - states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - return states - - -def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: - return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x) - - -def _balancer_schedule(min_prob: float): - return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) - - -class Zipformer2EncoderLayer(nn.Module): - """ - Args: - embed_dim: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - feedforward_dim: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - - Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - value_head_dim: int, - feedforward_dim: int, - dropout: FloatLike = 0.1, - cnn_module_kernel: int = 31, - causal: bool = False, - attention_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 - ), - conv_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0 - ), - const_attention_rate: FloatLike = ScheduledFloat( - (0.0, 0.25), (4000.0, 0.025), default=0 - ), - ff2_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) - ), - ff3_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0) - ), - bypass_skip_rate: FloatLike = ScheduledFloat( - (0.0, 0.5), (4000.0, 0.02), default=0 - ), - ) -> None: - super(Zipformer2EncoderLayer, self).__init__() - self.embed_dim = embed_dim - - # self.bypass implements layer skipping as well as bypass; see its default values. - self.bypass = BypassModule( - embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0 - ) - # bypass_mid is bypass used in the middle of the layer. - self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0) - - # skip probability for dynamic modules (meaning: anything but feedforward). - self.attention_skip_rate = copy.deepcopy(attention_skip_rate) - # an additional skip probability that applies to ConvModule to stop it from - # contributing too much early on. - self.conv_skip_rate = copy.deepcopy(conv_skip_rate) - - # ff2_skip_rate is to prevent the ff2 module from having output that's too big - # compared to its residual. - self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) - self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) - - self.const_attention_rate = copy.deepcopy(const_attention_rate) - - self.self_attn_weights = RelPositionMultiheadAttentionWeights( - embed_dim, - pos_dim=pos_dim, - num_heads=num_heads, - query_head_dim=query_head_dim, - pos_head_dim=pos_head_dim, - dropout=0.0, - ) - - self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim) - - self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim) - - self.feed_forward1 = FeedforwardModule( - embed_dim, (feedforward_dim * 3) // 4, dropout - ) - - self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) - - self.feed_forward3 = FeedforwardModule( - embed_dim, (feedforward_dim * 5) // 4, dropout - ) - - self.nonlin_attention = NonlinAttention( - embed_dim, hidden_channels=3 * embed_dim // 4 - ) - - self.conv_module1 = ConvolutionModule( - embed_dim, cnn_module_kernel, causal=causal - ) - - self.conv_module2 = ConvolutionModule( - embed_dim, cnn_module_kernel, causal=causal - ) - - # TODO: remove it - self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - - self.norm = BiasNorm(embed_dim) - - self.balancer1 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - min_abs=0.2, - max_abs=4.0, - ) - - # balancer for output of NonlinAttentionModule - self.balancer_na = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), - prob=0.05, # out of concern for memory usage - ) - - # balancer for output of feedforward2, prevent it from staying too - # small. give this a very small probability, even at the start of - # training, it's to fix a rare problem and it's OK to fix it slowly. - self.balancer_ff2 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), - max_abs=2.0, - prob=0.05, - ) - - self.balancer_ff3 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=0.7, - min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), - max_abs=4.0, - prob=0.05, - ) - - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(4.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - self.balancer2 = Balancer( - embed_dim, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - min_abs=0.1, - max_abs=4.0, - ) - - def get_sequence_dropout_mask( - self, x: Tensor, dropout_rate: float - ) -> Optional[Tensor]: - if ( - dropout_rate == 0.0 - or not self.training - or torch.jit.is_scripting() - or torch.jit.is_tracing() - ): - return None - batch_size = x.shape[1] - mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) - return mask - - def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: - """ - Apply sequence-level dropout to x. - x shape: (seq_len, batch_size, embed_dim) - """ - dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) - if dropout_mask is None: - return x - else: - return x * dropout_mask - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - chunk_size: int = -1, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: - A tensor which has the same shape as src - """ - src_orig = src - - # dropout rate for non-feedforward submodules - if torch.jit.is_scripting() or torch.jit.is_tracing(): - attention_skip_rate = 0.0 - else: - attention_skip_rate = ( - float(self.attention_skip_rate) if self.training else 0.0 - ) - - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights = self.self_attn_weights( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - - src = src + self.feed_forward1(src) - - self_attn_dropout_mask = self.get_sequence_dropout_mask( - src, attention_skip_rate - ) - - selected_attn_weights = attn_weights[0:1] - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif not self.training and random.random() < float(self.const_attention_rate): - # Make attention weights constant. The intention is to - # encourage these modules to do something similar to an - # averaging-over-time operation. - # only need the mask, can just use the 1st one and expand later - selected_attn_weights = selected_attn_weights[0:1] - selected_attn_weights = (selected_attn_weights > 0.0).to( - selected_attn_weights.dtype - ) - selected_attn_weights = selected_attn_weights * ( - 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True) - ) - - na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) - - src = src + ( - na if self_attn_dropout_mask is None else na * self_attn_dropout_mask - ) - - self_attn = self.self_attn1(src, attn_weights) - - src = src + ( - self_attn - if self_attn_dropout_mask is None - else self_attn * self_attn_dropout_mask - ) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - conv_skip_rate = 0.0 - else: - conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout( - self.conv_module1( - src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask - ), - conv_skip_rate, - ) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - ff2_skip_rate = 0.0 - else: - ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout( - self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate - ) - - # bypass in the middle of the layer. - src = self.bypass_mid(src_orig, src) - - self_attn = self.self_attn2(src, attn_weights) - - src = src + ( - self_attn - if self_attn_dropout_mask is None - else self_attn * self_attn_dropout_mask - ) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - conv_skip_rate = 0.0 - else: - conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout( - self.conv_module2( - src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask - ), - conv_skip_rate, - ) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - ff3_skip_rate = 0.0 - else: - ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 - src = src + self.sequence_dropout( - self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate - ) - - src = self.balancer1(src) - src = self.norm(src) - - src = self.bypass(src_orig, src) - - src = self.balancer2(src) - src = self.whiten(src) - - return src - - def streaming_forward( - self, - src: Tensor, - pos_emb: Tensor, - cached_key: Tensor, - cached_nonlin_attn: Tensor, - cached_val1: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - """Pass the input through the encoder layer in streaming forward mode. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or - (batch_size, left_context_len+2*seq_len-1, pos_emb_dim) - cached_key: cached attention key tensor of left context, - of shape (left_context_len, batch_size, key_dim) - cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape - (num_heads, batch_size, left_context_len, head_dim) - cached_val1: cached left context for the first attention module, - of shape (left_context_len, batch_size, value_dim) - cached_val2: cached left context for the second attention module, - of shape (left_context_len, batch_size, value_dim) - cached_conv1: cached left context for the first convolution module, - of shape (batch_size, channels, left_pad) - cached_conv2: cached left context for the second convolution module, - of shape (batch_size, channels, left_pad) - left_context_len: number of left context frames. - src_key_padding_mask: the mask for padding, of shape - (batch_size, left_context_len + seq_len); True means masked position. - May be None. - - Returns: - - x, with the same shape as src - - updated cached_key - - updated cached_nonlin_attn - - updated cached_val1 - - updated cached_val2 - - updated cached_conv1 - - updated cached_conv2 - """ - src_orig = src - - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - attn_weights, cached_key = self.self_attn_weights.streaming_forward( - src, - pos_emb=pos_emb, - cached_key=cached_key, - left_context_len=left_context_len, - key_padding_mask=src_key_padding_mask, - ) - - src = src + self.feed_forward1(src) - - na, cached_nonlin_attn = self.nonlin_attention.streaming_forward( - src, - attn_weights[0:1], - cached_x=cached_nonlin_attn, - left_context_len=left_context_len, - ) - src = src + na - - self_attn, cached_val1 = self.self_attn1.streaming_forward( - src, - attn_weights=attn_weights, - cached_val=cached_val1, - left_context_len=left_context_len, - ) - src = src + self_attn - - src_conv, cached_conv1 = self.conv_module1.streaming_forward( - src, - cache=cached_conv1, - src_key_padding_mask=src_key_padding_mask[:, left_context_len:], - ) - src = src + src_conv - - src = src + self.feed_forward2(src) - - # bypass in the middle of the layer. - src = self.bypass_mid(src_orig, src) - - self_attn, cached_val2 = self.self_attn2.streaming_forward( - src, - attn_weights=attn_weights, - cached_val=cached_val2, - left_context_len=left_context_len, - ) - src = src + self_attn - - src_conv, cached_conv2 = self.conv_module2.streaming_forward( - src, - cache=cached_conv2, - src_key_padding_mask=src_key_padding_mask[:, left_context_len:], - ) - src = src + src_conv - - src = src + self.feed_forward3(src) - - src = self.norm(src) - - src = self.bypass(src_orig, src) - - return ( - src, - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ) - - -class Zipformer2Encoder(nn.Module): - r"""Zipformer2Encoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - pos_dim: the dimension for the relative positional encoding - - Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = zipformer_encoder(src) - """ - - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - pos_dim: int, - dropout: float, - warmup_begin: float, - warmup_end: float, - initial_layerdrop_rate: float = 0.5, - final_layerdrop_rate: float = 0.05, - ) -> None: - super().__init__() - self.encoder_pos = CompactRelPositionalEncoding( - pos_dim, dropout_rate=0.15, length_factor=1.0 - ) - - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - assert 0 <= warmup_begin <= warmup_end - - delta = (1.0 / num_layers) * (warmup_end - warmup_begin) - cur_begin = warmup_begin # interpreted as a training batch index - for i in range(num_layers): - cur_end = cur_begin + delta - self.layers[i].bypass.skip_rate = ScheduledFloat( - (cur_begin, initial_layerdrop_rate), - (cur_end, final_layerdrop_rate), - default=0.0, - ) - cur_begin = cur_end - - def forward( - self, - src: Tensor, - chunk_size: int = -1, - feature_mask: Union[Tensor, float] = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: a Tensor with the same shape as src. - """ - pos_emb = self.encoder_pos(src) - output = src - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - output = output * feature_mask - - for i, mod in enumerate(self.layers): - output = mod( - output, - pos_emb, - chunk_size=chunk_size, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - output = output * feature_mask - - return output - - def streaming_forward( - self, - src: Tensor, - states: List[Tensor], - left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, List[Tensor]]: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is - (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - left_context_len: Number of left context frames. - src_key_padding_mask: the mask for padding, of shape - (batch_size, left_context_len + seq_len); True means masked position. - May be None. - - Returns: - - output, a Tensor with the same shape as src. - - updated states - """ - pos_emb = self.encoder_pos(src, left_context_len) - output = src - - new_states = [] - for i, mod in enumerate(self.layers): - ( - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ) = states[i * 6 : (i + 1) * 6] - ( - output, - new_cached_key, - new_cached_nonlin_attn, - new_cached_val1, - new_cached_val2, - new_cached_conv1, - new_cached_conv2, - ) = mod.streaming_forward( - output, - pos_emb, - cached_key=cached_key, - cached_nonlin_attn=cached_nonlin_attn, - cached_val1=cached_val1, - cached_val2=cached_val2, - cached_conv1=cached_conv1, - cached_conv2=cached_conv2, - left_context_len=left_context_len, - src_key_padding_mask=src_key_padding_mask, - ) - new_states += [ - new_cached_key, - new_cached_nonlin_attn, - new_cached_val1, - new_cached_val2, - new_cached_conv1, - new_cached_conv2, - ] - - return output, new_states - - -class BypassModule(nn.Module): - """ - An nn.Module that implements a learnable bypass scale, and also randomized per-sequence - layer-skipping. The bypass is limited during early stages of training to be close to - "straight-through", i.e. to not do the bypass operation much initially, in order to - force all the modules to learn something. - """ - - def __init__( - self, - embed_dim: int, - skip_rate: FloatLike = 0.0, - straight_through_rate: FloatLike = 0.0, - scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), - scale_max: FloatLike = 1.0, - ): - super().__init__() - self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) - self.skip_rate = copy.deepcopy(skip_rate) - self.straight_through_rate = copy.deepcopy(straight_through_rate) - self.scale_min = copy.deepcopy(scale_min) - self.scale_max = copy.deepcopy(scale_max) - - def _get_bypass_scale(self, batch_size: int): - # returns bypass-scale of shape (num_channels,), - # or (batch_size, num_channels,). This is actually the - # scale on the non-residual term, so 0 correponds to bypassing - # this module. - if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: - return self.bypass_scale - else: - ans = limit_param_value( - self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max) - ) - skip_rate = float(self.skip_rate) - if skip_rate != 0.0: - mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate - ans = ans * mask - # now ans is of shape (batch_size, num_channels), and is zero for sequences - # on which we have randomly chosen to do layer-skipping. - straight_through_rate = float(self.straight_through_rate) - if straight_through_rate != 0.0: - mask = ( - torch.rand((batch_size, 1), device=ans.device) - < straight_through_rate - ) - ans = torch.maximum(ans, mask.to(ans.dtype)) - return ans - - def forward(self, src_orig: Tensor, src: Tensor): - """ - Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) - Returns: something with the same shape as src and src_orig - """ - bypass_scale = self._get_bypass_scale(src.shape[1]) - return src_orig + (src - src_orig) * bypass_scale - - -class DownsampledZipformer2Encoder(nn.Module): - r""" - DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, - after convolutional downsampling, and then upsampled again at the output, and combined - with the origin input, so that the output has the same shape as the input. - """ - - def __init__( - self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike - ): - super(DownsampledZipformer2Encoder, self).__init__() - self.downsample_factor = downsample - self.downsample = SimpleDownsample(dim, downsample, dropout) - self.num_layers = encoder.num_layers - self.encoder = encoder - self.upsample = SimpleUpsample(dim, downsample) - self.out_combiner = BypassModule(dim, straight_through_rate=0) - - def forward( - self, - src: Tensor, - chunk_size: int = -1, - feature_mask: Union[Tensor, float] = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - True means masked position. May be None. - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. - - Returns: a Tensor with the same shape as src. - """ - src_orig = src - src = self.downsample(src) - ds = self.downsample_factor - if attn_mask is not None: - attn_mask = attn_mask[::ds, ::ds] - - src = self.encoder( - src, - chunk_size=chunk_size // ds, - feature_mask=feature_mask, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] - - return self.out_combiner(src_orig, src) - - def streaming_forward( - self, - src: Tensor, - states: List[Tensor], - left_context_len: int, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, List[Tensor]]: - r"""Downsample, go through encoder, upsample, in streaming forward mode. - - Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is - (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - left_context_len: Number of left context frames. - src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len); - True means masked position. May be None. - - Returns: - - output, a Tensor with the same shape as src. - - updated states - """ - src_orig = src - src = self.downsample(src) - - src, new_states = self.encoder.streaming_forward( - src, - states=states, - left_context_len=left_context_len, - src_key_padding_mask=src_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] - - return self.out_combiner(src_orig, src), new_states - - -class SimpleDownsample(torch.nn.Module): - """ - Does downsampling with attention, by weighted sum, and a projection.. - """ - - def __init__(self, channels: int, downsample: int, dropout: FloatLike): - super(SimpleDownsample, self).__init__() - - self.bias = nn.Parameter(torch.zeros(downsample)) - - self.name = None # will be set from training code - self.dropout = copy.deepcopy(dropout) - - self.downsample = downsample - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, in_channels) - Returns a tensor of shape - ( (seq_len+downsample-1)//downsample, batch_size, channels) - """ - (seq_len, batch_size, in_channels) = src.shape - ds = self.downsample - d_seq_len = (seq_len + ds - 1) // ds - - # Pad to an exact multiple of self.downsample - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds - - src = src.reshape(d_seq_len, ds, batch_size, in_channels) - - weights = self.bias.softmax(dim=0) - # weights: (downsample, 1, 1) - weights = weights.unsqueeze(-1).unsqueeze(-1) - - # ans1 is the first `in_channels` channels of the output - ans = (src * weights).sum(dim=1) - - return ans - - -class SimpleUpsample(torch.nn.Module): - """ - A very simple form of upsampling that mostly just repeats the input, but - also adds a position-specific bias. - """ - - def __init__(self, num_channels: int, upsample: int): - super(SimpleUpsample, self).__init__() - self.upsample = upsample - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, num_channels) - Returns a tensor of shape - ( (seq_len*upsample), batch_size, num_channels) - """ - upsample = self.upsample - (seq_len, batch_size, num_channels) = src.shape - src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) - src = src.reshape(seq_len * upsample, batch_size, num_channels) - return src - - -class CompactRelPositionalEncoding(torch.nn.Module): - """ - Relative positional encoding module. This version is "compact" meaning it is able to encode - the important information about the relative position in a relatively small number of dimensions. - The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) - make very little difference to the embedding. Such differences were potentially important - when encoding absolute position, but not important when encoding relative position because there - is now no need to compare two large offsets with each other. - - Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval - using the atan() function, before doing the fourier transform of that fixed interval. The - atan() function would compress the "long tails" too small, - making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic - function to compress large offsets to a smaller range before applying atan(). - Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long - as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) - - - Args: - embed_dim: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length: just a heuristic for initialization. - length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives - less weight to small differences of offset near the origin. - """ - - def __init__( - self, - embed_dim: int, - dropout_rate: FloatLike, - max_len: int = 1000, - length_factor: float = 1.0, - ) -> None: - """Construct a CompactRelPositionalEncoding object.""" - super(CompactRelPositionalEncoding, self).__init__() - self.embed_dim = embed_dim - assert embed_dim % 2 == 0 - self.dropout = Dropout2(dropout_rate) - self.pe = None - assert length_factor >= 1.0 - self.length_factor = length_factor - self.extend_pe(torch.tensor(0.0).expand(max_len)) - - def extend_pe(self, x: Tensor, left_context_len: int = 0): - """Reset the positional encodings.""" - T = x.size(0) + left_context_len - - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(0) >= T * 2 - 1: - # self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return self.pe.to(dtype=x.dtype, device=x.device) - - # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] - x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) - - freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) - - # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution - # for small time offsets but less resolution for large time offsets. - compression_length = self.embed_dim**0.5 - # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; - # but it does so more slowly than T for large absolute values of T. - # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which - # is important. - x_compressed = ( - compression_length - * x.sign() - * ((x.abs() + compression_length).log() - math.log(compression_length)) - ) - - # if self.length_factor == 1.0, then length_scale is chosen so that the - # FFT can exactly separate points close to the origin (T == 0). So this - # part of the formulation is not really heuristic. - # But empirically, for ASR at least, length_factor > 1.0 seems to work better. - length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) - - # note for machine implementations: if atan is not available, we can use: - # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) - # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) - x_atan = (x_compressed / length_scale).atan() # results between -pi and pi - - cosines = (x_atan * freqs).cos() - sines = (x_atan * freqs).sin() - - pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) - pe[:, 0::2] = cosines - pe[:, 1::2] = sines - pe[:, -1] = 1.0 # for bias. - - # self.pe = pe.to(dtype=x.dtype) - return pe.to(dtype=x.dtype) - - def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor: - """Create positional encoding. - - Args: - x (Tensor): Input tensor (time, batch, `*`). - left_context_len: (int): Length of cached left context. - - Returns: - positional embedding, of shape (batch, left_context_len + 2*time-1, `*`). - """ - pe = self.extend_pe(x, left_context_len) - x_size_left = x.size(0) + left_context_len - # length of positive side: x.size(0) + left_context_len - # length of negative side: x.size(0) - pos_emb = pe[ - pe.size(0) // 2 - - x_size_left - + 1 : pe.size(0) // 2 # noqa E203 - + x.size(0), - :, - ] - pos_emb = pos_emb.unsqueeze(0) - return self.dropout(pos_emb) - - -class RelPositionMultiheadAttentionWeights(nn.Module): - r"""Module that computes multi-head attention weights with relative position encoding. - Various other modules consume the resulting attention weights: see, for example, the - SimpleAttention module which allows you to compute conventional attention. - - This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", - we have to write up the differences. - - - Args: - embed_dim: number of channels at the input to this module, e.g. 256 - pos_dim: dimension of the positional encoding vectors, e.g. 128. - num_heads: number of heads to compute weights for, e.g. 8 - query_head_dim: dimension of the query (and key), per head. e.g. 24. - pos_head_dim: dimension of the projected positional encoding per head, e.g. 4. - dropout: dropout probability for attn_output_weights. Default: 0.0. - pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on - any given call to forward(), in training time. - """ - - def __init__( - self, - embed_dim: int, - pos_dim: int, - num_heads: int, - query_head_dim: int, - pos_head_dim: int, - dropout: float = 0.0, - pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), - ) -> None: - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.query_head_dim = query_head_dim - self.pos_head_dim = pos_head_dim - self.dropout = dropout - self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) - self.name = None # will be overwritten in training code; for diagnostics. - - key_head_dim = query_head_dim - in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads - - # the initial_scale is supposed to take over the "scaling" factor of - # head_dim ** -0.5 that has been used in previous forms of attention, - # dividing it between the query and key. Note: this module is intended - # to be used with the ScaledAdam optimizer; with most other optimizers, - # it would be necessary to apply the scaling factor in the forward function. - self.in_proj = ScaledLinear( - embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 - ) - - self.whiten_keys = Whiten( - num_groups=num_heads, - whitening_limit=_whitening_schedule(3.0), - prob=(0.025, 0.25), - grad_scale=0.025, - ) - - # add a balancer for the keys that runs with very small probability, and - # tries to enforce that all dimensions have mean around zero. The - # weights produced by this module are invariant to adding a constant to - # the keys, so the derivative of the bias is mathematically zero; but - # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero - # bias because the small numerical roundoff tends to have a non-random - # sign. This module is intended to prevent that. Use a very small - # probability; that should be suffixient to fix the problem. - self.balance_keys = Balancer( - key_head_dim * num_heads, - channel_dim=-1, - min_positive=0.4, - max_positive=0.6, - min_abs=0.0, - max_abs=100.0, - prob=0.025, - ) - - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear( - pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05 - ) - - # the following are for diagnosics only, see --print-diagnostics option - self.copy_pos_query = Identity() - self.copy_query = Identity() - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tensor: - r""" - Args: - x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), - interpreted as ([batch_size,] tgt_seq_len, src_seq_len) - saying which positions are allowed to attend to which other positions. - Returns: - a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). - """ - x = self.in_proj(x) - query_head_dim = self.query_head_dim - pos_head_dim = self.pos_head_dim - num_heads = self.num_heads - - seq_len, batch_size, _ = x.shape - - query_dim = query_head_dim * num_heads - - # self-attention - q = x[..., 0:query_dim] - k = x[..., query_dim : 2 * query_dim] - # p is the position-encoding query - p = x[..., 2 * query_dim :] - assert p.shape[-1] == num_heads * pos_head_dim - - q = self.copy_query(q) # for diagnostics only, does nothing. - k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. - p = self.copy_pos_query(p) # for diagnostics only, does nothing. - - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) - k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) - - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) - - attn_scores = torch.matmul(q, k) - - use_pos_scores = False - if torch.jit.is_scripting() or torch.jit.is_tracing(): - # We can't put random.random() in the same line - use_pos_scores = True - elif not self.training or random.random() >= float(self.pos_emb_skip_rate): - use_pos_scores = True - - if use_pos_scores: - pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( - 2, 0, 3, 1 - ) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - if torch.jit.is_tracing(): - (num_heads, batch_size, time1, n) = pos_scores.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(seq_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_scores = pos_scores.reshape(-1, n) - pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len) - else: - pos_scores = pos_scores.as_strided( - (num_heads, batch_size, seq_len, seq_len), - ( - pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2) - pos_scores.stride(3), - pos_scores.stride(3), - ), - storage_offset=pos_scores.stride(3) * (seq_len - 1), - ) - - attn_scores = attn_scores + pos_scores - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif self.training and random.random() < 0.1: - # This is a harder way of limiting the attention scores to not be - # too large. It incurs a penalty if any of them has an absolute - # value greater than 50.0. this should be outside the normal range - # of the attention scores. We use this mechanism instead of, say, - # something added to the loss function involving the entropy, - # because once the entropy gets very small gradients through the - # softmax can become very small, and we'd get zero derivatives. The - # choices of 1.0e-04 as the scale on the penalty makes this - # mechanism vulnerable to the absolute scale of the loss function, - # but we view this as a failsafe to avoid "implausible" parameter - # values rather than a regularization method that should be active - # under normal circumstances. - attn_scores = penalize_abs_values_gt( - attn_scores, limit=25.0, penalty=1.0e-04, name=self.name - ) - - assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) - - if attn_mask is not None: - assert attn_mask.dtype == torch.bool - # use -1000 to avoid nan's where attn_mask and key_padding_mask make - # all scores zero. It's important that this be large enough that exp(-1000) - # is exactly zero, for reasons related to const_attention_rate, it - # compares the final weights with zero. - attn_scores = attn_scores.masked_fill(attn_mask, -1000) - - if key_padding_mask is not None: - assert key_padding_mask.shape == ( - batch_size, - seq_len, - ), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - - # We use our own version of softmax, defined in scaling.py, which should - # save a little of the memory used in backprop by, if we are in - # automatic mixed precision mode (amp / autocast), by only storing the - # half-precision output for backprop purposes. - attn_weights = softmax(attn_scores, dim=-1) - - if torch.jit.is_scripting() or torch.jit.is_tracing(): - pass - elif random.random() < 0.001 and not self.training: - self._print_attn_entropy(attn_weights) - - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) - - return attn_weights - - def streaming_forward( - self, - x: Tensor, - pos_emb: Tensor, - cached_key: Tensor, - left_context_len: int, - key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor]: - r""" - Args: - x: input of shape (seq_len, batch_size, embed_dim) - pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim) - cached_key: cached attention key tensor of left context, - of shape (left_context_len, batch_size, key_dim) - left_context_len: number of left context frames. - key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that - are True in this mask will be ignored as sources in the attention weighting. - - Returns: - - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2), - interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). - - updated cached attention key tensor of left context. - """ - x = self.in_proj(x) - query_head_dim = self.query_head_dim - pos_head_dim = self.pos_head_dim - num_heads = self.num_heads - - seq_len, batch_size, _ = x.shape - - query_dim = query_head_dim * num_heads - - # self-attention - q = x[..., 0:query_dim] - k = x[..., query_dim : 2 * query_dim] - # p is the position-encoding query - p = x[..., 2 * query_dim :] - assert p.shape[-1] == num_heads * pos_head_dim - - # Pad cached left contexts - assert cached_key.shape[0] == left_context_len, ( - cached_key.shape[0], - left_context_len, - ) - k = torch.cat([cached_key, k], dim=0) - # Update cached left contexts - cached_key = k[-left_context_len:, ...] - - # The length of key - k_len = k.shape[0] - - q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) - k = k.reshape(k_len, batch_size, num_heads, query_head_dim) - - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) - - attn_scores = torch.matmul(q, k) - - pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 + left_context_len - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( - 2, 0, 3, 1 - ) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) - - if torch.jit.is_tracing(): - (num_heads, batch_size, time1, n) = pos_scores.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(k_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_scores = pos_scores.reshape(-1, n) - pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - else: - pos_scores = pos_scores.as_strided( - (num_heads, batch_size, seq_len, k_len), - ( - pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2) - pos_scores.stride(3), - pos_scores.stride(3), - ), - storage_offset=pos_scores.stride(3) * (seq_len - 1), - ) - - attn_scores = attn_scores + pos_scores - - assert attn_scores.shape == ( - num_heads, - batch_size, - seq_len, - k_len, - ), attn_scores.shape - - if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - - attn_weights = attn_scores.softmax(dim=-1) - - return attn_weights, cached_key - - def _print_attn_entropy(self, attn_weights: Tensor): - # attn_weights: (num_heads, batch_size, seq_len, seq_len) - (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape - - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): - attn_weights = attn_weights.to(torch.float32) - attn_weights_entropy = ( - -((attn_weights + 1.0e-20).log() * attn_weights) - .sum(dim=-1) - .mean(dim=(1, 2)) - ) - logging.info( - f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}" - ) - - -class SelfAttention(nn.Module): - """ - The simplest possible attention module. This one works with already-computed attention - weights, e.g. as computed by RelPositionMultiheadAttentionWeights. - - Args: - embed_dim: the input and output embedding dimension - num_heads: the number of attention heads - value_head_dim: the value dimension per head - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - value_head_dim: int, - ) -> None: - super().__init__() - self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) - - self.out_proj = ScaledLinear( - num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 - ) - - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - Returns: - a tensor with the same shape as x. - """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] - - # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) - - x = ( - x.permute(2, 1, 0, 3) - .contiguous() - .view(seq_len, batch_size, num_heads * value_head_dim) - ) - - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = self.out_proj(x) - x = self.whiten(x) - - return x - - def streaming_forward( - self, - x: Tensor, - attn_weights: Tensor, - cached_val: Tensor, - left_context_len: int, - ) -> Tuple[Tensor, Tensor]: - """ - Args: - x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. - cached_val: cached attention value tensor of left context, - of shape (left_context_len, batch_size, value_dim) - left_context_len: number of left context frames. - - Returns: - - attention weighted output, a tensor with the same shape as x. - - updated cached attention value tensor of left context. - """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - seq_len2 = seq_len + left_context_len - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2) - - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - - # Pad cached left contexts - assert cached_val.shape[0] == left_context_len, ( - cached_val.shape[0], - left_context_len, - ) - x = torch.cat([cached_val, x], dim=0) - # Update cached left contexts - cached_val = x[-left_context_len:, ...] - - x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) - value_head_dim = x.shape[-1] - - # todo: see whether there is benefit in overriding matmul - x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) - - x = ( - x.permute(2, 1, 0, 3) - .contiguous() - .view(seq_len, batch_size, num_heads * value_head_dim) - ) - - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. - x = self.out_proj(x) - - return x, cached_val - - -class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer2 model.""" - - def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): - super(FeedforwardModule, self).__init__() - self.in_proj = nn.Linear(embed_dim, feedforward_dim) - - self.hidden_balancer = Balancer( - feedforward_dim, - channel_dim=-1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0, - ) - - # shared_dim=0 means we share the dropout mask along the time axis - self.out_proj = ActivationDropoutAndLinear( - feedforward_dim, - embed_dim, - activation="SwooshL", - dropout_p=dropout, - dropout_shared_dim=0, - bias=True, - initial_scale=0.1, - ) - - self.out_whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward(self, x: Tensor): - x = self.in_proj(x) - x = self.hidden_balancer(x) - # out_proj contains SwooshL activation, then dropout, then linear. - x = self.out_proj(x) - x = self.out_whiten(x) - return x - - -class NonlinAttention(nn.Module): - """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed - from the attention module) in place of actual convolution. We also took out the second nonlinearity, the - one after the attention mechanism. - - Args: - channels (int): The number of channels of conv layers. - """ - - def __init__( - self, - channels: int, - hidden_channels: int, - ) -> None: - super().__init__() - - self.hidden_channels = hidden_channels - - self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) - - # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, - # because we noticed that well-trained instances of this module have abs-value before the sigmoid - # starting from about 3, and poorly-trained instances of the module have smaller abs values - # before the sigmoid. - self.balancer = Balancer( - hidden_channels, - channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), - max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), - min_abs=0.5, - max_abs=5.0, - ) - self.tanh = nn.Tanh() - - self.identity1 = Identity() # for diagnostics. - self.identity2 = Identity() # for diagnostics. - self.identity3 = Identity() # for diagnostics. - - self.out_proj = ScaledLinear( - hidden_channels, channels, bias=True, initial_scale=0.05 - ) - - self.whiten1 = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(5.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - self.whiten2 = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(5.0, ratio=3.0), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - def forward( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) - attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - Returns: - a Tensor with the same shape as x - """ - x = self.in_proj(x) - - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels - - s, x, y = x.chunk(3, dim=2) - - # s will go through tanh. - - s = self.balancer(s) - s = self.tanh(s) - - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = self.whiten1(x) - x = x * s - x = self.identity1(x) # diagnostics only, it's the identity. - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = torch.matmul(attn_weights, x) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - - y = self.identity2(y) - x = x * y - x = self.identity3(x) - - x = self.out_proj(x) - x = self.whiten2(x) - return x - - def streaming_forward( - self, - x: Tensor, - attn_weights: Tensor, - cached_x: Tensor, - left_context_len: int, - ) -> Tuple[Tensor, Tensor]: - """. - Args: - x: a Tensor of shape (seq_len, batch_size, num_channels) - attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) - cached_x: left context, a Tensor of shape - (num_heads, batch_size, left_context_len, head_dim) - left_context_len: number of left context frames. - Returns: - - a Tensor with the same shape as x - - updated left context with same shape as cached_x - """ - x = self.in_proj(x) - - (seq_len, batch_size, _) = x.shape - hidden_channels = self.hidden_channels - - s, x, y = x.chunk(3, dim=2) - - # s will go through tanh. - s = self.tanh(s) - - s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) - x = x * s - - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == ( - num_heads, - batch_size, - seq_len, - left_context_len + seq_len, - ) - - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, head_dim) - - # Pad cached tensor - assert cached_x.shape[2] == left_context_len, ( - cached_x.shape[2], - left_context_len, - ) - x_pad = torch.cat([cached_x, x], dim=2) - # Update cached tensor - cached_x = x_pad[:, :, -left_context_len:, :] - - x = torch.matmul(attn_weights, x_pad) - # now x: (num_heads, batch_size, seq_len, head_dim) - x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) - - x = x * y - - x = self.out_proj(x) - return x, cached_x - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Zipformer2 model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__( - self, - channels: int, - kernel_size: int, - causal: bool, - ) -> None: - """Construct a ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - bottleneck_dim = channels - self.causal = causal - - self.in_proj = nn.Linear( - channels, - 2 * bottleneck_dim, - ) - # the gradients on in_proj are a little noisy, likely to do with the - # sigmoid in glu. - - # after in_proj we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.balancer1 = Balancer( - bottleneck_dim, - channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), - max_positive=1.0, - min_abs=1.5, - max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), - ) - - self.activation1 = Identity() # for diagnostics - - self.sigmoid = nn.Sigmoid() - - self.activation2 = Identity() # for diagnostics - - assert kernel_size % 2 == 1 - - self.depthwise_conv = ( - ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) - if causal - else nn.Conv1d( - in_channels=bottleneck_dim, - out_channels=bottleneck_dim, - groups=bottleneck_dim, - kernel_size=kernel_size, - padding=kernel_size // 2, - ) - ) - - self.balancer2 = Balancer( - bottleneck_dim, - channel_dim=1, - min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), - max_positive=1.0, - min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), - max_abs=10.0, - ) - - self.whiten = Whiten( - num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01, - ) - - self.out_proj = ActivationDropoutAndLinear( - bottleneck_dim, - channels, - activation="SwooshR", - dropout_p=0.0, - initial_scale=0.05, - ) - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - chunk_size: int = -1, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - - x = self.in_proj(x) # (time, batch, 2*channels) - - x, s = x.chunk(2, dim=2) - s = self.balancer1(s) - s = self.sigmoid(s) - x = self.activation1(x) # identity. - x = x * s - x = self.activation2(x) # identity - - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - if ( - not torch.jit.is_scripting() - and not torch.jit.is_tracing() - and chunk_size >= 0 - ): - # Not support exporting a model for simulated streaming decoding - assert ( - self.causal - ), "Must initialize model with causal=True if you use chunk_size" - x = self.depthwise_conv(x, chunk_size=chunk_size) - else: - x = self.depthwise_conv(x) - - x = self.balancer2(x) - x = x.permute(2, 0, 1) # (time, batch, channels) - - x = self.whiten(x) # (time, batch, channels) - x = self.out_proj(x) # (time, batch, channels) - - return x - - def streaming_forward( - self, - x: Tensor, - cache: Tensor, - src_key_padding_mask: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Compute convolution module in streaming forward mode. - - Args: - x: Input tensor (#time, batch, channels). - cache: cached left context for depthwise_conv of shape - (#batch, channels, left_pad) - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. - - Returns: - - Output tensor (#time, batch, channels). - - Updated cache (#batch, channels, left_pad) - """ - - x = self.in_proj(x) # (time, batch, 2*channels) - - x, s = x.chunk(2, dim=2) - s = self.sigmoid(s) - x = x * s - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - x, cache = self.depthwise_conv.streaming_forward(x, cache=cache) - - x = x.permute(2, 0, 1) # (time, batch, channels) - - x = self.out_proj(x) # (time, batch, channels) - - return x, cache - - -class ScalarMultiply(nn.Module): - def __init__(self, scale: float): - super().__init__() - self.scale = scale - - def forward(self, x): - return x * self.scale - - -def _test_zipformer_main(causal: bool = False): - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - - c = Zipformer2( - encoder_dim=(64, 96), - encoder_unmasked_dim=(48, 64), - num_heads=(4, 4), - causal=causal, - chunk_size=(4,) if causal else (-1,), - left_context_frames=(64,), - ) - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - f = c( - torch.randn(seq_len, batch_size, 64), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f[0].sum().backward() - c.eval() - f = c( - torch.randn(seq_len, batch_size, 64), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f # to remove flake8 warnings - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_zipformer_main(False) - _test_zipformer_main(True) \ No newline at end of file -- Gitee