From 176cf1b165625e474ec2a3d9c07e37f24711e783 Mon Sep 17 00:00:00 2001
From: han_yifeng <hanyifeng2@huawei.com>
Date: Wed, 7 Feb 2024 19:10:44 +0800
Subject: [PATCH 1/2] =?UTF-8?q?zipformer=E6=A8=A1=E5=9E=8B=E9=9D=9E?=
 =?UTF-8?q?=E6=B5=81=E5=BC=8Freadme=E4=B8=8E=E8=84=9A=E6=9C=AC=E4=B8=8A?=
 =?UTF-8?q?=E5=BA=93?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 .../built-in/audio/Zipformer/README.md        |  221 ++
 .../built-in/audio/Zipformer/export-onnx.py   |  625 +++++
 .../icefall_pt/export_torch_aie_ts_dec.py     |   62 +
 .../icefall_pt/export_torch_aie_ts_enc.py     |   65 +
 .../icefall_pt/export_torch_aie_ts_join.py    |   67 +
 .../Zipformer/icefall_pt/model_pt_dec.py      |   50 +
 .../Zipformer/icefall_pt/model_pt_enc.py      |   51 +
 .../Zipformer/icefall_pt/model_pt_join.py     |   51 +
 .../audio/Zipformer/icefall_pt/pt_val_dec.py  |  202 ++
 .../audio/Zipformer/icefall_pt/pt_val_enc.py  |  202 ++
 .../audio/Zipformer/icefall_pt/pt_val_join.py |  203 ++
 .../audio/Zipformer/modify_decoder.py         |   25 +
 .../built-in/audio/Zipformer/zipformer.py     | 2438 +++++++++++++++++
 13 files changed, 4262 insertions(+)
 create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/README.md
 create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.py
 create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_dec.py
 create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_enc.py
 create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_join.py
 create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_dec.py
 create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_enc.py
 create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_join.py
 create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_dec.py
 create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_enc.py
 create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_join.py
 create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/modify_decoder.py
 create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.py

diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/README.md b/AscendIE/TorchAIE/built-in/audio/Zipformer/README.md
new file mode 100644
index 0000000000..7eee8c3e05
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/README.md
@@ -0,0 +1,221 @@
+# Zipformer流式模型-推理指导
+
+
+- [概述](#ZH-CN_TOPIC_0000001172161501)
+
+- [推理环境准备](#ZH-CN_TOPIC_0000001126281702)
+
+- [快速上手](#ZH-CN_TOPIC_0000001126281700)
+
+- [模型推理性能精度](#ZH-CN_TOPIC_0000001172201573)
+
+
+# 概述<a name="ZH-CN_TOPIC_0000001172161501"></a>
+
+(来自论文摘要)Conformer 已成为自动语音识别 (ASR) 中最流行的编码器模型。它将卷积模块添加到变压器中以学习局部和全局依赖性。
+在这项工作中,我们描述了一种更快、内存效率更高、性能更好的转换器,称为 Zipformer。建模变化包括:1)类似 U-Net 的编码器结构,其中中间堆栈以较低的帧速率运行;
+2)重新组织了具有更多模块的块结构,其中我们重新使用注意力权重以提高效率;3)LayerNorm的一种修改形式称为BiasNorm,允许我们保留一些长度信息;4)新的激活函数SwooshR和SwooshL比Swish效果更好。
+我们还提出了一个新的优化器,称为 ScaledAdam,它通过每个张量的当前尺度来缩放更新以保持相对变化大致相同,并且还显式地学习参数尺度。它比 Adam 实现了更快的收敛和更好的性能。
+在 LibriSpeech、Aishell-1 和 WenetSpeech 数据集上进行的大量实验证明了我们提出的 Zipformer 相对于其他最先进的 ASR 模型的有效性。
+  
+
+# 推理环境准备\[所有版本\]<a name="ZH-CN_TOPIC_0000001126281702"></a>
+
+- 该模型需要以下依赖
+
+  **表 1**  版本配套表
+
+| 配套                    | 版本          | 
+|-----------------------|-------------| 
+| CANN                  | 7.0.RC1     | -                                                       |
+| Python                | 3.9.11      |                                                           
+| torch                 | 2.0.1       |
+| Ascend-cann-torch-aie | -           
+| Ascend-cann-aie       | -           
+| 芯片类型                  | Ascend310P3 | -                                                         |
+
+# 快速上手<a name="ZH-CN_TOPIC_0000001126281700"></a>
+
+## 环境安装
+
+1. 安装k2
+   1. (NPU)x86环境  
+    ```shell
+    wget https://huggingface.co/csukuangfj/k2/resolve/main/cpu/k2-1.24.4.dev20231220+cpu.torch2.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
+    pip install k2-1.24.4.dev20231220+cpu.torch2.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
+    ```
+    2. (NPU/GPU)arm环境,需要从源码编译。
+    ```shell
+    git clone https://github.com/k2-fsa/k2.git
+    cd k2
+    export K2_MAKE_ARGS="-j6"
+    python3 setup.py install
+    ```
+    若执行以上命令遇到错误,请参考[此链接](https://k2-fsa.github.io/k2/installation/from_source.html)。  
+    3.(GPU)x86环境。从[此链接](https://k2-fsa.github.io/k2/cuda.html)下载对应CUDA版本的whl文件,然后使用pip进行安装。  
+    4. 验证k2是否安装成功  
+    ```shell
+    python3 -m k2.version
+    ```
+2. 安装其他依赖
+    ```shell
+    pip install lhotse
+    pip install kaldifeat
+    ```
+3. 安装icefall
+    ```shell
+    git clone https://github.com/k2-fsa/icefall.git
+    git reset --hard e2fcb42f5f176d9e39eb38506ab99d0a3adaf202
+   
+    cd icefall
+    pip install -r requirements.txt
+    ```
+4. 将icefall加入环境变量, "/path/to/icefall"替换为icefall文件夹所在的路径。
+   **这一步很重要,否则会报icefall找不到的错误。**
+    ```shell
+    export PYTHONPATH=/path/to/icefall:$PYTHONPATH
+    ```
+
+## 模型下载
+1. 安装 git lfs
+    ```shell
+    curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash
+
+    sudo apt-get install git-lfs
+    git lfs install --skip-repo
+    ```
+2. 下载模型
+    ```shell
+   git clone https://huggingface.co/pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
+    ```
+   若下载失败,请尝试从以上链接手动下载文件。模型转换和推理时只需要用到以下文件:
+    - data/lang_char/tokens.txt
+    - exp/epoch-12.pt
+
+## 模型推理
+1. 替换代码
+   使用本目录下的zipformer.py替换egs/librispeech/ASR/zipformer/zipformer.py
+   export-onnx.py替换egs/librispeech/ASR/zipformer/export-onnx.py
+
+2. 将本代码仓的icefall_pt目录拷贝到工程的根目录./icefall下。
+
+3. 导出onnx模型,用于精度测试。
+    ```shell
+   cd icefall/egs/librispeech/ASR/zipformer
+    # 注意将"icefall-asr-zipformer-streaming-wenetspeech-20230615"修改为实际路径
+   python ./export-onnx.py \
+   # 暂时没有使用以下参数
+     --tokens icefall-asr-zipformer-streaming-wenetspeech-20230615/data/lang_char/tokens.txt \
+     --use-averaged-model 0 \
+     --epoch 12 \
+     --avg 1 \
+     --exp-dir icefall-asr-zipformer-streaming-wenetspeech-20230615/exp \
+     --num-encoder-layers "2,2,3,4,3,2" \
+     --downsampling-factor "1,2,4,8,4,2" \
+     --feedforward-dim "512,768,1024,1536,1024,768" \
+     --num-heads "4,4,4,8,4,4" \
+     --encoder-dim "192,256,384,512,384,256" \
+     --query-head-dim 32 \
+     --value-head-dim 12 \
+     --pos-head-dim 4 \
+     --pos-dim 48 \
+     --encoder-unmasked-dim "192,192,256,256,256,192" \
+     --cnn-module-kernel "31,31,15,15,15,31" \
+     --decoder-dim 512 \
+     --joiner-dim 512 \
+     --causal True \
+     --chunk-size 16 \
+     --left-context-frames 128
+    ```
+   执行结束后,会在“icefall-asr-zipformer-streaming-wenetspeech-20230615/exp”目录下生成三个onnx文件与三个ts文件:
+    - encoder-epoch-12-avg-1.onnx
+    - decoder-epoch-12-avg-1.onnx
+    - joiner-epoch-12-avg-1.onnx
+    - encoder-epoch-12-avg-1.pt
+    - decoder-epoch-12-avg-1.pt
+    - joiner-epoch-12-avg-1.pt
+4. 对torchscript模型进行编译。
+   ```shell
+   cd icefall/icefall_pt
+   
+   # 注意将"icefall-asr-zipformer-streaming-wenetspeech-20230615"修改为实际路径
+   python ./export_torch_aie_ts_enc.py
+   python ./export_torch_aie_ts_dec.py
+   python ./export_torch_aie_ts_join.py
+    ```
+    执行结束后,会在“icefall_pt/pt_compiled_model”目录下生成三个编译好的torchscript文件:
+    - encoder-epoch-12-avg-1_torch_aie_bs1.pt
+    - decoder-epoch-12-avg-1_torch_aie_bs1.pt
+    - joiner-epoch-12-avg-1_torch_aie_bs1.pt  
+
+5. 运行推理样例
+   1. 下载样例语音数据(测试暂时使用全0输入)
+      ```shell
+      cd icefall/egs/librispeech/ASR/zipformer
+      wget https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
+      ```
+   2. 执行推理
+      ```shell
+      # 注意将"icefall-asr-zipformer-streaming-wenetspeech-20230615"修改为实际路径
+      cd icefall/icefall_pt
+      python ./pt_val_enc.py
+      python ./pt_val_dec.py
+      python ./pt_val_join.py
+      ```
+      执行结束后,会在icefall_pt/result下看到三个模型的推理结果:
+      同时输出PT执行的性能
+
+      
+6. 精度测试
+   ```shell
+   cd icefall/egs/librispeech/ASR/zipformer
+   将pretrained.py中的encoder,decoder与joiner的输入改为全全0:enc(1,100,80)+常量100,dec(1,2),joiner(1,512)+(1,512),并保存三个输出
+   python ./pretrained.py
+   三个输出分别比较余弦相似度
+   ```
+    执行结束后,结果为:
+    ```shell
+   encoder:0.97+
+   decoder:0.99+
+   joiner:0.99+
+    ```
+7. 性能测试
+   1. aie模型性能测试
+      ```shell
+      执行pt_val_enc/dec/join.py时会打印
+      ```
+   2. onnx模型性能测试。  
+      1. (可选)若使用GPU,请确保已安装CUDA和pytorch-gpu版本,同时需安装onnxruntime-gpu,如下所示:
+      ```shell
+      
+      ```
+      执行结束后,三个模型的性能信息会打印在命令行,如下所示:
+      ```shell
+      Encoder latency: 964.9530 ms
+      Encoder throughput: 1.0363 fps
+      Decoder latency: 0.4806 ms
+      Decoder throughput: 2080.7143 fps
+      Joiner latency: 0.4994 ms
+      Joiner throughput: 2002.3092 fps
+      ```
+   3. om性能测试:
+      在原本onnx模型的基础上,encoder模型进行onnxsim
+      ```shell
+      onnxsim encoder-epoch-12-avg-1.onnx encoder-epoch-12-avg-1_sim.onnx
+      onnxsim decoder-epoch-12-avg-1.onnx decoder-epoch-12-avg-1_sim.onnx
+      python modify_decoder.py
+      python3 -m ais_bench --model encoder_linux_aarch64.om --loop=2000
+      python3 -m ais_bench --model decoder.om --loop=2000
+      python3 -m ais_bench --model joiner.om --loop=2000
+      ```
+# 模型推理性能精度<a name="ZH-CN_TOPIC_0000001172201573"></a>
+
+Zipformer流式模型由三个子模型组成,分别是encoder、decoder和joiner,其性能如下表所示:
+
+| 模型      | pt插件 - 310P性能(时延/吞吐率) | T4性能(时延/吞吐率)       | A10性能(时延/吞吐率)      |
+|---------|-----------------------|--------------------|--------------------|
+| encoder | 20.4 ms / 49 fps      | 24.7 ms / 40 fps   | 19 ms / 52 fps     |
+| decoder | 0.19 ms / 5156 fps    | 0.59 ms / 1684 fps | 0.13 ms / 7604 fps |
+| joiner  | 0.22 ms / 4448 fps    | 0.13 ms / 7645 fps | 0.11 ms / 9224 fps |
+| 端到端     | 20.81 ms / 48 fps     | 25.42 ms /  39 fps | 19.24 ms / 52 fps  |
+
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.py
new file mode 100644
index 0000000000..805f7405c8
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.py
@@ -0,0 +1,625 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
+# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
+
+"""
+This script exports a transducer model from PyTorch to ONNX.
+
+We use the pre-trained model from
+https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+as an example to show how to use this file.
+
+1. Download the pre-trained model
+
+cd egs/librispeech/ASR
+
+repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+repo=$(basename $repo_url)
+
+pushd $repo
+git lfs pull --include "exp/pretrained.pt"
+
+cd exp
+ln -s pretrained.pt epoch-99.pt
+popd
+
+2. Export the model to ONNX
+
+python3 ./egs/librispeech/ASR/zipformer/export-onnx.py \
+  --tokens $repo/data/lang_bpe_500/tokens.txt \
+  --use-averaged-model 0 \
+  --epoch 99 \
+  --avg 1 \
+  --exp-dir $repo/exp \
+  --num-encoder-layers "2,2,3,4,3,2" \
+  --downsampling-factor "1,2,4,8,4,2" \
+  --feedforward-dim "512,768,1024,1536,1024,768" \
+  --num-heads "4,4,4,8,4,4" \
+  --encoder-dim "192,256,384,512,384,256" \
+  --query-head-dim 32 \
+  --value-head-dim 12 \
+  --pos-head-dim 4 \
+  --pos-dim 48 \
+  --encoder-unmasked-dim "192,192,256,256,256,192" \
+  --cnn-module-kernel "31,31,15,15,15,31" \
+  --decoder-dim 512 \
+  --joiner-dim 512 \
+  --causal False \
+  --chunk-size "16,32,64,-1" \
+  --left-context-frames "64,128,256,-1"
+
+It will generate the following 3 files inside $repo/exp:
+
+  - encoder-epoch-99-avg-1.onnx
+  - decoder-epoch-99-avg-1.onnx
+  - joiner-epoch-99-avg-1.onnx
+
+See ./onnx_pretrained.py and ./onnx_check.py for how to
+use the exported ONNX models.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Dict, Tuple
+
+import k2
+import onnx
+import torch
+import torch.nn as nn
+from decoder import Decoder
+from onnxruntime.quantization import QuantType, quantize_dynamic
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_model, get_params
+from zipformer import Zipformer2
+
+from icefall.checkpoint import (
+    average_checkpoints,
+    average_checkpoints_with_averaged_model,
+    find_checkpoints,
+    load_checkpoint,
+)
+from icefall.utils import make_pad_mask, num_tokens, str2bool
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--epoch",
+        type=int,
+        default=12,
+        help="""It specifies the checkpoint to use for averaging.
+        Note: Epoch counts from 0.
+        You can specify --avg to use more checkpoints for model averaging.""",
+    )
+
+    parser.add_argument(
+        "--iter",
+        type=int,
+        default=0,
+        help="""If positive, --epoch is ignored and it
+        will use the checkpoint exp_dir/checkpoint-iter.pt.
+        You can specify --avg to use more checkpoints for model averaging.
+        """,
+    )
+
+    parser.add_argument(
+        "--avg",
+        type=int,
+        default=1,
+        help="Number of checkpoints to average. Automatically select "
+             "consecutive checkpoints before the checkpoint specified by "
+             "'--epoch' and '--iter'",
+    )
+
+    parser.add_argument(
+        "--use-averaged-model",
+        type=str2bool,
+        default=False,
+        help="Whether to load averaged model. Currently it only supports "
+             "using --epoch. If True, it would decode with the averaged model "
+             "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+             "Actually only the models with epoch number of `epoch-avg` and "
+             "`epoch` are loaded for averaging. ",
+    )
+
+    parser.add_argument(
+        "--exp-dir",
+        type=str,
+        default="/home/devkit/hanyifeng/icefall/egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/",
+        help="""It specifies the directory where all training related
+        files, e.g., checkpoints, log, etc, are saved
+        """,
+    )
+
+    parser.add_argument(
+        "--tokens",
+        type=str,
+        default="/home/devkit/hanyifeng/icefall/egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/data/lang_char/tokens.txt",
+        help="Path to the tokens.txt",
+    )
+
+    parser.add_argument(
+        "--context-size",
+        type=int,
+        default=2,
+        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+    )
+
+    add_model_arguments(parser)
+
+    return parser
+
+
+def add_meta_data(filename: str, meta_data: Dict[str, str]):
+    """Add meta data to an ONNX model. It is changed in-place.
+
+    Args:
+      filename:
+        Filename of the ONNX model to be changed.
+      meta_data:
+        Key-value pairs.
+    """
+    model = onnx.load(filename)
+    for key, value in meta_data.items():
+        meta = model.metadata_props.add()
+        meta.key = key
+        meta.value = value
+
+    onnx.save(model, filename)
+
+
+class OnnxEncoder(nn.Module):
+    """A wrapper for Zipformer and the encoder_proj from the joiner"""
+
+    def __init__(
+            self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
+    ):
+        """
+        Args:
+          encoder:
+            A Zipformer encoder.
+          encoder_proj:
+            The projection layer for encoder from the joiner.
+        """
+        super().__init__()
+        self.encoder = encoder
+        self.encoder_embed = encoder_embed
+        self.encoder_proj = encoder_proj
+
+    def forward(
+            self,
+            x: torch.Tensor,
+            x_lens: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Please see the help information of Zipformer.forward
+
+        Args:
+          x:
+            A 3-D tensor of shape (N, T, C)
+          x_lens:
+            A 1-D tensor of shape (N,). Its dtype is torch.int64
+        Returns:
+          Return a tuple containing:
+            - encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
+            - encoder_out_lens, A 1-D tensor of shape (N,)
+        """
+        x, x_lens = self.encoder_embed(x, x_lens)
+        src_key_padding_mask = make_pad_mask(x_lens)
+        x = x.permute(1, 0, 2)
+        encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
+        encoder_out = encoder_out.permute(1, 0, 2)
+        encoder_out = self.encoder_proj(encoder_out)
+        # Now encoder_out is of shape (N, T, joiner_dim)
+
+        return encoder_out, encoder_out_lens
+
+
+class OnnxDecoder(nn.Module):
+    """A wrapper for Decoder and the decoder_proj from the joiner"""
+
+    def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
+        super().__init__()
+        self.decoder = decoder
+        self.decoder_proj = decoder_proj
+
+    def forward(self, y: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+          y:
+            A 2-D tensor of shape (N, context_size).
+        Returns
+          Return a 2-D tensor of shape (N, joiner_dim)
+        """
+        need_pad = False
+        decoder_output = self.decoder(y, need_pad=need_pad)
+        decoder_output = decoder_output.squeeze(1)
+        output = self.decoder_proj(decoder_output)
+
+        return output
+
+
+class OnnxJoiner(nn.Module):
+    """A wrapper for the joiner"""
+
+    def __init__(self, output_linear: nn.Linear):
+        super().__init__()
+        self.output_linear = output_linear
+
+    def forward(
+            self,
+            encoder_out: torch.Tensor,
+            decoder_out: torch.Tensor,
+    ) -> torch.Tensor:
+        """
+        Args:
+          encoder_out:
+            A 2-D tensor of shape (N, joiner_dim)
+          decoder_out:
+            A 2-D tensor of shape (N, joiner_dim)
+        Returns:
+          Return a 2-D tensor of shape (N, vocab_size)
+        """
+        logit = encoder_out + decoder_out
+        logit = self.output_linear(torch.tanh(logit))
+        return logit
+
+
+def export_encoder_model_onnx(
+        encoder_model: OnnxEncoder,
+        encoder_filename: str,
+        opset_version: int = 11,
+) -> None:
+    """Export the given encoder model to ONNX format.
+    The exported model has two inputs:
+
+        - x, a tensor of shape (N, T, C); dtype is torch.float32
+        - x_lens, a tensor of shape (N,); dtype is torch.int64
+
+    and it has two outputs:
+
+        - encoder_out, a tensor of shape (N, T', joiner_dim)
+        - encoder_out_lens, a tensor of shape (N,)
+
+    Args:
+      encoder_model:
+        The input encoder model
+      encoder_filename:
+        The filename to save the exported ONNX model.
+      opset_version:
+        The opset version to use.
+    """
+    x = torch.zeros(1, 100, 80, dtype=torch.float32)
+    x_lens = torch.tensor([100], dtype=torch.int64)
+
+    encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
+    encoder_model.save(str(encoder_filename).replace("onnx", "pt"))
+    torch.onnx.export(
+        encoder_model,
+        (x, x_lens),
+        encoder_filename,
+        verbose=False,
+        opset_version=opset_version,
+        input_names=["x", "x_lens"],
+        output_names=["encoder_out", "encoder_out_lens"],
+        dynamic_axes={
+            "x": {0: "N", 1: "T"},
+            "x_lens": {0: "N"},
+            "encoder_out": {0: "N", 1: "T"},
+            "encoder_out_lens": {0: "N"},
+        },
+    )
+
+    meta_data = {
+        "model_type": "zipformer2",
+        "version": "1",
+        "model_author": "k2-fsa",
+        "comment": "non-streaming zipformer2",
+    }
+    logging.info(f"meta_data: {meta_data}")
+
+    add_meta_data(filename=encoder_filename, meta_data=meta_data)
+
+
+def export_decoder_model_onnx(
+        decoder_model: OnnxDecoder,
+        decoder_filename: str,
+        opset_version: int = 11,
+) -> None:
+    """Export the decoder model to ONNX format.
+
+    The exported model has one input:
+
+        - y: a torch.int64 tensor of shape (N, decoder_model.context_size)
+
+    and has one output:
+
+        - decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
+
+    Args:
+      decoder_model:
+        The decoder model to be exported.
+      decoder_filename:
+        Filename to save the exported ONNX model.
+      opset_version:
+        The opset version to use.
+    """
+    context_size = decoder_model.decoder.context_size
+    vocab_size = decoder_model.decoder.vocab_size
+
+    y = torch.zeros(10, context_size, dtype=torch.int64)
+    ts_decoder_model = torch.jit.trace(decoder_model, y)
+    ts_decoder_model.save(str(decoder_filename).replace("onnx", "pt"))
+    decoder_model = torch.jit.script(decoder_model)
+
+    torch.onnx.export(
+        decoder_model,
+        y,
+        decoder_filename,
+        verbose=False,
+        opset_version=opset_version,
+        input_names=["y"],
+        output_names=["decoder_out"],
+        dynamic_axes={
+            "y": {0: "N"},
+            "decoder_out": {0: "N"},
+        },
+    )
+
+    meta_data = {
+        "context_size": str(context_size),
+        "vocab_size": str(vocab_size),
+    }
+    add_meta_data(filename=decoder_filename, meta_data=meta_data)
+
+
+def export_joiner_model_onnx(
+        joiner_model: nn.Module,
+        joiner_filename: str,
+        opset_version: int = 11,
+) -> None:
+    """Export the joiner model to ONNX format.
+    The exported joiner model has two inputs:
+
+        - encoder_out: a tensor of shape (N, joiner_dim)
+        - decoder_out: a tensor of shape (N, joiner_dim)
+
+    and produces one output:
+
+        - logit: a tensor of shape (N, vocab_size)
+    """
+    joiner_dim = joiner_model.output_linear.weight.shape[1]
+    logging.info(f"joiner dim: {joiner_dim}")
+
+    projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
+    projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
+    ts_joiner_model = torch.jit.trace(joiner_model, (projected_encoder_out, projected_decoder_out))
+    ts_joiner_model.save(str(joiner_filename).replace("onnx", "pt"))
+
+    torch.onnx.export(
+        joiner_model,
+        (projected_encoder_out, projected_decoder_out),
+        joiner_filename,
+        verbose=False,
+        opset_version=opset_version,
+        input_names=[
+            "encoder_out",
+            "decoder_out",
+        ],
+        output_names=["logit"],
+        dynamic_axes={
+            "encoder_out": {0: "N"},
+            "decoder_out": {0: "N"},
+            "logit": {0: "N"},
+        },
+    )
+    meta_data = {
+        "joiner_dim": str(joiner_dim),
+    }
+    add_meta_data(filename=joiner_filename, meta_data=meta_data)
+
+
+@torch.no_grad()
+def main():
+    args = get_parser().parse_args()
+    args.exp_dir = Path(args.exp_dir)
+
+    params = get_params()
+    params.update(vars(args))
+
+    device = torch.device("cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda", 0)
+
+    logging.info(f"device: {device}")
+
+    token_table = k2.SymbolTable.from_file(params.tokens)
+    params.blank_id = token_table["<blk>"]
+    params.vocab_size = num_tokens(token_table) + 1
+
+    logging.info(params)
+
+    logging.info("About to create model")
+    model = get_model(params)
+
+    model.to(device)
+
+    if not params.use_averaged_model:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                        : params.avg
+                        ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+        elif params.avg == 1:
+            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+        else:
+            start = params.epoch - params.avg + 1
+            filenames = []
+            for i in range(start, params.epoch + 1):
+                if i >= 1:
+                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+            logging.info(f"averaging {filenames}")
+            model.to(device)
+            model.load_state_dict(average_checkpoints(filenames, device=device))
+    else:
+        if params.iter > 0:
+            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+                        : params.avg + 1
+                        ]
+            if len(filenames) == 0:
+                raise ValueError(
+                    f"No checkpoints found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            elif len(filenames) < params.avg + 1:
+                raise ValueError(
+                    f"Not enough checkpoints ({len(filenames)}) found for"
+                    f" --iter {params.iter}, --avg {params.avg}"
+                )
+            filename_start = filenames[-1]
+            filename_end = filenames[0]
+            logging.info(
+                "Calculating the averaged model over iteration checkpoints"
+                f" from {filename_start} (excluded) to {filename_end}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+        else:
+            assert params.avg > 0, params.avg
+            start = params.epoch - params.avg
+            assert start >= 1, start
+            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+            logging.info(
+                f"Calculating the averaged model over epoch range from "
+                f"{start} (excluded) to {params.epoch}"
+            )
+            model.to(device)
+            model.load_state_dict(
+                average_checkpoints_with_averaged_model(
+                    filename_start=filename_start,
+                    filename_end=filename_end,
+                    device=device,
+                )
+            )
+
+    model.to("cpu")
+    model.eval()
+
+    convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
+
+    encoder = OnnxEncoder(
+        encoder=model.encoder,
+        encoder_embed=model.encoder_embed,
+        encoder_proj=model.joiner.encoder_proj,
+    )
+
+    decoder = OnnxDecoder(
+        decoder=model.decoder,
+        decoder_proj=model.joiner.decoder_proj,
+    )
+
+    joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
+
+    encoder_num_param = sum([p.numel() for p in encoder.parameters()])
+    decoder_num_param = sum([p.numel() for p in decoder.parameters()])
+    joiner_num_param = sum([p.numel() for p in joiner.parameters()])
+    total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
+    logging.info(f"encoder parameters: {encoder_num_param}")
+    logging.info(f"decoder parameters: {decoder_num_param}")
+    logging.info(f"joiner parameters: {joiner_num_param}")
+    logging.info(f"total parameters: {total_num_param}")
+
+    if params.iter > 0:
+        suffix = f"iter-{params.iter}"
+    else:
+        suffix = f"epoch-{params.epoch}"
+
+    suffix += f"-avg-{params.avg}"
+
+    opset_version = 13
+
+    logging.info("Exporting encoder")
+    encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
+    export_encoder_model_onnx(
+        encoder,
+        encoder_filename,
+        opset_version=opset_version,
+    )
+    logging.info(f"Exported encoder to {encoder_filename}")
+
+    logging.info("Exporting decoder")
+    decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
+    export_decoder_model_onnx(
+        decoder,
+        decoder_filename,
+        opset_version=opset_version,
+    )
+    logging.info(f"Exported decoder to {decoder_filename}")
+
+    logging.info("Exporting joiner")
+    joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
+    export_joiner_model_onnx(
+        joiner,
+        joiner_filename,
+        opset_version=opset_version,
+    )
+    logging.info(f"Exported joiner to {joiner_filename}")
+
+    # Generate int8 quantization models
+    # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
+
+    logging.info("Generate int8 quantization models")
+
+    encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
+    quantize_dynamic(
+        model_input=encoder_filename,
+        model_output=encoder_filename_int8,
+        op_types_to_quantize=["MatMul"],
+        weight_type=QuantType.QInt8,
+    )
+
+    decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
+    quantize_dynamic(
+        model_input=decoder_filename,
+        model_output=decoder_filename_int8,
+        op_types_to_quantize=["MatMul", "Gather"],
+        weight_type=QuantType.QInt8,
+    )
+
+    joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
+    quantize_dynamic(
+        model_input=joiner_filename,
+        model_output=joiner_filename_int8,
+        op_types_to_quantize=["MatMul"],
+        weight_type=QuantType.QInt8,
+    )
+
+
+if __name__ == "__main__":
+    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+    logging.basicConfig(format=formatter, level=logging.INFO)
+    main()
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_dec.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_dec.py
new file mode 100644
index 0000000000..bb341fca41
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_dec.py
@@ -0,0 +1,62 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+
+import torch
+import torch_aie
+from torch_aie import _enums
+
+def export_torch_aie(opt):
+    trace_model = torch.jit.load(opt.torch_script_path)
+    trace_model.eval()
+
+    torch_aie.set_device(0)
+    inputs = []
+    # inputs.append(torch_aie.Input([10, 2], dtype = torch_aie.dtype.INT64))
+    inputs.append(torch_aie.Input([opt.batch_size, 2], dtype = torch_aie.dtype.INT64))
+
+    torchaie_model = torch_aie.compile(
+        trace_model,
+        inputs=inputs,
+        precision_policy=_enums.PrecisionPolicy.FP16,
+        truncate_long_and_double=True,
+        require_full_compilation=False,
+        allow_tensor_replace_int=False,
+        min_block_size=3,
+        torch_executed_ops=[],
+        soc_version='Ascend310P3',
+        optimization_level=0)
+    suffix = os.path.splitext(opt.torch_script_path)[-1]
+    saved_name = os.path.basename(opt.torch_script_path).split('.')[0] + f"_torch_aie_bs{opt.batch_size}" + suffix
+    torchaie_model.save(os.path.join(opt.save_path, saved_name))
+    print("torch aie tdnn compiled done. saved model is ", os.path.join(opt.save_path, saved_name))
+
+
+def parse_opt():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--torch_script_path', type=str, default='../egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/decoder-epoch-12-avg-1.pt', help='trace model path')
+    parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version')
+    parser.add_argument('--batch_size', type=int, default=1, help='batch size')
+    parser.add_argument('--save_path', type=str, default='./pt_compiled_model/', help='compiled model path')
+    opt = parser.parse_args()
+    return opt
+
+def main(opt):
+    export_torch_aie(opt)
+
+if __name__ == '__main__':
+    opt = parse_opt()
+    main(opt)
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_enc.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_enc.py
new file mode 100644
index 0000000000..05c9e64e06
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_enc.py
@@ -0,0 +1,65 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+
+import torch
+import torch_aie
+from torch_aie import _enums
+
+def export_torch_aie(opt):
+    trace_model = torch.jit.load(opt.torch_script_path)
+    trace_model.eval()
+
+    torch_aie.set_device(0)
+    inputs = []
+    # x = torch.zeros(1, 100, 80, dtype=torch.float32)
+    # x_lens = torch.tensor([100], dtype=torch.int64)
+    inputs.append(torch_aie.Input([1, 100, 80], dtype = torch_aie.dtype.FLOAT))
+    inputs.append(torch_aie.Input([1], dtype = torch_aie.dtype.INT64))
+
+    torchaie_model = torch_aie.compile(
+        trace_model,
+        inputs=inputs,
+        precision_policy=_enums.PrecisionPolicy.FP16,
+        # truncate_long_and_double=True,
+        # require_full_compilation=False,
+        # allow_tensor_replace_int=False,
+        # min_block_size=3,
+        # torch_executed_ops=[],
+        soc_version='Ascend310P3',
+        optimization_level=0
+        )
+    suffix = os.path.splitext(opt.torch_script_path)[-1]
+    saved_name = os.path.basename(opt.torch_script_path).split('.')[0] + f"_torch_aie_bs{opt.batch_size}" + suffix
+    torchaie_model.save(os.path.join(opt.save_path, saved_name))
+    print("torch aie tdnn compiled done. saved model is ", os.path.join(opt.save_path, saved_name))
+
+
+def parse_opt():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--torch_script_path', type=str, default='../egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/encoder-epoch-12-avg-1.pt', help='trace model path')
+    parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version')
+    parser.add_argument('--batch_size', type=int, default=1, help='batch size')
+    parser.add_argument('--save_path', type=str, default='./pt_compiled_model/', help='compiled model path')
+    opt = parser.parse_args()
+    return opt
+
+def main(opt):
+    export_torch_aie(opt)
+
+if __name__ == '__main__':
+    opt = parse_opt()
+    main(opt)
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_join.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_join.py
new file mode 100644
index 0000000000..ae0b1e5b47
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_join.py
@@ -0,0 +1,67 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+
+import torch
+import torch_aie
+from torch_aie import _enums
+
+def export_torch_aie(opt):
+    trace_model = torch.jit.load(opt.torch_script_path)
+    trace_model.eval()
+
+    torch_aie.set_device(0)
+    inputs = []
+    # inputs.append(torch_aie.Input([11, 512], dtype = torch_aie.dtype.FLOAT))
+    # inputs.append(torch_aie.Input([11, 512], dtype = torch_aie.dtype.FLOAT))
+    inputs.append(torch_aie.Input([opt.batch_size, 512], dtype = torch_aie.dtype.FLOAT))
+    inputs.append(torch_aie.Input([opt.batch_size, 512], dtype = torch_aie.dtype.FLOAT))
+    # projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
+    # projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
+    # inputs.append(torch_aie.Input([10], dtype = torch_aie.dtype.INT64))
+
+    torchaie_model = torch_aie.compile(
+        trace_model,
+        inputs=inputs,
+        precision_policy=_enums.PrecisionPolicy.FP16,
+        truncate_long_and_double=True,
+        require_full_compilation=False,
+        allow_tensor_replace_int=False,
+        min_block_size=3,
+        torch_executed_ops=[],
+        soc_version='Ascend310P3',
+        optimization_level=0)
+    suffix = os.path.splitext(opt.torch_script_path)[-1]
+    saved_name = os.path.basename(opt.torch_script_path).split('.')[0] + f"_torch_aie_bs{opt.batch_size}" + suffix
+    torchaie_model.save(os.path.join(opt.save_path, saved_name))
+    print("torch aie tdnn compiled done. saved model is ", os.path.join(opt.save_path, saved_name))
+
+
+def parse_opt():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--torch_script_path', type=str, default='../egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/joiner-epoch-12-avg-1.pt', help='trace model path')
+    parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version')
+    parser.add_argument('--batch_size', type=int, default=1, help='batch size')
+    parser.add_argument('--save_path', type=str, default='./pt_compiled_model/', help='compiled model path')
+    opt = parser.parse_args()
+    return opt
+
+def main(opt):
+    export_torch_aie(opt)
+
+if __name__ == '__main__':
+    opt = parse_opt()
+    main(opt)
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_dec.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_dec.py
new file mode 100644
index 0000000000..d8934b5f29
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_dec.py
@@ -0,0 +1,50 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from tqdm import tqdm
+
+import torch
+import torch_aie
+
+
+def forward_infer(model, dataloader, batchsize, device_id):
+    pred_results = []
+    inference_time = []
+    loop_num = 0
+    for snd in tqdm(dataloader):
+        result, inference_time = pt_infer(model, snd[0].to(torch.int64), device_id, loop_num, inference_time)
+        pred_results.append(result)
+        loop_num += 1
+
+    avg_inf_time = sum(inference_time) / len(inference_time) / batchsize * 1000
+    print('performance(ms):', avg_inf_time)
+    print("throughput(fps): ", 1000 / avg_inf_time)
+
+    return pred_results
+
+def pt_infer(model, input_li_1, device_id, loop_num, inference_time):
+
+    input_npu_li_1 = input_li_1.to("npu:" + str(device_id))
+    stream = torch_aie.npu.Stream("npu:" + str(device_id))
+    with torch_aie.npu.stream(stream):
+        inf_start = time.time()
+        output_npu = model.forward(input_npu_li_1)
+        stream.synchronize()
+        inf_end = time.time()
+        inf = inf_end - inf_start
+        if loop_num >= 5:   # use 5 step to warmup
+            inference_time.append(inf)
+    results = [output_npu.to("cpu")]
+    return results, inference_time
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_enc.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_enc.py
new file mode 100644
index 0000000000..61396050b6
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_enc.py
@@ -0,0 +1,51 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from tqdm import tqdm
+import numpy as np
+import torch
+import torch_aie
+
+
+def forward_infer(model, dataloader, batchsize, device_id):
+    pred_results = []
+    inference_time = []
+    loop_num = 0
+    for snd in tqdm(dataloader):
+        result, inference_time = pt_infer(model, snd[0].to(torch.float32), snd[1].to(torch.int64), device_id, loop_num, inference_time)
+        pred_results.append(result)
+        loop_num += 1
+
+    avg_inf_time = sum(inference_time) / len(inference_time) / batchsize * 1000
+    print('performance(ms):', avg_inf_time)
+    print("throughput(fps): ", 1000 / avg_inf_time)
+
+    return pred_results
+
+def pt_infer(model, input_li_1, input_li_2, device_id, loop_num, inference_time):
+
+    input_npu_li_1 = input_li_1.to("npu:" + str(device_id))
+    input_npu_li_2 = input_li_2.to("npu:" + str(device_id))
+    stream = torch_aie.npu.Stream("npu:" + str(device_id))
+    with torch_aie.npu.stream(stream):
+        inf_start = time.time()
+        output_npu = model.forward(input_npu_li_1, input_npu_li_2)
+        stream.synchronize()
+        inf_end = time.time()
+        inf = inf_end - inf_start
+        if loop_num >= 5:   # use 5 step to warmup
+            inference_time.append(inf)
+    results = [output_npu[0].to("cpu"), output_npu[1].to("cpu")]
+    return results, inference_time
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_join.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_join.py
new file mode 100644
index 0000000000..45b8ec0f93
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_join.py
@@ -0,0 +1,51 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from tqdm import tqdm
+
+import torch
+import torch_aie
+
+
+def forward_infer(model, dataloader, batchsize, device_id):
+    pred_results = []
+    inference_time = []
+    loop_num = 0
+    for snd in tqdm(dataloader):
+        result, inference_time = pt_infer(model, snd[0].to(torch.float32), snd[1].to(torch.float32), device_id, loop_num, inference_time)
+        pred_results.append(result)
+        loop_num += 1
+
+    avg_inf_time = sum(inference_time) / len(inference_time) / batchsize * 1000
+    print('performance(ms):', avg_inf_time)
+    print("throughput(fps): ", 1000 / avg_inf_time)
+
+    return pred_results
+
+def pt_infer(model, input_li_1, input_li_2, device_id, loop_num, inference_time):
+
+    input_npu_li_1 = input_li_1.to("npu:" + str(device_id))
+    input_npu_li_2 = input_li_2.to("npu:" + str(device_id))
+    stream = torch_aie.npu.Stream("npu:" + str(device_id))
+    with torch_aie.npu.stream(stream):
+        inf_start = time.time()
+        output_npu = model.forward(input_npu_li_1, input_npu_li_2)
+        stream.synchronize()
+        inf_end = time.time()
+        inf = inf_end - inf_start
+        if loop_num >= 5:   # use 5 step to warmup
+            inference_time.append(inf)
+    results = [output_npu.to("cpu")]
+    return results, inference_time
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_dec.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_dec.py
new file mode 100644
index 0000000000..3106753822
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_dec.py
@@ -0,0 +1,202 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import copy
+import argparse
+
+import torch
+import torch_aie
+import numpy as np
+from torch_aie import _enums
+from torch.utils.data import dataloader
+
+from model_pt_dec import forward_infer
+
+
+class InfiniteDataLoader(dataloader.DataLoader):
+    """ Dataloader that reuses workers
+
+    Uses same syntax as vanilla DataLoader
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
+        self.iterator = super().__iter__()
+
+    def __len__(self):
+        return len(self.batch_sampler.sampler)
+
+    def __iter__(self):
+        for i in range(len(self)):
+            yield next(self.iterator)
+
+
+class _RepeatSampler:
+    """ Sampler that repeats forever
+
+    Args:
+        sampler (Sampler)
+    """
+
+    def __init__(self, sampler):
+        self.sampler = sampler
+
+    def __iter__(self):
+        while True:
+            yield from iter(self.sampler)
+
+# def collate_fn(batch):
+#     """
+#     data preprocessing
+#     """
+#     def func(p):
+#         """
+#         data size
+#         """
+#         return p[0].size(1)
+
+#     batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True)
+#     longest_sample = max(batch, key=func)[0]
+#     freq_size = longest_sample.size(0)
+#     minibatch_size = len(batch)
+#     max_seqlength = longest_sample.size(1)
+#     inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength)
+#     input_percentages = torch.FloatTensor(minibatch_size)
+#     for x in range(minibatch_size):
+#         sample = batch[x]
+#         tensor = sample[0]
+#         seq_length = tensor.size(1)
+#         inputs[x][0].narrow(1, 0, seq_length).copy_(tensor)
+#         input_percentages[x] = seq_length / float(max_seqlength)
+#     return inputs, input_percentages, []
+
+
+def get_dataloader(opt):
+    # with open(to_absolute_path(opt.label_file)) as label_file:
+    #     labels = json.load(label_file)
+
+    # dataset = SpectrogramDataset(
+    #     audio_conf=DataConfig.spect,
+    #     input_path=opt.data_file,
+    #     labels=labels,
+    #     normalize=True,
+    #     aug_cfg=DataConfig.augmentation
+    # )
+    # inputs, input_percentages, _ = collate_fn(dataset)
+    # input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
+    # print(inputs[0])
+    # print(input_sizes.tolist())
+
+    # datasets = [[inputs[i], input_sizes[i]] for i in range(len(input_sizes))]
+    x = torch.zeros(2, dtype=torch.int64)
+    # x_lens = torch.tensor([100], dtype=torch.int64)
+    # x_lens = 100
+    datasets = []
+    for i in range(20):
+        datasets.append([copy.deepcopy(x)])
+    # print(datasets)
+    while len(datasets) % opt.batch_size != 0:
+        datasets.append(datasets[-1])
+    m = 1
+    datasets_orig = copy.deepcopy(datasets)
+    while m < opt.multi:
+        datasets += datasets_orig
+        m += 1
+
+    loader =  InfiniteDataLoader  # only DataLoader allows for attribute updates
+    print("OPT_BATCHSIZE: ", opt.batch_size)
+    return loader(datasets,
+                  batch_size=opt.batch_size,
+                  shuffle=False,
+                  num_workers=1,
+                  sampler=None,
+                  pin_memory=True)
+
+
+def save_tensor_arr_to_file(arr, file_path):
+    write_sen = ""
+    for m in arr:
+        for l in m:
+            for c in l:
+                write_sen += str(c) + " "
+            write_sen += "\n"
+    with open(file_path, "w", encoding='utf-8') as f:
+        f.write(write_sen)
+
+def save_size_to_file(size, file_path):
+    write_sen = "" + str(size) + " "
+    with open(file_path, "w", encoding='utf-8') as f:
+        f.write(write_sen)
+
+def main(opt):
+    # load model
+    model = torch.jit.load(opt.model)
+    batch_size = opt.batch_size
+    torch_aie.set_device(opt.device_id)
+    if opt.need_compile:
+        inputs = []
+        inputs.append(torch_aie.Input([10, 2], dtype=torch_aie.dtype.INT64))
+        # inputs.append(torch_aie.Input([opt.batch_size], dtype=torch_aie.dtype.INT32))
+
+        model = torch_aie.compile(
+            model,
+            inputs=inputs,
+            precision_policy=_enums.PrecisionPolicy.FP16,
+            truncate_long_and_double=True,
+            require_full_compilation=False,
+            allow_tensor_replace_int=False,
+            min_block_size=3,
+            torch_executed_ops=[],
+            soc_version='Ascend310P3',
+            optimization_level=0)
+
+    dataloader = get_dataloader(opt)
+    pred_results = forward_infer(model, dataloader, batch_size, opt.device_id)
+    # for index, res in enumerate(pred_results):
+    #     print(index, "  ", res)
+    #     print(res[0].shape)
+
+    if opt.batch_size == 1 and opt.multi == 1:
+        result_path = opt.result_path
+        if(os.path.exists(result_path) == False):
+                os.makedirs(result_path)
+        for index, res in enumerate(pred_results):
+            # for i in range(batch_size):
+            # result_fname_0 = 'data' + str(index * batch_size + i + 1) + '_0.txt'
+            # result_fname_1 = 'data' + str(index * batch_size + i + 1) + '_1.txt'
+            result_fname_0 = 'data' + str(index) + '_0.txt'
+            # result_fname_1 = 'data' + str(index) + '_1.txt'
+            # res = np.array(res)
+            # save_tensor_arr_to_file(np.array(res[0][i]), os.path.join(result_path, result_fname_0))
+            # save_size_to_file(res[1].numpy()[i], os.path.join(result_path, result_fname_1))
+            save_tensor_arr_to_file(np.array(res), os.path.join(result_path, result_fname_0))
+            # save_size_to_file(res[1].numpy()[0], os.path.join(result_path, result_fname_1))
+
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='DeepSpeech2 offline model inference.')
+    parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version')
+    parser.add_argument('--model', type=str, default="./pt_compiled_model/decoder-epoch-12-avg-1_torch_aie_bs1.pt", help='ts model path')
+    parser.add_argument('--need_compile', action="store_true", help='if the loaded model needs to be compiled or not')
+    parser.add_argument('--batch_size', type=int, default=1, help='batch size')
+    parser.add_argument('--device_id', type=int, default=0, help='device id')
+    parser.add_argument('--data_file', default='./deepspeech.pytorch/data/an4_test_manifest.json')
+    parser.add_argument('--label_file', default='./deepspeech.pytorch/labels.json')
+    parser.add_argument('--result_path', default='result/decoder')
+    parser.add_argument('--multi', type=int, default=1, help='multiples of dataset replication for enough infer loop. if multi != 1, the pred result will not be stored.')
+    opt = parser.parse_args()
+    main(opt)
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_enc.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_enc.py
new file mode 100644
index 0000000000..496e767fe0
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_enc.py
@@ -0,0 +1,202 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import copy
+import argparse
+
+import torch
+import torch_aie
+import numpy as np
+from torch_aie import _enums
+from torch.utils.data import dataloader
+
+from model_pt_enc import forward_infer
+
+
+class InfiniteDataLoader(dataloader.DataLoader):
+    """ Dataloader that reuses workers
+
+    Uses same syntax as vanilla DataLoader
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
+        self.iterator = super().__iter__()
+
+    def __len__(self):
+        return len(self.batch_sampler.sampler)
+
+    def __iter__(self):
+        for i in range(len(self)):
+            yield next(self.iterator)
+
+
+class _RepeatSampler:
+    """ Sampler that repeats forever
+
+    Args:
+        sampler (Sampler)
+    """
+
+    def __init__(self, sampler):
+        self.sampler = sampler
+
+    def __iter__(self):
+        while True:
+            yield from iter(self.sampler)
+
+# def collate_fn(batch):
+#     """
+#     data preprocessing
+#     """
+#     def func(p):
+#         """
+#         data size
+#         """
+#         return p[0].size(1)
+
+#     batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True)
+#     longest_sample = max(batch, key=func)[0]
+#     freq_size = longest_sample.size(0)
+#     minibatch_size = len(batch)
+#     max_seqlength = longest_sample.size(1)
+#     inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength)
+#     input_percentages = torch.FloatTensor(minibatch_size)
+#     for x in range(minibatch_size):
+#         sample = batch[x]
+#         tensor = sample[0]
+#         seq_length = tensor.size(1)
+#         inputs[x][0].narrow(1, 0, seq_length).copy_(tensor)
+#         input_percentages[x] = seq_length / float(max_seqlength)
+#     return inputs, input_percentages, []
+
+
+def get_dataloader(opt):
+    # with open(to_absolute_path(opt.label_file)) as label_file:
+    #     labels = json.load(label_file)
+
+    # dataset = SpectrogramDataset(
+    #     audio_conf=DataConfig.spect,
+    #     input_path=opt.data_file,
+    #     labels=labels,
+    #     normalize=True,
+    #     aug_cfg=DataConfig.augmentation
+    # )
+    # inputs, input_percentages, _ = collate_fn(dataset)
+    # input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
+    # print(inputs[0])
+    # print(input_sizes.tolist())
+
+    # datasets = [[inputs[i], input_sizes[i]] for i in range(len(input_sizes))]
+    x = torch.zeros(100, 80, dtype=torch.float32)
+    # x_lens = torch.tensor([100], dtype=torch.int64)
+    x_lens = 100
+    datasets = []
+    for i in range(20):
+        datasets.append([copy.deepcopy(x), copy.deepcopy(x_lens)])
+    # print(datasets)
+    while len(datasets) % opt.batch_size != 0:
+        datasets.append(datasets[-1])
+    m = 1
+    datasets_orig = copy.deepcopy(datasets)
+    while m < opt.multi:
+        datasets += datasets_orig
+        m += 1
+
+    loader =  InfiniteDataLoader  # only DataLoader allows for attribute updates
+    print("OPT_BATCHSIZE: ", opt.batch_size)
+    return loader(datasets,
+                  batch_size=opt.batch_size,
+                  shuffle=False,
+                  num_workers=1,
+                  sampler=None,
+                  pin_memory=True)
+
+
+def save_tensor_arr_to_file(arr, file_path):
+    write_sen = ""
+    for m in arr:
+        for l in m:
+            for c in l:
+                write_sen += str(c) + " "
+            write_sen += "\n"
+    with open(file_path, "w", encoding='utf-8') as f:
+        f.write(write_sen)
+
+def save_size_to_file(size, file_path):
+    write_sen = "" + str(size) + " "
+    with open(file_path, "w", encoding='utf-8') as f:
+        f.write(write_sen)
+
+def main(opt):
+    # load model
+    model = torch.jit.load(opt.model)
+    batch_size = opt.batch_size
+    torch_aie.set_device(opt.device_id)
+    if opt.need_compile:
+        inputs = []
+        inputs.append(torch_aie.Input([opt.batch_size, 100, 80], dtype=torch_aie.dtype.FLOAT))
+        inputs.append(torch_aie.Input([opt.batch_size], dtype=torch_aie.dtype.INT32))
+
+        model = torch_aie.compile(
+            model,
+            inputs=inputs,
+            precision_policy=_enums.PrecisionPolicy.FP16,
+            truncate_long_and_double=True,
+            require_full_compilation=False,
+            allow_tensor_replace_int=False,
+            min_block_size=3,
+            torch_executed_ops=[],
+            soc_version='Ascend310P3',
+            optimization_level=0)
+
+    dataloader = get_dataloader(opt)
+    pred_results = forward_infer(model, dataloader, batch_size, opt.device_id)
+    for index, res in enumerate(pred_results):
+        print(index, "  ", res)
+        print(res[0].shape)
+
+    if opt.batch_size == 1 and opt.multi == 1:
+        result_path = opt.result_path
+        if(os.path.exists(result_path) == False):
+                os.makedirs(result_path)
+        for index, res in enumerate(pred_results):
+            # for i in range(batch_size):
+            # result_fname_0 = 'data' + str(index * batch_size + i + 1) + '_0.txt'
+            # result_fname_1 = 'data' + str(index * batch_size + i + 1) + '_1.txt'
+            result_fname_0 = 'data' + str(index) + '_0.txt'
+            result_fname_1 = 'data' + str(index) + '_1.txt'
+            # res = np.array(res)
+            # save_tensor_arr_to_file(np.array(res[0][i]), os.path.join(result_path, result_fname_0))
+            # save_size_to_file(res[1].numpy()[i], os.path.join(result_path, result_fname_1))
+            save_tensor_arr_to_file(np.array(res[0]), os.path.join(result_path, result_fname_0))
+            save_size_to_file(res[1].numpy()[0], os.path.join(result_path, result_fname_1))
+
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='DeepSpeech2 offline model inference.')
+    parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version')
+    parser.add_argument('--model', type=str, default="./pt_compiled_model/encoder-epoch-12-avg-1_torch_aie_bs1.pt", help='ts model path')
+    parser.add_argument('--need_compile', action="store_true", help='if the loaded model needs to be compiled or not')
+    parser.add_argument('--batch_size', type=int, default=1, help='batch size')
+    parser.add_argument('--device_id', type=int, default=0, help='device id')
+    parser.add_argument('--data_file', default='./deepspeech.pytorch/data/an4_test_manifest.json')
+    parser.add_argument('--label_file', default='./deepspeech.pytorch/labels.json')
+    parser.add_argument('--result_path', default='result/encoder')
+    parser.add_argument('--multi', type=int, default=1, help='multiples of dataset replication for enough infer loop. if multi != 1, the pred result will not be stored.')
+    opt = parser.parse_args()
+    main(opt)
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_join.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_join.py
new file mode 100644
index 0000000000..a4fe80f4de
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_join.py
@@ -0,0 +1,203 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import copy
+import argparse
+
+import torch
+import torch_aie
+import numpy as np
+from torch_aie import _enums
+from torch.utils.data import dataloader
+
+from model_pt_join import forward_infer
+
+
+class InfiniteDataLoader(dataloader.DataLoader):
+    """ Dataloader that reuses workers
+
+    Uses same syntax as vanilla DataLoader
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
+        self.iterator = super().__iter__()
+
+    def __len__(self):
+        return len(self.batch_sampler.sampler)
+
+    def __iter__(self):
+        for i in range(len(self)):
+            yield next(self.iterator)
+
+
+class _RepeatSampler:
+    """ Sampler that repeats forever
+
+    Args:
+        sampler (Sampler)
+    """
+
+    def __init__(self, sampler):
+        self.sampler = sampler
+
+    def __iter__(self):
+        while True:
+            yield from iter(self.sampler)
+
+# def collate_fn(batch):
+#     """
+#     data preprocessing
+#     """
+#     def func(p):
+#         """
+#         data size
+#         """
+#         return p[0].size(1)
+
+#     batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True)
+#     longest_sample = max(batch, key=func)[0]
+#     freq_size = longest_sample.size(0)
+#     minibatch_size = len(batch)
+#     max_seqlength = longest_sample.size(1)
+#     inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength)
+#     input_percentages = torch.FloatTensor(minibatch_size)
+#     for x in range(minibatch_size):
+#         sample = batch[x]
+#         tensor = sample[0]
+#         seq_length = tensor.size(1)
+#         inputs[x][0].narrow(1, 0, seq_length).copy_(tensor)
+#         input_percentages[x] = seq_length / float(max_seqlength)
+#     return inputs, input_percentages, []
+
+
+def get_dataloader(opt):
+    # with open(to_absolute_path(opt.label_file)) as label_file:
+    #     labels = json.load(label_file)
+
+    # dataset = SpectrogramDataset(
+    #     audio_conf=DataConfig.spect,
+    #     input_path=opt.data_file,
+    #     labels=labels,
+    #     normalize=True,
+    #     aug_cfg=DataConfig.augmentation
+    # )
+    # inputs, input_percentages, _ = collate_fn(dataset)
+    # input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
+    # print(inputs[0])
+    # print(input_sizes.tolist())
+
+    # datasets = [[inputs[i], input_sizes[i]] for i in range(len(input_sizes))]
+    x = torch.zeros(512, dtype=torch.float32)
+    y = torch.zeros(512, dtype=torch.float32)
+    # x_lens = torch.tensor([100], dtype=torch.int64)
+    # x_lens = 100
+    datasets = []
+    for i in range(20):
+        datasets.append([copy.deepcopy(x), copy.deepcopy(y)])
+    # print(datasets)
+    while len(datasets) % opt.batch_size != 0:
+        datasets.append(datasets[-1])
+    m = 1
+    datasets_orig = copy.deepcopy(datasets)
+    while m < opt.multi:
+        datasets += datasets_orig
+        m += 1
+
+    loader =  InfiniteDataLoader  # only DataLoader allows for attribute updates
+    print("OPT_BATCHSIZE: ", opt.batch_size)
+    return loader(datasets,
+                  batch_size=opt.batch_size,
+                  shuffle=False,
+                  num_workers=1,
+                  sampler=None,
+                  pin_memory=True)
+
+
+def save_tensor_arr_to_file(arr, file_path):
+    write_sen = ""
+    for m in arr:
+        for l in m:
+            for c in l:
+                write_sen += str(c) + " "
+            write_sen += "\n"
+    with open(file_path, "w", encoding='utf-8') as f:
+        f.write(write_sen)
+
+def save_size_to_file(size, file_path):
+    write_sen = "" + str(size) + " "
+    with open(file_path, "w", encoding='utf-8') as f:
+        f.write(write_sen)
+
+def main(opt):
+    # load model
+    model = torch.jit.load(opt.model)
+    batch_size = opt.batch_size
+    torch_aie.set_device(opt.device_id)
+    if opt.need_compile:
+        inputs = []
+        inputs.append(torch_aie.Input([opt.batch_size, 512], dtype = torch_aie.dtype.FLOAT))
+        inputs.append(torch_aie.Input([opt.batch_size, 512], dtype = torch_aie.dtype.FLOAT))
+
+        model = torch_aie.compile(
+            model,
+            inputs=inputs,
+            precision_policy=_enums.PrecisionPolicy.FP16,
+            truncate_long_and_double=True,
+            require_full_compilation=False,
+            allow_tensor_replace_int=False,
+            min_block_size=3,
+            torch_executed_ops=[],
+            soc_version='Ascend310P3',
+            optimization_level=0)
+
+    dataloader = get_dataloader(opt)
+    pred_results = forward_infer(model, dataloader, batch_size, opt.device_id)
+    # for index, res in enumerate(pred_results):
+    #     print(index, "  ", res)
+    #     print(res[0].shape)
+
+    if opt.batch_size == 1 and opt.multi == 1:
+        result_path = opt.result_path
+        if(os.path.exists(result_path) == False):
+                os.makedirs(result_path)
+        for index, res in enumerate(pred_results):
+            # for i in range(batch_size):
+            # result_fname_0 = 'data' + str(index * batch_size + i + 1) + '_0.txt'
+            # result_fname_1 = 'data' + str(index * batch_size + i + 1) + '_1.txt'
+            result_fname_0 = 'data' + str(index) + '_0.txt'
+            # result_fname_1 = 'data' + str(index) + '_1.txt'
+            # res = np.array(res)
+            # save_tensor_arr_to_file(np.array(res[0][i]), os.path.join(result_path, result_fname_0))
+            # save_size_to_file(res[1].numpy()[i], os.path.join(result_path, result_fname_1))
+            save_tensor_arr_to_file(np.array(res), os.path.join(result_path, result_fname_0))
+            # save_size_to_file(res[1].numpy()[0], os.path.join(result_path, result_fname_1))
+
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='DeepSpeech2 offline model inference.')
+    parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version')
+    parser.add_argument('--model', type=str, default="./pt_compiled_model/joiner-epoch-12-avg-1_torch_aie_bs1.pt", help='ts model path')
+    parser.add_argument('--need_compile', action="store_true", help='if the loaded model needs to be compiled or not')
+    parser.add_argument('--batch_size', type=int, default=1, help='batch size')
+    parser.add_argument('--device_id', type=int, default=0, help='device id')
+    parser.add_argument('--data_file', default='./deepspeech.pytorch/data/an4_test_manifest.json')
+    parser.add_argument('--label_file', default='./deepspeech.pytorch/labels.json')
+    parser.add_argument('--result_path', default='result/joiner')
+    parser.add_argument('--multi', type=int, default=1, help='multiples of dataset replication for enough infer loop. if multi != 1, the pred result will not be stored.')
+    opt = parser.parse_args()
+    main(opt)
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/modify_decoder.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/modify_decoder.py
new file mode 100644
index 0000000000..866216820f
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/modify_decoder.py
@@ -0,0 +1,25 @@
+from argparse import ArgumentParser
+
+from auto_optimizer import OnnxGraph
+
+
+def main():
+    parser = ArgumentParser()
+    parser.add_argument("--onnx", type=str, required=True)
+    args = parser.parse_args()
+
+    graph = OnnxGraph.parse(args.onnx)
+    graph.remove("/decoder/Clip")
+    gather = graph["/decoder/embedding/Gather"]
+    gather.inputs[1] = "y"
+    graph.update_map()
+    graph.infershape()
+
+    g_sim = graph.simplify()
+    save_path = args.onnx.replace(".onnx", "_modified.onnx")
+    g_sim.save(save_path)
+    print("Modified model saved to ", save_path)
+
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.py
new file mode 100644
index 0000000000..163eb5f6b2
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.py
@@ -0,0 +1,2438 @@
+#!/usr/bin/env python3
+# Copyright    2022-2023  Xiaomi Corp.        (authors: Daniel Povey,
+#                                                       Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import logging
+import math
+import random
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+from encoder_interface import EncoderInterface
+from scaling import (
+    Identity,  # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
+)
+from scaling import (
+    ScaledLinear,  # not as in other dirs.. just scales down initial parameter values.
+)
+from scaling import (
+    ActivationDropoutAndLinear,
+    Balancer,
+    BiasNorm,
+    ChunkCausalDepthwiseConv1d,
+    Dropout2,
+    FloatLike,
+    ScheduledFloat,
+    Whiten,
+    convert_num_channels,
+    limit_param_value,
+    penalize_abs_values_gt,
+    softmax,
+)
+from torch import Tensor, nn
+
+
+class Zipformer2(EncoderInterface):
+    """
+    Args:
+
+    Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length
+    as downsampling_factor if they are single ints or one-element tuples.  The length of
+    downsampling_factor defines the number of stacks.
+
+        output_downsampling_factor (int): how much to downsample at the output.  Note:
+            we also downsample by a factor of 2 in the Conv2dSubsampling encoder.
+            You should probably leave this at 2.
+        downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
+           Note: this is in addition to the downsampling factor of 2 that is applied in
+           the frontend (self.encoder_embed).
+        encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
+           encoder stack.
+        num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
+        encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of
+            the encoder stacks for purposes of per-frame dropout (recommend 256 for
+            now).
+        query_head_dim (int or Tuple[int]): dimension of query and key per attention
+           head: per stack, if a tuple..
+        pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per
+           attention head
+        value_head_dim (int or Tuple[int]): dimension of value in each attention head
+        num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
+              Must be at least 4.
+        feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
+        cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
+
+        pos_dim (int): the dimension of each positional-encoding vector prior to projection,
+            e.g. 128.
+
+        dropout (float): dropout rate
+        warmup_batches (float): number of batches to warm up over; this controls
+          dropout of encoder layers.
+        causal (bool): if True, support chunkwise causal convolution.  This should
+          not hurt WER as no modeling power is lost, but the convolution modules will be
+          slightly slower and use more memory.  Enables use of the chunk_size and
+          left_context_chunks options in forward(), which simulates streaming
+          decoding.
+        chunk_size: (list of int): only set this to other than [-1] if causal;
+           the chunk size will be randomly chosen from this list.  -1 means no chunking.
+        left_context_frames: (list of int): determines the number of left-
+           context chunks for causal training; will be rounded to a number of
+           chunks.  Must not be less than cnn_module_kernel (after factoring in
+           rounding and downsampling); an error will be thrown if this is violated.
+    """
+
+    def __init__(
+        self,
+        output_downsampling_factor: int = 2,
+        downsampling_factor: Tuple[int] = (2, 4),
+        encoder_dim: Union[int, Tuple[int]] = 384,
+        num_encoder_layers: Union[int, Tuple[int]] = 4,
+        encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
+        query_head_dim: Union[int, Tuple[int]] = 24,
+        pos_head_dim: Union[int, Tuple[int]] = 4,
+        value_head_dim: Union[int, Tuple[int]] = 12,
+        num_heads: Union[int, Tuple[int]] = 8,
+        feedforward_dim: Union[int, Tuple[int]] = 1536,
+        cnn_module_kernel: Union[int, Tuple[int]] = 31,
+        pos_dim: int = 192,
+        dropout: FloatLike = None,  # see code below for default
+        warmup_batches: float = 4000.0,
+        causal: bool = False,
+        chunk_size: Tuple[int] = [-1],
+        left_context_frames: Tuple[int] = [-1],
+    ) -> None:
+        super(Zipformer2, self).__init__()
+
+        if dropout is None:
+            dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
+
+        def _to_tuple(x):
+            """Converts a single int or a 1-tuple of an int to a tuple with the same length
+            as downsampling_factor"""
+            if isinstance(x, int):
+                x = (x,)
+            if len(x) == 1:
+                x = x * len(downsampling_factor)
+            else:
+                assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
+            return x
+
+        self.output_downsampling_factor = output_downsampling_factor  # int
+        self.downsampling_factor = downsampling_factor  # tuple
+        self.encoder_dim = encoder_dim = _to_tuple(encoder_dim)  # tuple
+        self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(
+            encoder_unmasked_dim
+        )  # tuple
+        num_encoder_layers = _to_tuple(num_encoder_layers)
+        self.num_encoder_layers = num_encoder_layers
+        self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
+        self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
+        pos_head_dim = _to_tuple(pos_head_dim)
+        self.num_heads = num_heads = _to_tuple(num_heads)
+        feedforward_dim = _to_tuple(feedforward_dim)
+        self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
+
+        self.causal = causal
+        self.chunk_size = chunk_size
+        self.left_context_frames = left_context_frames
+
+        for u, d in zip(encoder_unmasked_dim, encoder_dim):
+            assert u <= d
+
+        # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
+        encoders = []
+
+        num_encoders = len(downsampling_factor)
+        for i in range(num_encoders):
+            encoder_layer = Zipformer2EncoderLayer(
+                embed_dim=encoder_dim[i],
+                pos_dim=pos_dim,
+                num_heads=num_heads[i],
+                query_head_dim=query_head_dim[i],
+                pos_head_dim=pos_head_dim[i],
+                value_head_dim=value_head_dim[i],
+                feedforward_dim=feedforward_dim[i],
+                dropout=dropout,
+                cnn_module_kernel=cnn_module_kernel[i],
+                causal=causal,
+            )
+
+            # For the segment of the warmup period, we let the Conv2dSubsampling
+            # layer learn something.  Then we start to warm up the other encoders.
+            encoder = Zipformer2Encoder(
+                encoder_layer,
+                num_encoder_layers[i],
+                pos_dim=pos_dim,
+                dropout=dropout,
+                warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
+                warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
+                final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
+            )
+
+            if downsampling_factor[i] != 1:
+                encoder = DownsampledZipformer2Encoder(
+                    encoder,
+                    dim=encoder_dim[i],
+                    downsample=downsampling_factor[i],
+                    dropout=dropout,
+                )
+
+            encoders.append(encoder)
+
+        self.encoders = nn.ModuleList(encoders)
+
+        self.downsample_output = SimpleDownsample(
+            max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout
+        )
+
+    def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]:
+        """
+        In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
+        randomized feature masks, one per encoder.
+        On e.g. 15% of frames, these masks will zero out all enocder dims larger than
+        some supplied number, e.g. >256, so in effect on those frames we are using
+        a smaller encoer dim.
+
+        We generate the random masks at this level because we want the 2 masks to 'agree'
+        all the way up the encoder stack. This will mean that the 1st mask will have
+        mask values repeated self.zipformer_subsampling_factor times.
+
+        Args:
+           x: the embeddings (needed for the shape and dtype and device), of shape
+             (1, batch_size, encoder_dims0)
+        """
+        num_encoders = len(self.encoder_dim)
+        if not self.training:
+            return [1.0] * num_encoders
+
+        (num_frames0, batch_size, _encoder_dims0) = x.shape
+
+        assert self.encoder_dim[0] == _encoder_dims0, (
+            self.encoder_dim[0],
+            _encoder_dims0,
+        )
+
+        feature_mask_dropout_prob = 0.125
+
+        # mask1 shape: (1, batch_size, 1)
+        mask1 = (
+            torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob
+        ).to(x.dtype)
+
+        # mask2 has additional sequences masked, about twice the number.
+        mask2 = torch.logical_and(
+            mask1,
+            (
+                torch.rand(1, batch_size, 1, device=x.device)
+                > feature_mask_dropout_prob
+            ).to(x.dtype),
+        )
+
+        # dim: (1, batch_size, 2)
+        mask = torch.cat((mask1, mask2), dim=-1)
+
+        feature_masks = []
+        for i in range(num_encoders):
+            channels = self.encoder_dim[i]
+            feature_mask = torch.ones(
+                1, batch_size, channels, dtype=x.dtype, device=x.device
+            )
+            u1 = self.encoder_unmasked_dim[i]
+            u2 = u1 + (channels - u1) // 2
+
+            feature_mask[:, :, u1:u2] *= mask[..., 0:1]
+            feature_mask[:, :, u2:] *= mask[..., 1:2]
+
+            feature_masks.append(feature_mask)
+
+        return feature_masks
+
+    def get_chunk_info(self) -> Tuple[int, int]:
+        """
+        Returns chunk_size and left_context_chunks.
+        """
+        if not self.causal:
+            return -1, -1
+
+        if torch.jit.is_scripting() or torch.jit.is_tracing():
+            assert len(self.chunk_size) == 1, self.chunk_size
+            chunk_size = self.chunk_size[0]
+        else:
+            chunk_size = random.choice(self.chunk_size)
+
+        if chunk_size == -1:
+            left_context_chunks = -1
+        else:
+            if torch.jit.is_scripting() or torch.jit.is_tracing():
+                assert len(self.left_context_frames) == 1, self.left_context_frames
+                left_context_frames = self.left_context_frames[0]
+            else:
+                left_context_frames = random.choice(self.left_context_frames)
+            # Note: in Python, -1 // n == -1 for n > 0
+            left_context_chunks = left_context_frames // chunk_size
+            if left_context_chunks == 0:
+                left_context_chunks = 1
+
+        return chunk_size, left_context_chunks
+
+    def forward(
+        self,
+        x: Tensor,
+        x_lens: Tensor,
+        src_key_padding_mask: Optional[Tensor] = None,
+    ) -> Tuple[Tensor, Tensor]:
+        """
+        Args:
+          x:
+            The input tensor. Its shape is (seq_len, batch_size, feature_dim).
+          x_lens:
+            A tensor of shape (batch_size,) containing the number of frames in
+            `x` before padding.
+          src_key_padding_mask:
+            The mask for padding, of shape (batch_size, seq_len); True means
+            masked position. May be None.
+        Returns:
+          Return a tuple containing 2 tensors:
+            - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
+            - lengths, a tensor of shape (batch_size,) containing the number
+              of frames in `embeddings` before padding.
+        """
+        outputs = []
+        if torch.jit.is_scripting() or torch.jit.is_tracing():
+            feature_masks = [1.0] * len(self.encoder_dim)
+        else:
+            feature_masks = self.get_feature_masks(x)
+
+        chunk_size, left_context_chunks = self.get_chunk_info()
+
+        if torch.jit.is_scripting() or torch.jit.is_tracing():
+            # Not support exporting a model for simulating streaming decoding
+            attn_mask = None
+        else:
+            attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
+
+        for i, module in enumerate(self.encoders):
+            ds = self.downsampling_factor[i]
+            x = convert_num_channels(x, self.encoder_dim[i])
+
+            x = module(
+                x,
+                chunk_size=chunk_size,
+                feature_mask=feature_masks[i],
+                src_key_padding_mask=(
+                    None
+                    if src_key_padding_mask is None
+                    else src_key_padding_mask[..., ::ds]
+                ),
+                attn_mask=attn_mask,
+            )
+            outputs.append(x)
+
+        # if the last output has the largest dimension, x will be unchanged,
+        # it will be the same as outputs[-1].  Otherwise it will be concatenated
+        # from different pieces of 'outputs', taking each dimension from the
+        # most recent output that has it present.
+        x = self._get_full_dim_output(outputs)
+        x = self.downsample_output(x)
+        # class Downsample has this rounding behavior..
+        assert self.output_downsampling_factor == 2, self.output_downsampling_factor
+        if torch.jit.is_scripting() or torch.jit.is_tracing():
+            lengths = (x_lens + 1) // 2
+        else:
+            with warnings.catch_warnings():
+                warnings.simplefilter("ignore")
+                lengths = (x_lens + 1) // 2
+
+        return x, lengths
+
+    def _get_attn_mask(
+        self, x: Tensor, chunk_size: int, left_context_chunks: int
+    ) -> Optional[Tensor]:
+        """
+        Return None if chunk_size == -1, else return attention mask of shape
+          (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len).  True
+           means a masked position.
+        Args:
+           x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
+          chunk_size: chunk size, must divide
+        """
+        if chunk_size <= 0:
+            return None
+        assert all(chunk_size % d == 0 for d in self.downsampling_factor)
+        if left_context_chunks >= 0:
+            num_encoders = len(self.encoder_dim)
+            assert all(
+                chunk_size * left_context_chunks
+                >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i]
+                for i in range(num_encoders)
+            )
+        else:
+            left_context_chunks = 1000000
+
+        seq_len = x.shape[0]
+
+        # t is frame index, shape (seq_len,)
+        t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
+        # c is chunk index for each frame, shape (seq_len,)
+        if torch.jit.is_scripting() or torch.jit.is_tracing():
+            c = t // chunk_size
+        else:
+            with warnings.catch_warnings():
+                warnings.simplefilter("ignore")
+                c = t // chunk_size
+        src_c = c
+        tgt_c = c.unsqueeze(-1)
+
+        attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks)
+        if __name__ == "__main__":
+            logging.info(f"attn_mask = {attn_mask}")
+        return attn_mask
+
+    def _get_full_dim_output(self, outputs: List[Tensor]):
+        num_encoders = len(self.encoder_dim)
+        assert len(outputs) == num_encoders
+        output_dim = max(self.encoder_dim)
+        output_pieces = [outputs[-1]]
+        cur_dim = self.encoder_dim[-1]
+        for i in range(num_encoders - 2, -1, -1):
+            d = self.encoder_dim[i]
+            if d > cur_dim:
+                this_output = outputs[i]
+                output_pieces.append(this_output[..., cur_dim:d])
+                cur_dim = d
+        assert cur_dim == output_dim
+        return torch.cat(output_pieces, dim=-1)
+
+    def streaming_forward(
+        self,
+        x: Tensor,
+        x_lens: Tensor,
+        states: List[Tensor],
+        src_key_padding_mask: Tensor,
+    ) -> Tuple[Tensor, Tensor, List[Tensor]]:
+        """
+        Args:
+          x:
+            The input tensor. Its shape is (seq_len, batch_size, feature_dim).
+          x_lens:
+            A tensor of shape (batch_size,) containing the number of frames in
+            `x` before padding.
+          states: list of cached tensors of all encoder layers. For layer-i,
+            states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
+            cached_conv1, cached_conv2).
+          src_key_padding_mask:
+            The mask for padding, of shape (batch_size, seq_len); True means
+            masked position. May be None.
+        Returns:
+          Return a tuple containing 2 tensors:
+            - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
+            - lengths, a tensor of shape (batch_size,) containing the number
+              of frames in `embeddings` before padding.
+            - updated states
+        """
+        outputs = []
+        new_states = []
+        layer_offset = 0
+
+        for i, module in enumerate(self.encoders):
+            num_layers = module.num_layers
+            ds = self.downsampling_factor[i]
+            x = convert_num_channels(x, self.encoder_dim[i])
+
+            x, new_layer_states = module.streaming_forward(
+                x,
+                states=states[layer_offset * 6 : (layer_offset + num_layers) * 6],
+                left_context_len=self.left_context_frames[0] // ds,
+                src_key_padding_mask=src_key_padding_mask[..., ::ds],
+            )
+            layer_offset += num_layers
+            outputs.append(x)
+            new_states += new_layer_states
+
+        # if the last output has the largest dimension, x will be unchanged,
+        # it will be the same as outputs[-1].  Otherwise it will be concatenated
+        # from different pieces of 'outputs', taking each dimension from the
+        # most recent output that has it present.
+        x = self._get_full_dim_output(outputs)
+        x = self.downsample_output(x)
+        # class Downsample has this rounding behavior..
+        assert self.output_downsampling_factor == 2
+        if torch.jit.is_scripting() or torch.jit.is_tracing():
+            lengths = (x_lens + 1) // 2
+        else:
+            with warnings.catch_warnings():
+                warnings.simplefilter("ignore")
+                lengths = (x_lens + 1) // 2
+
+        return x, lengths, new_states
+
+    @torch.jit.export
+    def get_init_states(
+        self,
+        batch_size: int = 1,
+        device: torch.device = torch.device("cpu"),
+    ) -> List[Tensor]:
+        """Get initial states.
+
+        A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
+        is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+        """
+        states = []
+        for i, module in enumerate(self.encoders):
+            num_layers = module.num_layers
+            embed_dim = self.encoder_dim[i]
+            ds = self.downsampling_factor[i]
+            num_heads = self.num_heads[i]
+            key_dim = self.query_head_dim[i] * num_heads
+            value_dim = self.value_head_dim[i] * num_heads
+            downsample_left = self.left_context_frames[0] // ds
+            nonlin_attn_head_dim = 3 * embed_dim // 4
+            conv_left_pad = self.cnn_module_kernel[i] // 2
+            for layer in range(num_layers):
+                cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(
+                    device
+                )
+                cached_nonlin_attn = torch.zeros(
+                    1, batch_size, downsample_left, nonlin_attn_head_dim
+                ).to(device)
+                cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(
+                    device
+                )
+                cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(
+                    device
+                )
+                cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
+                    device
+                )
+                cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
+                    device
+                )
+                states += [
+                    cached_key,
+                    cached_nonlin_attn,
+                    cached_val1,
+                    cached_val2,
+                    cached_conv1,
+                    cached_conv2,
+                ]
+
+        return states
+
+
+def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
+    return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x)
+
+
+def _balancer_schedule(min_prob: float):
+    return ScheduledFloat((0.0, 0.4), (8000.0, min_prob))
+
+
+class Zipformer2EncoderLayer(nn.Module):
+    """
+    Args:
+        embed_dim: the number of expected features in the input (required).
+        nhead: the number of heads in the multiheadattention models (required).
+        feedforward_dim: the dimension of the feedforward network model (default=2048).
+        dropout: the dropout value (default=0.1).
+        cnn_module_kernel (int): Kernel size of convolution module.
+
+    Examples::
+        >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
+        >>> src = torch.rand(10, 32, 512)
+        >>> pos_emb = torch.rand(32, 19, 512)
+        >>> out = encoder_layer(src, pos_emb)
+    """
+
+    def __init__(
+        self,
+        embed_dim: int,
+        pos_dim: int,
+        num_heads: int,
+        query_head_dim: int,
+        pos_head_dim: int,
+        value_head_dim: int,
+        feedforward_dim: int,
+        dropout: FloatLike = 0.1,
+        cnn_module_kernel: int = 31,
+        causal: bool = False,
+        attention_skip_rate: FloatLike = ScheduledFloat(
+            (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
+        ),
+        conv_skip_rate: FloatLike = ScheduledFloat(
+            (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
+        ),
+        const_attention_rate: FloatLike = ScheduledFloat(
+            (0.0, 0.25), (4000.0, 0.025), default=0
+        ),
+        ff2_skip_rate: FloatLike = ScheduledFloat(
+            (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
+        ),
+        ff3_skip_rate: FloatLike = ScheduledFloat(
+            (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
+        ),
+        bypass_skip_rate: FloatLike = ScheduledFloat(
+            (0.0, 0.5), (4000.0, 0.02), default=0
+        ),
+    ) -> None:
+        super(Zipformer2EncoderLayer, self).__init__()
+        self.embed_dim = embed_dim
+
+        # self.bypass implements layer skipping as well as bypass; see its default values.
+        self.bypass = BypassModule(
+            embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0
+        )
+        # bypass_mid is bypass used in the middle of the layer.
+        self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0)
+
+        # skip probability for dynamic modules (meaning: anything but feedforward).
+        self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
+        # an additional skip probability that applies to ConvModule to stop it from
+        # contributing too much early on.
+        self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
+
+        # ff2_skip_rate is to prevent the ff2 module from having output that's too big
+        # compared to its residual.
+        self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
+        self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
+
+        self.const_attention_rate = copy.deepcopy(const_attention_rate)
+
+        self.self_attn_weights = RelPositionMultiheadAttentionWeights(
+            embed_dim,
+            pos_dim=pos_dim,
+            num_heads=num_heads,
+            query_head_dim=query_head_dim,
+            pos_head_dim=pos_head_dim,
+            dropout=0.0,
+        )
+
+        self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
+
+        self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim)
+
+        self.feed_forward1 = FeedforwardModule(
+            embed_dim, (feedforward_dim * 3) // 4, dropout
+        )
+
+        self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout)
+
+        self.feed_forward3 = FeedforwardModule(
+            embed_dim, (feedforward_dim * 5) // 4, dropout
+        )
+
+        self.nonlin_attention = NonlinAttention(
+            embed_dim, hidden_channels=3 * embed_dim // 4
+        )
+
+        self.conv_module1 = ConvolutionModule(
+            embed_dim, cnn_module_kernel, causal=causal
+        )
+
+        self.conv_module2 = ConvolutionModule(
+            embed_dim, cnn_module_kernel, causal=causal
+        )
+
+        # TODO: remove it
+        self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
+
+        self.norm = BiasNorm(embed_dim)
+
+        self.balancer1 = Balancer(
+            embed_dim,
+            channel_dim=-1,
+            min_positive=0.45,
+            max_positive=0.55,
+            min_abs=0.2,
+            max_abs=4.0,
+        )
+
+        # balancer for output of NonlinAttentionModule
+        self.balancer_na = Balancer(
+            embed_dim,
+            channel_dim=-1,
+            min_positive=0.3,
+            max_positive=0.7,
+            min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
+            prob=0.05,  # out of concern for memory usage
+        )
+
+        # balancer for output of feedforward2, prevent it from staying too
+        # small.  give this a very small probability, even at the start of
+        # training, it's to fix a rare problem and it's OK to fix it slowly.
+        self.balancer_ff2 = Balancer(
+            embed_dim,
+            channel_dim=-1,
+            min_positive=0.3,
+            max_positive=0.7,
+            min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0),
+            max_abs=2.0,
+            prob=0.05,
+        )
+
+        self.balancer_ff3 = Balancer(
+            embed_dim,
+            channel_dim=-1,
+            min_positive=0.3,
+            max_positive=0.7,
+            min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0),
+            max_abs=4.0,
+            prob=0.05,
+        )
+
+        self.whiten = Whiten(
+            num_groups=1,
+            whitening_limit=_whitening_schedule(4.0, ratio=3.0),
+            prob=(0.025, 0.25),
+            grad_scale=0.01,
+        )
+
+        self.balancer2 = Balancer(
+            embed_dim,
+            channel_dim=-1,
+            min_positive=0.45,
+            max_positive=0.55,
+            min_abs=0.1,
+            max_abs=4.0,
+        )
+
+    def get_sequence_dropout_mask(
+        self, x: Tensor, dropout_rate: float
+    ) -> Optional[Tensor]:
+        if (
+            dropout_rate == 0.0
+            or not self.training
+            or torch.jit.is_scripting()
+            or torch.jit.is_tracing()
+        ):
+            return None
+        batch_size = x.shape[1]
+        mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
+        return mask
+
+    def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
+        """
+        Apply sequence-level dropout to x.
+        x shape: (seq_len, batch_size, embed_dim)
+        """
+        dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
+        if dropout_mask is None:
+            return x
+        else:
+            return x * dropout_mask
+
+    def forward(
+        self,
+        src: Tensor,
+        pos_emb: Tensor,
+        chunk_size: int = -1,
+        attn_mask: Optional[Tensor] = None,
+        src_key_padding_mask: Optional[Tensor] = None,
+    ) -> Tensor:
+        """
+            Pass the input through the encoder layer.
+            Args:
+                src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+             pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
+             chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
+           feature_mask: something that broadcasts with src, that we'll multiply `src`
+                  by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
+             attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
+                    interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+                   True means masked position. May be None.
+        src_key_padding_mask:  the mask for padding, of shape (batch_size, seq_len); True means
+                 masked position.  May be None.
+
+            Returns:
+               A tensor which has the same shape as src
+        """
+        src_orig = src
+
+        # dropout rate for non-feedforward submodules
+        if torch.jit.is_scripting() or torch.jit.is_tracing():
+            attention_skip_rate = 0.0
+        else:
+            attention_skip_rate = (
+                float(self.attention_skip_rate) if self.training else 0.0
+            )
+
+        # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+        attn_weights = self.self_attn_weights(
+            src,
+            pos_emb=pos_emb,
+            attn_mask=attn_mask,
+            key_padding_mask=src_key_padding_mask,
+        )
+
+        src = src + self.feed_forward1(src)
+
+        self_attn_dropout_mask = self.get_sequence_dropout_mask(
+            src, attention_skip_rate
+        )
+
+        selected_attn_weights = attn_weights[0:1]
+        if torch.jit.is_scripting() or torch.jit.is_tracing():
+            pass
+        elif not self.training and random.random() < float(self.const_attention_rate):
+            # Make attention weights constant.  The intention is to
+            # encourage these modules to do something similar to an
+            # averaging-over-time operation.
+            # only need the mask, can just use the 1st one and expand later
+            selected_attn_weights = selected_attn_weights[0:1]
+            selected_attn_weights = (selected_attn_weights > 0.0).to(
+                selected_attn_weights.dtype
+            )
+            selected_attn_weights = selected_attn_weights * (
+                1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
+            )
+
+        na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
+
+        src = src + (
+            na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
+        )
+
+        self_attn = self.self_attn1(src, attn_weights)
+
+        src = src + (
+            self_attn
+            if self_attn_dropout_mask is None
+            else self_attn * self_attn_dropout_mask
+        )
+
+        if torch.jit.is_scripting() or torch.jit.is_tracing():
+            conv_skip_rate = 0.0
+        else:
+            conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
+        src = src + self.sequence_dropout(
+            self.conv_module1(
+                src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
+            ),
+            conv_skip_rate,
+        )
+
+        if torch.jit.is_scripting() or torch.jit.is_tracing():
+            ff2_skip_rate = 0.0
+        else:
+            ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
+        src = src + self.sequence_dropout(
+            self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate
+        )
+
+        # bypass in the middle of the layer.
+        src = self.bypass_mid(src_orig, src)
+
+        self_attn = self.self_attn2(src, attn_weights)
+
+        src = src + (
+            self_attn
+            if self_attn_dropout_mask is None
+            else self_attn * self_attn_dropout_mask
+        )
+
+        if torch.jit.is_scripting() or torch.jit.is_tracing():
+            conv_skip_rate = 0.0
+        else:
+            conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
+        src = src + self.sequence_dropout(
+            self.conv_module2(
+                src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
+            ),
+            conv_skip_rate,
+        )
+
+        if torch.jit.is_scripting() or torch.jit.is_tracing():
+            ff3_skip_rate = 0.0
+        else:
+            ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
+        src = src + self.sequence_dropout(
+            self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate
+        )
+
+        src = self.balancer1(src)
+        src = self.norm(src)
+
+        src = self.bypass(src_orig, src)
+
+        src = self.balancer2(src)
+        src = self.whiten(src)
+
+        return src
+
+    def streaming_forward(
+        self,
+        src: Tensor,
+        pos_emb: Tensor,
+        cached_key: Tensor,
+        cached_nonlin_attn: Tensor,
+        cached_val1: Tensor,
+        cached_val2: Tensor,
+        cached_conv1: Tensor,
+        cached_conv2: Tensor,
+        left_context_len: int,
+        src_key_padding_mask: Tensor,
+    ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
+        """Pass the input through the encoder layer in streaming forward mode.
+
+        Args:
+            src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+            pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or
+              (batch_size, left_context_len+2*seq_len-1, pos_emb_dim)
+            cached_key: cached attention key tensor of left context,
+              of shape (left_context_len, batch_size, key_dim)
+            cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape
+              (num_heads, batch_size, left_context_len, head_dim)
+            cached_val1: cached left context for the first attention module,
+              of shape (left_context_len, batch_size, value_dim)
+            cached_val2: cached left context for the second attention module,
+              of shape (left_context_len, batch_size, value_dim)
+            cached_conv1: cached left context for the first convolution module,
+              of shape (batch_size, channels, left_pad)
+            cached_conv2: cached left context for the second convolution module,
+              of shape (batch_size, channels, left_pad)
+            left_context_len: number of left context frames.
+            src_key_padding_mask:  the mask for padding, of shape
+              (batch_size, left_context_len + seq_len); True means masked position.
+              May be None.
+
+        Returns:
+            - x, with the same shape as src
+            - updated cached_key
+            - updated cached_nonlin_attn
+            - updated cached_val1
+            - updated cached_val2
+            - updated cached_conv1
+            - updated cached_conv2
+        """
+        src_orig = src
+
+        # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+        attn_weights, cached_key = self.self_attn_weights.streaming_forward(
+            src,
+            pos_emb=pos_emb,
+            cached_key=cached_key,
+            left_context_len=left_context_len,
+            key_padding_mask=src_key_padding_mask,
+        )
+
+        src = src + self.feed_forward1(src)
+
+        na, cached_nonlin_attn = self.nonlin_attention.streaming_forward(
+            src,
+            attn_weights[0:1],
+            cached_x=cached_nonlin_attn,
+            left_context_len=left_context_len,
+        )
+        src = src + na
+
+        self_attn, cached_val1 = self.self_attn1.streaming_forward(
+            src,
+            attn_weights=attn_weights,
+            cached_val=cached_val1,
+            left_context_len=left_context_len,
+        )
+        src = src + self_attn
+
+        src_conv, cached_conv1 = self.conv_module1.streaming_forward(
+            src,
+            cache=cached_conv1,
+            src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
+        )
+        src = src + src_conv
+
+        src = src + self.feed_forward2(src)
+
+        # bypass in the middle of the layer.
+        src = self.bypass_mid(src_orig, src)
+
+        self_attn, cached_val2 = self.self_attn2.streaming_forward(
+            src,
+            attn_weights=attn_weights,
+            cached_val=cached_val2,
+            left_context_len=left_context_len,
+        )
+        src = src + self_attn
+
+        src_conv, cached_conv2 = self.conv_module2.streaming_forward(
+            src,
+            cache=cached_conv2,
+            src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
+        )
+        src = src + src_conv
+
+        src = src + self.feed_forward3(src)
+
+        src = self.norm(src)
+
+        src = self.bypass(src_orig, src)
+
+        return (
+            src,
+            cached_key,
+            cached_nonlin_attn,
+            cached_val1,
+            cached_val2,
+            cached_conv1,
+            cached_conv2,
+        )
+
+
+class Zipformer2Encoder(nn.Module):
+    r"""Zipformer2Encoder is a stack of N encoder layers
+
+    Args:
+        encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
+        num_layers: the number of sub-encoder-layers in the encoder (required).
+       pos_dim: the dimension for the relative positional encoding
+
+    Examples::
+        >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
+        >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
+        >>> src = torch.rand(10, 32, 512)
+        >>> out = zipformer_encoder(src)
+    """
+
+    def __init__(
+        self,
+        encoder_layer: nn.Module,
+        num_layers: int,
+        pos_dim: int,
+        dropout: float,
+        warmup_begin: float,
+        warmup_end: float,
+        initial_layerdrop_rate: float = 0.5,
+        final_layerdrop_rate: float = 0.05,
+    ) -> None:
+        super().__init__()
+        self.encoder_pos = CompactRelPositionalEncoding(
+            pos_dim, dropout_rate=0.15, length_factor=1.0
+        )
+
+        self.layers = nn.ModuleList(
+            [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+        )
+        self.num_layers = num_layers
+
+        assert 0 <= warmup_begin <= warmup_end
+
+        delta = (1.0 / num_layers) * (warmup_end - warmup_begin)
+        cur_begin = warmup_begin  # interpreted as a training batch index
+        for i in range(num_layers):
+            cur_end = cur_begin + delta
+            self.layers[i].bypass.skip_rate = ScheduledFloat(
+                (cur_begin, initial_layerdrop_rate),
+                (cur_end, final_layerdrop_rate),
+                default=0.0,
+            )
+            cur_begin = cur_end
+
+    def forward(
+        self,
+        src: Tensor,
+        chunk_size: int = -1,
+        feature_mask: Union[Tensor, float] = 1.0,
+        attn_mask: Optional[Tensor] = None,
+        src_key_padding_mask: Optional[Tensor] = None,
+    ) -> Tensor:
+        r"""Pass the input through the encoder layers in turn.
+
+        Args:
+            src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+            chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
+            feature_mask: something that broadcasts with src, that we'll multiply `src`
+               by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
+            attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
+                 interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+                 True means masked position. May be None.
+            src_key_padding_mask:  the mask for padding, of shape (batch_size, seq_len); True means
+                 masked position.  May be None.
+
+        Returns: a Tensor with the same shape as src.
+        """
+        pos_emb = self.encoder_pos(src)
+        output = src
+
+        if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+            output = output * feature_mask
+
+        for i, mod in enumerate(self.layers):
+            output = mod(
+                output,
+                pos_emb,
+                chunk_size=chunk_size,
+                attn_mask=attn_mask,
+                src_key_padding_mask=src_key_padding_mask,
+            )
+
+            if not torch.jit.is_scripting() and not torch.jit.is_tracing():
+                output = output * feature_mask
+
+        return output
+
+    def streaming_forward(
+        self,
+        src: Tensor,
+        states: List[Tensor],
+        left_context_len: int,
+        src_key_padding_mask: Tensor,
+    ) -> Tuple[Tensor, List[Tensor]]:
+        r"""Pass the input through the encoder layers in turn.
+
+        Args:
+            src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+            states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
+              (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+            left_context_len: Number of left context frames.
+            src_key_padding_mask:  the mask for padding, of shape
+              (batch_size, left_context_len + seq_len); True means masked position.
+              May be None.
+
+        Returns:
+          - output, a Tensor with the same shape as src.
+          - updated states
+        """
+        pos_emb = self.encoder_pos(src, left_context_len)
+        output = src
+
+        new_states = []
+        for i, mod in enumerate(self.layers):
+            (
+                cached_key,
+                cached_nonlin_attn,
+                cached_val1,
+                cached_val2,
+                cached_conv1,
+                cached_conv2,
+            ) = states[i * 6 : (i + 1) * 6]
+            (
+                output,
+                new_cached_key,
+                new_cached_nonlin_attn,
+                new_cached_val1,
+                new_cached_val2,
+                new_cached_conv1,
+                new_cached_conv2,
+            ) = mod.streaming_forward(
+                output,
+                pos_emb,
+                cached_key=cached_key,
+                cached_nonlin_attn=cached_nonlin_attn,
+                cached_val1=cached_val1,
+                cached_val2=cached_val2,
+                cached_conv1=cached_conv1,
+                cached_conv2=cached_conv2,
+                left_context_len=left_context_len,
+                src_key_padding_mask=src_key_padding_mask,
+            )
+            new_states += [
+                new_cached_key,
+                new_cached_nonlin_attn,
+                new_cached_val1,
+                new_cached_val2,
+                new_cached_conv1,
+                new_cached_conv2,
+            ]
+
+        return output, new_states
+
+
+class BypassModule(nn.Module):
+    """
+    An nn.Module that implements a learnable bypass scale, and also randomized per-sequence
+    layer-skipping.  The bypass is limited during early stages of training to be close to
+    "straight-through", i.e. to not do the bypass operation much initially, in order to
+    force all the modules to learn something.
+    """
+
+    def __init__(
+        self,
+        embed_dim: int,
+        skip_rate: FloatLike = 0.0,
+        straight_through_rate: FloatLike = 0.0,
+        scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
+        scale_max: FloatLike = 1.0,
+    ):
+        super().__init__()
+        self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
+        self.skip_rate = copy.deepcopy(skip_rate)
+        self.straight_through_rate = copy.deepcopy(straight_through_rate)
+        self.scale_min = copy.deepcopy(scale_min)
+        self.scale_max = copy.deepcopy(scale_max)
+
+    def _get_bypass_scale(self, batch_size: int):
+        # returns bypass-scale of shape (num_channels,),
+        # or (batch_size, num_channels,).  This is actually the
+        # scale on the non-residual term, so 0 correponds to bypassing
+        # this module.
+        if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
+            return self.bypass_scale
+        else:
+            ans = limit_param_value(
+                self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max)
+            )
+            skip_rate = float(self.skip_rate)
+            if skip_rate != 0.0:
+                mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
+                ans = ans * mask
+                # now ans is of shape (batch_size, num_channels), and is zero for sequences
+                # on which we have randomly chosen to do layer-skipping.
+            straight_through_rate = float(self.straight_through_rate)
+            if straight_through_rate != 0.0:
+                mask = (
+                    torch.rand((batch_size, 1), device=ans.device)
+                    < straight_through_rate
+                )
+                ans = torch.maximum(ans, mask.to(ans.dtype))
+            return ans
+
+    def forward(self, src_orig: Tensor, src: Tensor):
+        """
+        Args: src_orig and src are both of shape (seq_len, batch_size, num_channels)
+        Returns: something with the same shape as src and src_orig
+        """
+        bypass_scale = self._get_bypass_scale(src.shape[1])
+        return src_orig + (src - src_orig) * bypass_scale
+
+
+class DownsampledZipformer2Encoder(nn.Module):
+    r"""
+    DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
+    after convolutional downsampling, and then upsampled again at the output, and combined
+    with the origin input, so that the output has the same shape as the input.
+    """
+
+    def __init__(
+        self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike
+    ):
+        super(DownsampledZipformer2Encoder, self).__init__()
+        self.downsample_factor = downsample
+        self.downsample = SimpleDownsample(dim, downsample, dropout)
+        self.num_layers = encoder.num_layers
+        self.encoder = encoder
+        self.upsample = SimpleUpsample(dim, downsample)
+        self.out_combiner = BypassModule(dim, straight_through_rate=0)
+
+    def forward(
+        self,
+        src: Tensor,
+        chunk_size: int = -1,
+        feature_mask: Union[Tensor, float] = 1.0,
+        attn_mask: Optional[Tensor] = None,
+        src_key_padding_mask: Optional[Tensor] = None,
+    ) -> Tensor:
+        r"""Downsample, go through encoder, upsample.
+
+        Args:
+            src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+            feature_mask: something that broadcasts with src, that we'll multiply `src`
+               by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
+            attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
+                 interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
+                 True means masked position. May be None.
+            src_key_padding_mask:  the mask for padding, of shape (batch_size, seq_len); True means
+                 masked position.  May be None.
+
+        Returns: a Tensor with the same shape as src.
+        """
+        src_orig = src
+        src = self.downsample(src)
+        ds = self.downsample_factor
+        if attn_mask is not None:
+            attn_mask = attn_mask[::ds, ::ds]
+
+        src = self.encoder(
+            src,
+            chunk_size=chunk_size // ds,
+            feature_mask=feature_mask,
+            attn_mask=attn_mask,
+            src_key_padding_mask=src_key_padding_mask,
+        )
+        src = self.upsample(src)
+        # remove any extra frames that are not a multiple of downsample_factor
+        src = src[: src_orig.shape[0]]
+
+        return self.out_combiner(src_orig, src)
+
+    def streaming_forward(
+        self,
+        src: Tensor,
+        states: List[Tensor],
+        left_context_len: int,
+        src_key_padding_mask: Tensor,
+    ) -> Tuple[Tensor, List[Tensor]]:
+        r"""Downsample, go through encoder, upsample, in streaming forward mode.
+
+        Args:
+            src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
+            states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
+              (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
+            left_context_len: Number of left context frames.
+            src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len);
+              True means masked position. May be None.
+
+        Returns:
+            - output, a Tensor with the same shape as src.
+            - updated states
+        """
+        src_orig = src
+        src = self.downsample(src)
+
+        src, new_states = self.encoder.streaming_forward(
+            src,
+            states=states,
+            left_context_len=left_context_len,
+            src_key_padding_mask=src_key_padding_mask,
+        )
+        src = self.upsample(src)
+        # remove any extra frames that are not a multiple of downsample_factor
+        src = src[: src_orig.shape[0]]
+
+        return self.out_combiner(src_orig, src), new_states
+
+
+class SimpleDownsample(torch.nn.Module):
+    """
+    Does downsampling with attention, by weighted sum, and a projection..
+    """
+
+    def __init__(self, channels: int, downsample: int, dropout: FloatLike):
+        super(SimpleDownsample, self).__init__()
+
+        self.bias = nn.Parameter(torch.zeros(downsample))
+
+        self.name = None  # will be set from training code
+        self.dropout = copy.deepcopy(dropout)
+
+        self.downsample = downsample
+
+    def forward(self, src: Tensor) -> Tensor:
+        """
+        x: (seq_len, batch_size, in_channels)
+        Returns a tensor of shape
+           ( (seq_len+downsample-1)//downsample, batch_size, channels)
+        """
+        (seq_len, batch_size, in_channels) = src.shape
+        ds = self.downsample
+        d_seq_len = (seq_len + ds - 1) // ds
+
+        # Pad to an exact multiple of self.downsample
+        # right-pad src, repeating the last element.
+        pad = d_seq_len * ds - seq_len
+        src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
+        src = torch.cat((src, src_extra), dim=0)
+        assert src.shape[0] == d_seq_len * ds
+
+        src = src.reshape(d_seq_len, ds, batch_size, in_channels)
+
+        weights = self.bias.softmax(dim=0)
+        # weights: (downsample, 1, 1)
+        weights = weights.unsqueeze(-1).unsqueeze(-1)
+
+        # ans1 is the first `in_channels` channels of the output
+        ans = (src * weights).sum(dim=1)
+
+        return ans
+
+
+class SimpleUpsample(torch.nn.Module):
+    """
+    A very simple form of upsampling that mostly just repeats the input, but
+    also adds a position-specific bias.
+    """
+
+    def __init__(self, num_channels: int, upsample: int):
+        super(SimpleUpsample, self).__init__()
+        self.upsample = upsample
+
+    def forward(self, src: Tensor) -> Tensor:
+        """
+        x: (seq_len, batch_size, num_channels)
+        Returns a tensor of shape
+           ( (seq_len*upsample), batch_size, num_channels)
+        """
+        upsample = self.upsample
+        (seq_len, batch_size, num_channels) = src.shape
+        src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
+        src = src.reshape(seq_len * upsample, batch_size, num_channels)
+        return src
+
+
+class CompactRelPositionalEncoding(torch.nn.Module):
+    """
+    Relative positional encoding module.  This version is "compact" meaning it is able to encode
+    the important information about the relative position in a relatively small number of dimensions.
+    The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001)
+    make very little difference to the embedding.   Such differences were potentially important
+    when encoding absolute position, but not important when encoding relative position because there
+    is now no need to compare two large offsets with each other.
+
+    Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval
+    using the atan() function, before doing the fourier transform of that fixed interval.  The
+    atan() function would compress the "long tails" too small,
+    making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic
+    function to compress large offsets to a smaller range before applying atan().
+    Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long
+    as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim)
+
+
+    Args:
+        embed_dim: Embedding dimension.
+        dropout_rate: Dropout rate.
+        max_len: Maximum input length: just a heuristic for initialization.
+        length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
+           less weight to small differences of offset near the origin.
+    """
+
+    def __init__(
+        self,
+        embed_dim: int,
+        dropout_rate: FloatLike,
+        max_len: int = 1000,
+        length_factor: float = 1.0,
+    ) -> None:
+        """Construct a CompactRelPositionalEncoding object."""
+        super(CompactRelPositionalEncoding, self).__init__()
+        self.embed_dim = embed_dim
+        assert embed_dim % 2 == 0
+        self.dropout = Dropout2(dropout_rate)
+        self.pe = None
+        assert length_factor >= 1.0
+        self.length_factor = length_factor
+        self.extend_pe(torch.tensor(0.0).expand(max_len))
+
+    def extend_pe(self, x: Tensor, left_context_len: int = 0):
+        """Reset the positional encodings."""
+        T = x.size(0) + left_context_len
+
+        if self.pe is not None:
+            # self.pe contains both positive and negative parts
+            # the length of self.pe is 2 * input_len - 1
+            if self.pe.size(0) >= T * 2 - 1:
+                # self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+                return self.pe.to(dtype=x.dtype, device=x.device)
+
+        # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
+        x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
+
+        freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
+
+        # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution
+        # for small time offsets but less resolution for large time offsets.
+        compression_length = self.embed_dim**0.5
+        # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity;
+        # but it does so more slowly than T for large absolute values of T.
+        # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which
+        # is important.
+        x_compressed = (
+            compression_length
+            * x.sign()
+            * ((x.abs() + compression_length).log() - math.log(compression_length))
+        )
+
+        # if self.length_factor == 1.0, then length_scale is chosen so that the
+        # FFT can exactly separate points close to the origin (T == 0).  So this
+        # part of the formulation is not really heuristic.
+        # But empirically, for ASR at least, length_factor > 1.0 seems to work better.
+        length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)
+
+        # note for machine implementations: if atan is not available, we can use:
+        #   x.sign() * ((1 / (x.abs() + 1)) - 1)  * (-math.pi/2)
+        #  check on wolframalpha.com: plot(sign(x) *  (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
+        x_atan = (x_compressed / length_scale).atan()  # results between -pi and pi
+
+        cosines = (x_atan * freqs).cos()
+        sines = (x_atan * freqs).sin()
+
+        pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
+        pe[:, 0::2] = cosines
+        pe[:, 1::2] = sines
+        pe[:, -1] = 1.0  # for bias.
+
+        # self.pe = pe.to(dtype=x.dtype)
+        return pe.to(dtype=x.dtype)
+
+    def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
+        """Create positional encoding.
+
+        Args:
+            x (Tensor): Input tensor (time, batch, `*`).
+            left_context_len: (int): Length of cached left context.
+
+        Returns:
+            positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
+        """
+        pe = self.extend_pe(x, left_context_len)
+        x_size_left = x.size(0) + left_context_len
+        # length of positive side: x.size(0) + left_context_len
+        # length of negative side: x.size(0)
+        pos_emb = pe[
+            pe.size(0) // 2
+            - x_size_left
+            + 1 : pe.size(0) // 2  # noqa E203
+            + x.size(0),
+            :,
+        ]
+        pos_emb = pos_emb.unsqueeze(0)
+        return self.dropout(pos_emb)
+
+
+class RelPositionMultiheadAttentionWeights(nn.Module):
+    r"""Module that computes multi-head attention weights with relative position encoding.
+    Various other modules consume the resulting attention weights: see, for example, the
+    SimpleAttention module which allows you to compute conventional attention.
+
+    This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context",
+    we have to write up the differences.
+
+
+    Args:
+           embed_dim: number of channels at the input to this module, e.g. 256
+             pos_dim: dimension of the positional encoding vectors, e.g. 128.
+           num_heads:  number of heads to compute weights for, e.g. 8
+     query_head_dim: dimension of the query (and key), per head.  e.g. 24.
+       pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
+            dropout: dropout probability for attn_output_weights. Default: 0.0.
+       pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
+                     any given call to forward(), in training time.
+    """
+
+    def __init__(
+        self,
+        embed_dim: int,
+        pos_dim: int,
+        num_heads: int,
+        query_head_dim: int,
+        pos_head_dim: int,
+        dropout: float = 0.0,
+        pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
+    ) -> None:
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.query_head_dim = query_head_dim
+        self.pos_head_dim = pos_head_dim
+        self.dropout = dropout
+        self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
+        self.name = None  # will be overwritten in training code; for diagnostics.
+
+        key_head_dim = query_head_dim
+        in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
+
+        # the initial_scale is supposed to take over the "scaling" factor of
+        # head_dim ** -0.5 that has been used in previous forms of attention,
+        # dividing it between the query and key.   Note: this module is intended
+        # to be used with the ScaledAdam optimizer; with most other optimizers,
+        # it would be necessary to apply the scaling factor in the forward function.
+        self.in_proj = ScaledLinear(
+            embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
+        )
+
+        self.whiten_keys = Whiten(
+            num_groups=num_heads,
+            whitening_limit=_whitening_schedule(3.0),
+            prob=(0.025, 0.25),
+            grad_scale=0.025,
+        )
+
+        # add a balancer for the keys that runs with very small probability, and
+        # tries to enforce that all dimensions have mean around zero.  The
+        # weights produced by this module are invariant to adding a constant to
+        # the keys, so the derivative of the bias is mathematically zero; but
+        # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero
+        # bias because the small numerical roundoff tends to have a non-random
+        # sign.  This module is intended to prevent that.  Use a very small
+        # probability; that should be suffixient to fix the problem.
+        self.balance_keys = Balancer(
+            key_head_dim * num_heads,
+            channel_dim=-1,
+            min_positive=0.4,
+            max_positive=0.6,
+            min_abs=0.0,
+            max_abs=100.0,
+            prob=0.025,
+        )
+
+        # linear transformation for positional encoding.
+        self.linear_pos = ScaledLinear(
+            pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05
+        )
+
+        # the following are for diagnosics only, see --print-diagnostics option
+        self.copy_pos_query = Identity()
+        self.copy_query = Identity()
+
+    def forward(
+        self,
+        x: Tensor,
+        pos_emb: Tensor,
+        key_padding_mask: Optional[Tensor] = None,
+        attn_mask: Optional[Tensor] = None,
+    ) -> Tensor:
+        r"""
+        Args:
+            x: input of shape (seq_len, batch_size, embed_dim)
+            pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
+            key_padding_mask: a bool tensor of shape (batch_size, seq_len).  Positions that
+               are True in this mask will be ignored as sources in the attention weighting.
+            attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
+               interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
+               saying which positions are allowed to attend to which other positions.
+        Returns:
+           a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
+           interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
+        """
+        x = self.in_proj(x)
+        query_head_dim = self.query_head_dim
+        pos_head_dim = self.pos_head_dim
+        num_heads = self.num_heads
+
+        seq_len, batch_size, _ = x.shape
+
+        query_dim = query_head_dim * num_heads
+
+        # self-attention
+        q = x[..., 0:query_dim]
+        k = x[..., query_dim : 2 * query_dim]
+        # p is the position-encoding query
+        p = x[..., 2 * query_dim :]
+        assert p.shape[-1] == num_heads * pos_head_dim
+
+        q = self.copy_query(q)  # for diagnostics only, does nothing.
+        k = self.whiten_keys(self.balance_keys(k))  # does nothing in the forward pass.
+        p = self.copy_pos_query(p)  # for diagnostics only, does nothing.
+
+        q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
+        p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
+        k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
+
+        # time1 refers to target, time2 refers to source.
+        q = q.permute(2, 1, 0, 3)  # (head, batch, time1, query_head_dim)
+        p = p.permute(2, 1, 0, 3)  # (head, batch, time1, pos_head_dim)
+        k = k.permute(2, 1, 3, 0)  # (head, batch, d_k, time2)
+
+        attn_scores = torch.matmul(q, k)
+
+        use_pos_scores = False
+        if torch.jit.is_scripting() or torch.jit.is_tracing():
+            # We can't put random.random() in the same line
+            use_pos_scores = True
+        elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
+            use_pos_scores = True
+
+        if use_pos_scores:
+            pos_emb = self.linear_pos(pos_emb)
+            seq_len2 = 2 * seq_len - 1
+            pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
+                2, 0, 3, 1
+            )
+            # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
+
+            # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
+            #  [where seq_len2 represents relative position.]
+            pos_scores = torch.matmul(p, pos_emb)
+            # the following .as_strided() expression converts the last axis of pos_scores from relative
+            # to absolute position.  I don't know whether I might have got the time-offsets backwards or
+            # not, but let this code define which way round it is supposed to be.
+            if torch.jit.is_tracing():
+                (num_heads, batch_size, time1, n) = pos_scores.shape
+                rows = torch.arange(start=time1 - 1, end=-1, step=-1)
+                cols = torch.arange(seq_len)
+                rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
+                indexes = rows + cols
+                pos_scores = pos_scores.reshape(-1, n)
+                pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
+                pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
+            else:
+                pos_scores = pos_scores.as_strided(
+                    (num_heads, batch_size, seq_len, seq_len),
+                    (
+                        pos_scores.stride(0),
+                        pos_scores.stride(1),
+                        pos_scores.stride(2) - pos_scores.stride(3),
+                        pos_scores.stride(3),
+                    ),
+                    storage_offset=pos_scores.stride(3) * (seq_len - 1),
+                )
+
+            attn_scores = attn_scores + pos_scores
+
+        if torch.jit.is_scripting() or torch.jit.is_tracing():
+            pass
+        elif self.training and random.random() < 0.1:
+            # This is a harder way of limiting the attention scores to not be
+            # too large.  It incurs a penalty if any of them has an absolute
+            # value greater than 50.0.  this should be outside the normal range
+            # of the attention scores.  We use this mechanism instead of, say,
+            # something added to the loss function involving the entropy,
+            # because once the entropy gets very small gradients through the
+            # softmax can become very small, and we'd get zero derivatives.  The
+            # choices of 1.0e-04 as the scale on the penalty makes this
+            # mechanism vulnerable to the absolute scale of the loss function,
+            # but we view this as a failsafe to avoid "implausible" parameter
+            # values rather than a regularization method that should be active
+            # under normal circumstances.
+            attn_scores = penalize_abs_values_gt(
+                attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
+            )
+
+        assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
+
+        if attn_mask is not None:
+            assert attn_mask.dtype == torch.bool
+            # use -1000 to avoid nan's where attn_mask and key_padding_mask make
+            # all scores zero.  It's important that this be large enough that exp(-1000)
+            # is exactly zero, for reasons related to const_attention_rate, it
+            # compares the final weights with zero.
+            attn_scores = attn_scores.masked_fill(attn_mask, -1000)
+
+        if key_padding_mask is not None:
+            assert key_padding_mask.shape == (
+                batch_size,
+                seq_len,
+            ), key_padding_mask.shape
+            attn_scores = attn_scores.masked_fill(
+                key_padding_mask.unsqueeze(1),
+                -1000,
+            )
+
+        # We use our own version of softmax, defined in scaling.py, which should
+        # save a little of the memory used in backprop by, if we are in
+        # automatic mixed precision mode (amp / autocast), by only storing the
+        # half-precision output for backprop purposes.
+        attn_weights = softmax(attn_scores, dim=-1)
+
+        if torch.jit.is_scripting() or torch.jit.is_tracing():
+            pass
+        elif random.random() < 0.001 and not self.training:
+            self._print_attn_entropy(attn_weights)
+
+        attn_weights = nn.functional.dropout(
+            attn_weights, p=self.dropout, training=self.training
+        )
+
+        return attn_weights
+
+    def streaming_forward(
+        self,
+        x: Tensor,
+        pos_emb: Tensor,
+        cached_key: Tensor,
+        left_context_len: int,
+        key_padding_mask: Tensor,
+    ) -> Tuple[Tensor, Tensor]:
+        r"""
+        Args:
+            x: input of shape (seq_len, batch_size, embed_dim)
+            pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim)
+            cached_key: cached attention key tensor of left context,
+              of shape (left_context_len, batch_size, key_dim)
+            left_context_len: number of left context frames.
+            key_padding_mask: a bool tensor of shape (batch_size, seq_len).  Positions that
+              are True in this mask will be ignored as sources in the attention weighting.
+
+        Returns:
+           - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2),
+             interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
+           - updated cached attention key tensor of left context.
+        """
+        x = self.in_proj(x)
+        query_head_dim = self.query_head_dim
+        pos_head_dim = self.pos_head_dim
+        num_heads = self.num_heads
+
+        seq_len, batch_size, _ = x.shape
+
+        query_dim = query_head_dim * num_heads
+
+        # self-attention
+        q = x[..., 0:query_dim]
+        k = x[..., query_dim : 2 * query_dim]
+        # p is the position-encoding query
+        p = x[..., 2 * query_dim :]
+        assert p.shape[-1] == num_heads * pos_head_dim
+
+        # Pad cached left contexts
+        assert cached_key.shape[0] == left_context_len, (
+            cached_key.shape[0],
+            left_context_len,
+        )
+        k = torch.cat([cached_key, k], dim=0)
+        # Update cached left contexts
+        cached_key = k[-left_context_len:, ...]
+
+        # The length of key
+        k_len = k.shape[0]
+
+        q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
+        p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
+        k = k.reshape(k_len, batch_size, num_heads, query_head_dim)
+
+        # time1 refers to target, time2 refers to source.
+        q = q.permute(2, 1, 0, 3)  # (head, batch, time1, query_head_dim)
+        p = p.permute(2, 1, 0, 3)  # (head, batch, time1, pos_head_dim)
+        k = k.permute(2, 1, 3, 0)  # (head, batch, d_k, time2)
+
+        attn_scores = torch.matmul(q, k)
+
+        pos_emb = self.linear_pos(pos_emb)
+        seq_len2 = 2 * seq_len - 1 + left_context_len
+        pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
+            2, 0, 3, 1
+        )
+        # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
+
+        # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
+        #  [where seq_len2 represents relative position.]
+        pos_scores = torch.matmul(p, pos_emb)
+
+        if torch.jit.is_tracing():
+            (num_heads, batch_size, time1, n) = pos_scores.shape
+            rows = torch.arange(start=time1 - 1, end=-1, step=-1)
+            cols = torch.arange(k_len)
+            rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
+            indexes = rows + cols
+            pos_scores = pos_scores.reshape(-1, n)
+            pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
+            pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
+        # the following .as_strided() expression converts the last axis of pos_scores from relative
+        # to absolute position.  I don't know whether I might have got the time-offsets backwards or
+        # not, but let this code define which way round it is supposed to be.
+        else:
+            pos_scores = pos_scores.as_strided(
+                (num_heads, batch_size, seq_len, k_len),
+                (
+                    pos_scores.stride(0),
+                    pos_scores.stride(1),
+                    pos_scores.stride(2) - pos_scores.stride(3),
+                    pos_scores.stride(3),
+                ),
+                storage_offset=pos_scores.stride(3) * (seq_len - 1),
+            )
+
+        attn_scores = attn_scores + pos_scores
+
+        assert attn_scores.shape == (
+            num_heads,
+            batch_size,
+            seq_len,
+            k_len,
+        ), attn_scores.shape
+
+        if key_padding_mask is not None:
+            assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape
+            attn_scores = attn_scores.masked_fill(
+                key_padding_mask.unsqueeze(1),
+                -1000,
+            )
+
+        attn_weights = attn_scores.softmax(dim=-1)
+
+        return attn_weights, cached_key
+
+    def _print_attn_entropy(self, attn_weights: Tensor):
+        # attn_weights: (num_heads, batch_size, seq_len, seq_len)
+        (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
+
+        with torch.no_grad():
+            with torch.cuda.amp.autocast(enabled=False):
+                attn_weights = attn_weights.to(torch.float32)
+                attn_weights_entropy = (
+                    -((attn_weights + 1.0e-20).log() * attn_weights)
+                    .sum(dim=-1)
+                    .mean(dim=(1, 2))
+                )
+                logging.info(
+                    f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}"
+                )
+
+
+class SelfAttention(nn.Module):
+    """
+    The simplest possible attention module.  This one works with already-computed attention
+    weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
+
+    Args:
+          embed_dim: the input and output embedding dimension
+          num_heads: the number of attention heads
+          value_head_dim: the value dimension per head
+    """
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        value_head_dim: int,
+    ) -> None:
+        super().__init__()
+        self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
+
+        self.out_proj = ScaledLinear(
+            num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
+        )
+
+        self.whiten = Whiten(
+            num_groups=1,
+            whitening_limit=_whitening_schedule(7.5, ratio=3.0),
+            prob=(0.025, 0.25),
+            grad_scale=0.01,
+        )
+
+    def forward(
+        self,
+        x: Tensor,
+        attn_weights: Tensor,
+    ) -> Tensor:
+        """
+        Args:
+          x: input tensor, of shape (seq_len, batch_size, embed_dim)
+         attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
+          with seq_len being interpreted as (tgt_seq_len, src_seq_len).  Expect
+          attn_weights.sum(dim=-1) == 1.
+        Returns:
+           a tensor with the same shape as x.
+        """
+        (seq_len, batch_size, embed_dim) = x.shape
+        num_heads = attn_weights.shape[0]
+        assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
+
+        x = self.in_proj(x)  # (seq_len, batch_size, num_heads * value_head_dim)
+        x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+        # now x: (num_heads, batch_size, seq_len, value_head_dim)
+        value_head_dim = x.shape[-1]
+
+        # todo: see whether there is benefit in overriding matmul
+        x = torch.matmul(attn_weights, x)
+        # v: (num_heads, batch_size, seq_len, value_head_dim)
+
+        x = (
+            x.permute(2, 1, 0, 3)
+            .contiguous()
+            .view(seq_len, batch_size, num_heads * value_head_dim)
+        )
+
+        # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
+        x = self.out_proj(x)
+        x = self.whiten(x)
+
+        return x
+
+    def streaming_forward(
+        self,
+        x: Tensor,
+        attn_weights: Tensor,
+        cached_val: Tensor,
+        left_context_len: int,
+    ) -> Tuple[Tensor, Tensor]:
+        """
+        Args:
+            x: input tensor, of shape (seq_len, batch_size, embed_dim)
+            attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
+              with seq_len being interpreted as (tgt_seq_len, src_seq_len).  Expect
+              attn_weights.sum(dim=-1) == 1.
+            cached_val: cached attention value tensor of left context,
+              of shape (left_context_len, batch_size, value_dim)
+            left_context_len: number of left context frames.
+
+        Returns:
+           - attention weighted output, a tensor with the same shape as x.
+           - updated cached attention value tensor of left context.
+        """
+        (seq_len, batch_size, embed_dim) = x.shape
+        num_heads = attn_weights.shape[0]
+        seq_len2 = seq_len + left_context_len
+        assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2)
+
+        x = self.in_proj(x)  # (seq_len, batch_size, num_heads * value_head_dim)
+
+        # Pad cached left contexts
+        assert cached_val.shape[0] == left_context_len, (
+            cached_val.shape[0],
+            left_context_len,
+        )
+        x = torch.cat([cached_val, x], dim=0)
+        # Update cached left contexts
+        cached_val = x[-left_context_len:, ...]
+
+        x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+        # now x: (num_heads, batch_size, seq_len, value_head_dim)
+        value_head_dim = x.shape[-1]
+
+        # todo: see whether there is benefit in overriding matmul
+        x = torch.matmul(attn_weights, x)
+        # v: (num_heads, batch_size, seq_len, value_head_dim)
+
+        x = (
+            x.permute(2, 1, 0, 3)
+            .contiguous()
+            .view(seq_len, batch_size, num_heads * value_head_dim)
+        )
+
+        # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
+        x = self.out_proj(x)
+
+        return x, cached_val
+
+
+class FeedforwardModule(nn.Module):
+    """Feedforward module in Zipformer2 model."""
+
+    def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike):
+        super(FeedforwardModule, self).__init__()
+        self.in_proj = nn.Linear(embed_dim, feedforward_dim)
+
+        self.hidden_balancer = Balancer(
+            feedforward_dim,
+            channel_dim=-1,
+            min_positive=0.3,
+            max_positive=1.0,
+            min_abs=0.75,
+            max_abs=5.0,
+        )
+
+        # shared_dim=0 means we share the dropout mask along the time axis
+        self.out_proj = ActivationDropoutAndLinear(
+            feedforward_dim,
+            embed_dim,
+            activation="SwooshL",
+            dropout_p=dropout,
+            dropout_shared_dim=0,
+            bias=True,
+            initial_scale=0.1,
+        )
+
+        self.out_whiten = Whiten(
+            num_groups=1,
+            whitening_limit=_whitening_schedule(7.5),
+            prob=(0.025, 0.25),
+            grad_scale=0.01,
+        )
+
+    def forward(self, x: Tensor):
+        x = self.in_proj(x)
+        x = self.hidden_balancer(x)
+        # out_proj contains SwooshL activation, then dropout, then linear.
+        x = self.out_proj(x)
+        x = self.out_whiten(x)
+        return x
+
+
+class NonlinAttention(nn.Module):
+    """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed
+       from the attention module) in place of actual convolution.  We also took out the second nonlinearity, the
+       one after the attention mechanism.
+
+    Args:
+        channels (int): The number of channels of conv layers.
+    """
+
+    def __init__(
+        self,
+        channels: int,
+        hidden_channels: int,
+    ) -> None:
+        super().__init__()
+
+        self.hidden_channels = hidden_channels
+
+        self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
+
+        # balancer that goes before the sigmoid.  Have quite a large min_abs value, at 2.0,
+        # because we noticed that well-trained instances of this module have abs-value before the sigmoid
+        # starting from about 3, and poorly-trained instances of the module have smaller abs values
+        # before the sigmoid.
+        self.balancer = Balancer(
+            hidden_channels,
+            channel_dim=-1,
+            min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
+            max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
+            min_abs=0.5,
+            max_abs=5.0,
+        )
+        self.tanh = nn.Tanh()
+
+        self.identity1 = Identity()  # for diagnostics.
+        self.identity2 = Identity()  # for diagnostics.
+        self.identity3 = Identity()  # for diagnostics.
+
+        self.out_proj = ScaledLinear(
+            hidden_channels, channels, bias=True, initial_scale=0.05
+        )
+
+        self.whiten1 = Whiten(
+            num_groups=1,
+            whitening_limit=_whitening_schedule(5.0),
+            prob=(0.025, 0.25),
+            grad_scale=0.01,
+        )
+
+        self.whiten2 = Whiten(
+            num_groups=1,
+            whitening_limit=_whitening_schedule(5.0, ratio=3.0),
+            prob=(0.025, 0.25),
+            grad_scale=0.01,
+        )
+
+    def forward(
+        self,
+        x: Tensor,
+        attn_weights: Tensor,
+    ) -> Tensor:
+        """.
+                Args:
+                   x: a Tensor of shape (seq_len, batch_size, num_channels)
+        attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
+                Returns:
+                   a Tensor with the same shape as x
+        """
+        x = self.in_proj(x)
+
+        (seq_len, batch_size, _) = x.shape
+        hidden_channels = self.hidden_channels
+
+        s, x, y = x.chunk(3, dim=2)
+
+        # s will go through tanh.
+
+        s = self.balancer(s)
+        s = self.tanh(s)
+
+        s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
+        x = self.whiten1(x)
+        x = x * s
+        x = self.identity1(x)  # diagnostics only, it's the identity.
+
+        (seq_len, batch_size, embed_dim) = x.shape
+        num_heads = attn_weights.shape[0]
+        assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
+
+        x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+        # now x: (num_heads, batch_size, seq_len, head_dim)
+        x = torch.matmul(attn_weights, x)
+        # now x: (num_heads, batch_size, seq_len, head_dim)
+        x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
+
+        y = self.identity2(y)
+        x = x * y
+        x = self.identity3(x)
+
+        x = self.out_proj(x)
+        x = self.whiten2(x)
+        return x
+
+    def streaming_forward(
+        self,
+        x: Tensor,
+        attn_weights: Tensor,
+        cached_x: Tensor,
+        left_context_len: int,
+    ) -> Tuple[Tensor, Tensor]:
+        """.
+        Args:
+            x: a Tensor of shape (seq_len, batch_size, num_channels)
+            attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
+            cached_x: left context, a Tensor of shape
+              (num_heads, batch_size, left_context_len, head_dim)
+            left_context_len: number of left context frames.
+        Returns:
+            - a Tensor with the same shape as x
+            - updated left context with same shape as cached_x
+        """
+        x = self.in_proj(x)
+
+        (seq_len, batch_size, _) = x.shape
+        hidden_channels = self.hidden_channels
+
+        s, x, y = x.chunk(3, dim=2)
+
+        # s will go through tanh.
+        s = self.tanh(s)
+
+        s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
+        x = x * s
+
+        (seq_len, batch_size, embed_dim) = x.shape
+        num_heads = attn_weights.shape[0]
+        assert attn_weights.shape == (
+            num_heads,
+            batch_size,
+            seq_len,
+            left_context_len + seq_len,
+        )
+
+        x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
+        # now x: (num_heads, batch_size, seq_len, head_dim)
+
+        # Pad cached tensor
+        assert cached_x.shape[2] == left_context_len, (
+            cached_x.shape[2],
+            left_context_len,
+        )
+        x_pad = torch.cat([cached_x, x], dim=2)
+        # Update cached tensor
+        cached_x = x_pad[:, :, -left_context_len:, :]
+
+        x = torch.matmul(attn_weights, x_pad)
+        # now x: (num_heads, batch_size, seq_len, head_dim)
+        x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
+
+        x = x * y
+
+        x = self.out_proj(x)
+        return x, cached_x
+
+
+class ConvolutionModule(nn.Module):
+    """ConvolutionModule in Zipformer2 model.
+    Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
+
+    Args:
+        channels (int): The number of channels of conv layers.
+        kernel_size (int): Kernerl size of conv layers.
+        bias (bool): Whether to use bias in conv layers (default=True).
+
+    """
+
+    def __init__(
+        self,
+        channels: int,
+        kernel_size: int,
+        causal: bool,
+    ) -> None:
+        """Construct a ConvolutionModule object."""
+        super(ConvolutionModule, self).__init__()
+        # kernerl_size should be a odd number for 'SAME' padding
+        assert (kernel_size - 1) % 2 == 0
+
+        bottleneck_dim = channels
+        self.causal = causal
+
+        self.in_proj = nn.Linear(
+            channels,
+            2 * bottleneck_dim,
+        )
+        # the gradients on in_proj are a little noisy, likely to do with the
+        # sigmoid in glu.
+
+        # after in_proj we put x through a gated linear unit (nn.functional.glu).
+        # For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
+        # but sometimes, for some reason, for layer 0 the rms ends up being very large,
+        # between 50 and 100 for different channels.  This will cause very peaky and
+        # sparse derivatives for the sigmoid gating function, which will tend to make
+        # the loss function not learn effectively.  (for most layers the average absolute values
+        # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
+        # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
+        # layers, which likely breaks down as 0.5 for the "linear" half and
+        # 0.2 to 0.3 for the part that goes into the sigmoid.  The idea is that if we
+        # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
+        # it will be in a better position to start learning something, i.e. to latch onto
+        # the correct range.
+        self.balancer1 = Balancer(
+            bottleneck_dim,
+            channel_dim=-1,
+            min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
+            max_positive=1.0,
+            min_abs=1.5,
+            max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
+        )
+
+        self.activation1 = Identity()  # for diagnostics
+
+        self.sigmoid = nn.Sigmoid()
+
+        self.activation2 = Identity()  # for diagnostics
+
+        assert kernel_size % 2 == 1
+
+        self.depthwise_conv = (
+            ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size)
+            if causal
+            else nn.Conv1d(
+                in_channels=bottleneck_dim,
+                out_channels=bottleneck_dim,
+                groups=bottleneck_dim,
+                kernel_size=kernel_size,
+                padding=kernel_size // 2,
+            )
+        )
+
+        self.balancer2 = Balancer(
+            bottleneck_dim,
+            channel_dim=1,
+            min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
+            max_positive=1.0,
+            min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
+            max_abs=10.0,
+        )
+
+        self.whiten = Whiten(
+            num_groups=1,
+            whitening_limit=_whitening_schedule(7.5),
+            prob=(0.025, 0.25),
+            grad_scale=0.01,
+        )
+
+        self.out_proj = ActivationDropoutAndLinear(
+            bottleneck_dim,
+            channels,
+            activation="SwooshR",
+            dropout_p=0.0,
+            initial_scale=0.05,
+        )
+
+    def forward(
+        self,
+        x: Tensor,
+        src_key_padding_mask: Optional[Tensor] = None,
+        chunk_size: int = -1,
+    ) -> Tensor:
+        """Compute convolution module.
+
+        Args:
+            x: Input tensor (#time, batch, channels).
+           src_key_padding_mask: the mask for the src keys per batch (optional):
+               (batch, #time), contains True in masked positions.
+
+        Returns:
+            Tensor: Output tensor (#time, batch, channels).
+
+        """
+
+        x = self.in_proj(x)  # (time, batch, 2*channels)
+
+        x, s = x.chunk(2, dim=2)
+        s = self.balancer1(s)
+        s = self.sigmoid(s)
+        x = self.activation1(x)  # identity.
+        x = x * s
+        x = self.activation2(x)  # identity
+
+        # (time, batch, channels)
+
+        # exchange the temporal dimension and the feature dimension
+        x = x.permute(1, 2, 0)  # (#batch, channels, time).
+
+        if src_key_padding_mask is not None:
+            x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
+
+        if (
+            not torch.jit.is_scripting()
+            and not torch.jit.is_tracing()
+            and chunk_size >= 0
+        ):
+            # Not support exporting a model for simulated streaming decoding
+            assert (
+                self.causal
+            ), "Must initialize model with causal=True if you use chunk_size"
+            x = self.depthwise_conv(x, chunk_size=chunk_size)
+        else:
+            x = self.depthwise_conv(x)
+
+        x = self.balancer2(x)
+        x = x.permute(2, 0, 1)  # (time, batch, channels)
+
+        x = self.whiten(x)  # (time, batch, channels)
+        x = self.out_proj(x)  # (time, batch, channels)
+
+        return x
+
+    def streaming_forward(
+        self,
+        x: Tensor,
+        cache: Tensor,
+        src_key_padding_mask: Tensor,
+    ) -> Tuple[Tensor, Tensor]:
+        """Compute convolution module in streaming forward mode.
+
+        Args:
+            x: Input tensor (#time, batch, channels).
+            cache: cached left context for depthwise_conv of shape
+              (#batch, channels, left_pad)
+            src_key_padding_mask: the mask for the src keys per batch (optional):
+              (batch, #time), contains True in masked positions.
+
+        Returns:
+            - Output tensor (#time, batch, channels).
+            - Updated cache (#batch, channels, left_pad)
+        """
+
+        x = self.in_proj(x)  # (time, batch, 2*channels)
+
+        x, s = x.chunk(2, dim=2)
+        s = self.sigmoid(s)
+        x = x * s
+        # (time, batch, channels)
+
+        # exchange the temporal dimension and the feature dimension
+        x = x.permute(1, 2, 0)  # (#batch, channels, time).
+
+        if src_key_padding_mask is not None:
+            x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
+
+        x, cache = self.depthwise_conv.streaming_forward(x, cache=cache)
+
+        x = x.permute(2, 0, 1)  # (time, batch, channels)
+
+        x = self.out_proj(x)  # (time, batch, channels)
+
+        return x, cache
+
+
+class ScalarMultiply(nn.Module):
+    def __init__(self, scale: float):
+        super().__init__()
+        self.scale = scale
+
+    def forward(self, x):
+        return x * self.scale
+
+
+def _test_zipformer_main(causal: bool = False):
+    batch_size = 5
+    seq_len = 20
+    # Just make sure the forward pass runs.
+
+    c = Zipformer2(
+        encoder_dim=(64, 96),
+        encoder_unmasked_dim=(48, 64),
+        num_heads=(4, 4),
+        causal=causal,
+        chunk_size=(4,) if causal else (-1,),
+        left_context_frames=(64,),
+    )
+    batch_size = 5
+    seq_len = 20
+    # Just make sure the forward pass runs.
+    f = c(
+        torch.randn(seq_len, batch_size, 64),
+        torch.full((batch_size,), seq_len, dtype=torch.int64),
+    )
+    f[0].sum().backward()
+    c.eval()
+    f = c(
+        torch.randn(seq_len, batch_size, 64),
+        torch.full((batch_size,), seq_len, dtype=torch.int64),
+    )
+    f  # to remove flake8 warnings
+
+
+if __name__ == "__main__":
+    logging.getLogger().setLevel(logging.INFO)
+    torch.set_num_threads(1)
+    torch.set_num_interop_threads(1)
+    _test_zipformer_main(False)
+    _test_zipformer_main(True)
\ No newline at end of file
-- 
Gitee


From de2dd27b8525405aae019b69ab1f6b4808596242 Mon Sep 17 00:00:00 2001
From: han_yifeng <hanyifeng2@huawei.com>
Date: Sat, 9 Mar 2024 18:08:34 +0800
Subject: [PATCH 2/2] =?UTF-8?q?Zipformer=20PT=E7=BC=96=E8=AF=91=E9=9D=9E?=
 =?UTF-8?q?=E6=B5=81=E5=BC=8F=E4=B8=8A=E7=BA=BF-part1?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 .../built-in/audio/Zipformer/README.md        |  218 +-
 .../audio/Zipformer/export-onnx.patch         |   94 +
 .../built-in/audio/Zipformer/export-onnx.py   |  625 -----
 ...ie_ts_enc.py => export_torch_aie_model.py} |   42 +-
 .../icefall_pt/export_torch_aie_ts_dec.py     |   62 -
 .../icefall_pt/export_torch_aie_ts_join.py    |   67 -
 .../Zipformer/icefall_pt/model_pt_dec.py      |    8 +-
 .../Zipformer/icefall_pt/model_pt_enc.py      |    9 +-
 .../Zipformer/icefall_pt/model_pt_join.py     |    8 +-
 .../audio/Zipformer/icefall_pt/pt_val_dec.py  |   96 +-
 .../audio/Zipformer/icefall_pt/pt_val_enc.py  |   90 +-
 .../audio/Zipformer/icefall_pt/pt_val_join.py |   94 +-
 .../audio/Zipformer/modify_decoder.py         |   25 -
 .../built-in/audio/Zipformer/zipformer.patch  |   59 +
 .../built-in/audio/Zipformer/zipformer.py     | 2438 -----------------
 15 files changed, 335 insertions(+), 3600 deletions(-)
 create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.patch
 delete mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.py
 rename AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/{export_torch_aie_ts_enc.py => export_torch_aie_model.py} (59%)
 delete mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_dec.py
 delete mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_join.py
 delete mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/modify_decoder.py
 create mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.patch
 delete mode 100644 AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.py

diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/README.md b/AscendIE/TorchAIE/built-in/audio/Zipformer/README.md
index 7eee8c3e05..619ce41e49 100644
--- a/AscendIE/TorchAIE/built-in/audio/Zipformer/README.md
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/README.md
@@ -9,7 +9,6 @@
 
 - [模型推理性能精度](#ZH-CN_TOPIC_0000001172201573)
 
-
 # 概述<a name="ZH-CN_TOPIC_0000001172161501"></a>
 
 (来自论文摘要)Conformer 已成为自动语音识别 (ASR) 中最流行的编码器模型。它将卷积模块添加到变压器中以学习局部和全局依赖性。
@@ -17,7 +16,6 @@
 2)重新组织了具有更多模块的块结构,其中我们重新使用注意力权重以提高效率;3)LayerNorm的一种修改形式称为BiasNorm,允许我们保留一些长度信息;4)新的激活函数SwooshR和SwooshL比Swish效果更好。
 我们还提出了一个新的优化器,称为 ScaledAdam,它通过每个张量的当前尺度来缩放更新以保持相对变化大致相同,并且还显式地学习参数尺度。它比 Adam 实现了更快的收敛和更好的性能。
 在 LibriSpeech、Aishell-1 和 WenetSpeech 数据集上进行的大量实验证明了我们提出的 Zipformer 相对于其他最先进的 ASR 模型的有效性。
-  
 
 # 推理环境准备\[所有版本\]<a name="ZH-CN_TOPIC_0000001126281702"></a>
 
@@ -25,14 +23,12 @@
 
   **表 1**  版本配套表
 
-| 配套                    | 版本          | 
-|-----------------------|-------------| 
-| CANN                  | 7.0.RC1     | -                                                       |
-| Python                | 3.9.11      |                                                           
-| torch                 | 2.0.1       |
-| Ascend-cann-torch-aie | -           
-| Ascend-cann-aie       | -           
-| 芯片类型                  | Ascend310P3 | -                                                         |
+| 配套                          | 版本          | 
+|-----------------------------|-------------| 
+| CANN                        | 8.0.T5      | -                                                       |
+| Python                      | 3.10.13     |                                                           
+| torch                       | 2.1.0       |
+| 芯片类型                        | Ascend310P3 | -                                                         |
 
 # 快速上手<a name="ZH-CN_TOPIC_0000001126281700"></a>
 
@@ -51,8 +47,9 @@
     export K2_MAKE_ARGS="-j6"
     python3 setup.py install
     ```
-    若执行以上命令遇到错误,请参考[此链接](https://k2-fsa.github.io/k2/installation/from_source.html)。  
-    3.(GPU)x86环境。从[此链接](https://k2-fsa.github.io/k2/cuda.html)下载对应CUDA版本的whl文件,然后使用pip进行安装。  
+    * **若编译失败,尝试再次编译前,需要先删除build文件夹。**
+    * 若执行以上命令遇到错误,请参考[此链接](https://k2-fsa.github.io/k2/installation/from_source.html)。    
+    3. (GPU)x86环境。从[此链接](https://k2-fsa.github.io/k2/cuda.html)下载对应CUDA版本的whl文件,然后使用pip进行安装。
     4. 验证k2是否安装成功  
     ```shell
     python3 -m k2.version
@@ -61,13 +58,21 @@
     ```shell
     pip install lhotse
     pip install kaldifeat
+    
+    apt install libsndfile1
+    ```
+   * kaldifeat若安装失败,请执行以下命令使用源码安装:
+    ```shell
+    git clone https://github.com/csukuangfj/kaldifeat.git
+    cd kaldifeat
+    python3 setup.py install
     ```
 3. 安装icefall
     ```shell
     git clone https://github.com/k2-fsa/icefall.git
-    git reset --hard e2fcb42f5f176d9e39eb38506ab99d0a3adaf202
    
     cd icefall
+    git reset --hard e2fcb42f5f176d9e39eb38506ab99d0a3adaf202
     pip install -r requirements.txt
     ```
 4. 将icefall加入环境变量, "/path/to/icefall"替换为icefall文件夹所在的路径。
@@ -75,41 +80,42 @@
     ```shell
     export PYTHONPATH=/path/to/icefall:$PYTHONPATH
     ```
-
 ## 模型下载
-1. 安装 git lfs
-    ```shell
-    curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash
-
-    sudo apt-get install git-lfs
-    git lfs install --skip-repo
-    ```
-2. 下载模型
-    ```shell
-   git clone https://huggingface.co/pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
-    ```
-   若下载失败,请尝试从以上链接手动下载文件。模型转换和推理时只需要用到以下文件:
-    - data/lang_char/tokens.txt
+从[此链接](https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615)下载模型相关文件。
+模型转换和推理时只需要用到以下文件:  
+    - data/lang_char/tokens.txt  
     - exp/epoch-12.pt
 
+下载完后,整理成如下目录结构:
+```shell
+icefall-asr-zipformer-wenetspeech-20230615
+├── data
+│   └── lang_char
+│       └── tokens.txt
+└── exp
+    └── epoch-12.pt
+```
 ## 模型推理
-1. 替换代码
-   使用本目录下的zipformer.py替换egs/librispeech/ASR/zipformer/zipformer.py
-   export-onnx.py替换egs/librispeech/ASR/zipformer/export-onnx.py
-
-2. 将本代码仓的icefall_pt目录拷贝到工程的根目录./icefall下。
+1. 打代码补丁(首先将export_onnx.patch与zipformer.patch拷贝到icefall/egs/librispeech/ASR/zipformer目录下)
+    ```shell
+    cd icefall/egs/librispeech/ASR/zipformer/
+    
+    patch < export_onnx.patch
+    patch < zipformer.patch
+    ```
+2. 将本代码仓的icefall_pt目录拷贝到icefall工程根目录下。
 
-3. 导出onnx模型,用于精度测试。
+3. 导出onnx模型与torchscript模型。
     ```shell
-   cd icefall/egs/librispeech/ASR/zipformer
-    # 注意将"icefall-asr-zipformer-streaming-wenetspeech-20230615"修改为实际路径
-   python ./export-onnx.py \
-   # 暂时没有使用以下参数
-     --tokens icefall-asr-zipformer-streaming-wenetspeech-20230615/data/lang_char/tokens.txt \
+   cd icefall/
+   repo=/path/to/icefall//egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615
+    # 注意将"icefall-asr-zipformer-wenetspeech-20230615"修改为实际路径
+   python ./egs/librispeech/ASR/zipformer/export-onnx.py \
+     --tokens $repo/data/lang_char/tokens.txt \
      --use-averaged-model 0 \
      --epoch 12 \
      --avg 1 \
-     --exp-dir icefall-asr-zipformer-streaming-wenetspeech-20230615/exp \
+     --exp-dir $repo/exp \
      --num-encoder-layers "2,2,3,4,3,2" \
      --downsampling-factor "1,2,4,8,4,2" \
      --feedforward-dim "512,768,1024,1536,1024,768" \
@@ -123,99 +129,87 @@
      --cnn-module-kernel "31,31,15,15,15,31" \
      --decoder-dim 512 \
      --joiner-dim 512 \
-     --causal True \
-     --chunk-size 16 \
-     --left-context-frames 128
+     --causal False \
+     --chunk-size "16, 32, 64, -1" \
+     --left-context-frames "64, 128, 256, -1"
     ```
-   执行结束后,会在“icefall-asr-zipformer-streaming-wenetspeech-20230615/exp”目录下生成三个onnx文件与三个ts文件:
+   执行结束后,会在“icefall-asr-zipformer-wenetspeech-20230615/exp”目录下生成三个onnx文件:
     - encoder-epoch-12-avg-1.onnx
     - decoder-epoch-12-avg-1.onnx
     - joiner-epoch-12-avg-1.onnx
     - encoder-epoch-12-avg-1.pt
     - decoder-epoch-12-avg-1.pt
     - joiner-epoch-12-avg-1.pt
-4. 对torchscript模型进行编译。
+4. 对torchscript模型使用PT插件进行编译。
    ```shell
    cd icefall/icefall_pt
-   
-   # 注意将"icefall-asr-zipformer-streaming-wenetspeech-20230615"修改为实际路径
-   python ./export_torch_aie_ts_enc.py
-   python ./export_torch_aie_ts_dec.py
-   python ./export_torch_aie_ts_join.py
+   python export_torch_aie_model.py
+    ```
+    参数说明:
+   ```shell
+   --torch_script_path:torhscript模型路径
+   --export_part:选择编译部分(encoder,decoder或joiner)
+   --soc_version:硬件版本
+   --batch_size
+   --save_path:编译后的模型保存路径
     ```
-    执行结束后,会在“icefall_pt/pt_compiled_model”目录下生成三个编译好的torchscript文件:
-    - encoder-epoch-12-avg-1_torch_aie_bs1.pt
-    - decoder-epoch-12-avg-1_torch_aie_bs1.pt
-    - joiner-epoch-12-avg-1_torch_aie_bs1.pt  
+    执行结束后,会在save_path对应目录下生成三个编译好的torchscript文件:
+    - encoder-epoch-12-avg-1_mindietorch_bs1.pt
+    - decoder-epoch-12-avg-1_mindietorch_bs1.pt
+    - joiner-epoch-12-avg-1_mindietorch_bs1.pt  
 
 5. 运行推理样例
-   1. 下载样例语音数据(测试暂时使用全0输入)
-      ```shell
-      cd icefall/egs/librispeech/ASR/zipformer
-      wget https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
-      ```
-   2. 执行推理
-      ```shell
-      # 注意将"icefall-asr-zipformer-streaming-wenetspeech-20230615"修改为实际路径
-      cd icefall/icefall_pt
-      python ./pt_val_enc.py
-      python ./pt_val_dec.py
-      python ./pt_val_join.py
-      ```
-      执行结束后,会在icefall_pt/result下看到三个模型的推理结果:
-      同时输出PT执行的性能
-
-      
-6. 精度测试
    ```shell
-   cd icefall/egs/librispeech/ASR/zipformer
-   将pretrained.py中的encoder,decoder与joiner的输入改为全全0:enc(1,100,80)+常量100,dec(1,2),joiner(1,512)+(1,512),并保存三个输出
-   python ./pretrained.py
-   三个输出分别比较余弦相似度
+   cd icefall/icefall_pt
+   python pt_val_enc.py
+   python pt_val_dec.py
+   python pt_val_join.py
    ```
-    执行结束后,结果为:
+    参数说明:
+   ```shell
+   --model:PT模型路径
+   --need_compile:是否需要编译。若model指向torchscript模型,则需要进行编译后运行
+   --soc_version:硬件版本
+   --batch_size
+   --result_path:模型运行结果保存路径
+   --device_id:硬件编号
+   --multi:数据加倍的倍数(默认生成20条数据,multi为1。若multi为n,则数据实际为n*20条。注意,若n不为1,则不生成result文件)
+    ```
+6. 精度测试(将onnx_test下的四个测试脚本拷贝到icefall/egs/librispeech/ASR/zipformer)
+   ```shell
+   cd icefall/egs/librispeech/ASR
+   python ./zipformer/onnx_test_enc.py --encoder-model-filename $repo/exp/encoder-epoch-12-avg-1.onnx --tokens $repo/data/lang_char/tokens.txt
+   python ./zipformer/onnx_test_dec.py --decoder-model-filename $repo/exp/decoder-epoch-12-avg-1.onnx --tokens $repo/data/lang_char/tokens.txt
+   python ./zipformer/onnx_test_join.py --joiner-model-filename $repo/exp/joiner-epoch-12-avg-1.onnx --tokens $repo/data/lang_char/tokens.txt
+   ```
+    执行结束后,将得到的结果路径与步骤5中得到的结果路径分别配置到脚本cosine_similarity_test.py中,并运行查看相似度:
+    ```shell
+   python cosine_similarity_test.py
+    ```
+   可以得出与T4及A10对比的精度结果:
     ```shell
-   encoder:0.97+
-   decoder:0.99+
-   joiner:0.99+
+   Cosine Similarity PT_T4_ENC :0.9999+
+   Cosine Similarity PT_T4_DEC :1
+   Cosine Similarity PT_T4_JOIN :1
+   Cosine Similarity PT_A10_ENC :0.9999+
+   Cosine Similarity PT_A10_DEC :1
+   Cosine Similarity PT_A10_JOIN :1
     ```
 7. 性能测试
-   1. aie模型性能测试
-      ```shell
-      执行pt_val_enc/dec/join.py时会打印
-      ```
+   1. pt模型性能测试
+      第5步运行pt_val脚本时会打印PT模型性能
+
    2. onnx模型性能测试。  
-      1. (可选)若使用GPU,请确保已安装CUDA和pytorch-gpu版本,同时需安装onnxruntime-gpu,如下所示:
-      ```shell
-      
-      ```
-      执行结束后,三个模型的性能信息会打印在命令行,如下所示:
-      ```shell
-      Encoder latency: 964.9530 ms
-      Encoder throughput: 1.0363 fps
-      Decoder latency: 0.4806 ms
-      Decoder throughput: 2080.7143 fps
-      Joiner latency: 0.4994 ms
-      Joiner throughput: 2002.3092 fps
-      ```
-   3. om性能测试:
-      在原本onnx模型的基础上,encoder模型进行onnxsim
-      ```shell
-      onnxsim encoder-epoch-12-avg-1.onnx encoder-epoch-12-avg-1_sim.onnx
-      onnxsim decoder-epoch-12-avg-1.onnx decoder-epoch-12-avg-1_sim.onnx
-      python modify_decoder.py
-      python3 -m ais_bench --model encoder_linux_aarch64.om --loop=2000
-      python3 -m ais_bench --model decoder.om --loop=2000
-      python3 -m ais_bench --model joiner.om --loop=2000
-      ```
+      第6步运行onnx_test脚本时会打印onnx模型性能
+
 # 模型推理性能精度<a name="ZH-CN_TOPIC_0000001172201573"></a>
 
 Zipformer流式模型由三个子模型组成,分别是encoder、decoder和joiner,其性能如下表所示:
 
-| 模型      | pt插件 - 310P性能(时延/吞吐率) | T4性能(时延/吞吐率)       | A10性能(时延/吞吐率)      |
-|---------|-----------------------|--------------------|--------------------|
-| encoder | 20.4 ms / 49 fps      | 24.7 ms / 40 fps   | 19 ms / 52 fps     |
-| decoder | 0.19 ms / 5156 fps    | 0.59 ms / 1684 fps | 0.13 ms / 7604 fps |
-| joiner  | 0.22 ms / 4448 fps    | 0.13 ms / 7645 fps | 0.11 ms / 9224 fps |
-| 端到端     | 20.81 ms / 48 fps     | 25.42 ms /  39 fps | 19.24 ms / 52 fps  |
+| 模型      | pt插件 - 310P性能(时延/吞吐率)     | T4 onnx性能(时延/吞吐率)         | A10 onnx性能(时延/吞吐率)        |
+|---------|---------------------------|---------------------------|---------------------------|
+| encoder | 43.1391 ms / 23.1807 fps  | 25.6406 ms / 39.0005 fps  | 16.2751 ms / 61.4434 fps  |
+| decoder | 0.4387 ms / 2279.2528 fps | 0.5691 ms / 1757.0740 fps | 0.1219 ms / 8200.5706 fps |
+| joiner  | 0.4586 ms / 2180.0839 fps | 0.1526 ms / 6551.7825 fps | 0.1107 ms / 9026.4239 fps |
+| 端到端     | 44.0364 ms / 22.7084 fps  | 26.3623 ms /  37.9329 fps | 16.5077 ms / 60.5777 fps  |
 
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.patch b/AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.patch
new file mode 100644
index 0000000000..c9580a6096
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.patch
@@ -0,0 +1,94 @@
+--- export-onnx.py	2024-03-08 15:33:12.340000000 +0800
++++ export-onnx.py	2024-03-08 15:33:12.340000000 +0800
+@@ -27,10 +27,10 @@
+ 
+ 2. Export the model to ONNX
+ 
+-./zipformer/export-onnx.py \
+-  --tokens $repo/data/lang_bpe_500/tokens.txt \
++python3 ./export-onnx.py \
++  --tokens $repo/data/lang_char/tokens.txt \
+   --use-averaged-model 0 \
+-  --epoch 99 \
++  --epoch 12 \
+   --avg 1 \
+   --exp-dir $repo/exp \
+   --num-encoder-layers "2,2,3,4,3,2" \
+@@ -92,7 +92,7 @@
+     parser.add_argument(
+         "--epoch",
+         type=int,
+-        default=28,
++        default=12,
+         help="""It specifies the checkpoint to use for averaging.
+         Note: Epoch counts from 0.
+         You can specify --avg to use more checkpoints for model averaging.""",
+@@ -111,7 +111,7 @@
+     parser.add_argument(
+         "--avg",
+         type=int,
+-        default=15,
++        default=1,
+         help="Number of checkpoints to average. Automatically select "
+         "consecutive checkpoints before the checkpoint specified by "
+         "'--epoch' and '--iter'",
+@@ -120,7 +120,7 @@
+     parser.add_argument(
+         "--use-averaged-model",
+         type=str2bool,
+-        default=True,
++        default=False,
+         help="Whether to load averaged model. Currently it only supports "
+         "using --epoch. If True, it would decode with the averaged model "
+         "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+@@ -131,7 +131,7 @@
+     parser.add_argument(
+         "--exp-dir",
+         type=str,
+-        default="zipformer/exp",
++        default="/home/devkit/hanyifeng/icefall/egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/",
+         help="""It specifies the directory where all training related
+         files, e.g., checkpoints, log, etc, are saved
+         """,
+@@ -140,7 +140,7 @@
+     parser.add_argument(
+         "--tokens",
+         type=str,
+-        default="data/lang_bpe_500/tokens.txt",
++        default="/home/devkit/hanyifeng/icefall/egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/data/lang_char/tokens.txt",
+         help="Path to the tokens.txt",
+     )
+ 
+@@ -298,7 +298,7 @@
+     x_lens = torch.tensor([100], dtype=torch.int64)
+ 
+     encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
+-
++    encoder_model.save(str(encoder_filename).replace("onnx", "pt"))
+     torch.onnx.export(
+         encoder_model,
+         (x, x_lens),
+@@ -352,7 +352,9 @@
+     context_size = decoder_model.decoder.context_size
+     vocab_size = decoder_model.decoder.vocab_size
+ 
+-    y = torch.zeros(10, context_size, dtype=torch.int64)
++    y = torch.zeros(1, context_size, dtype=torch.int64)
++    ts_decoder_model = torch.jit.trace(decoder_model, y)
++    ts_decoder_model.save(str(decoder_filename).replace("onnx", "pt"))
+     decoder_model = torch.jit.script(decoder_model)
+     torch.onnx.export(
+         decoder_model,
+@@ -393,8 +395,10 @@
+     joiner_dim = joiner_model.output_linear.weight.shape[1]
+     logging.info(f"joiner dim: {joiner_dim}")
+ 
+-    projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
+-    projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
++    projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
++    projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
++    ts_joiner_model = torch.jit.trace(joiner_model, (projected_encoder_out, projected_decoder_out))
++    ts_joiner_model.save(str(joiner_filename).replace("onnx", "pt"))
+ 
+     torch.onnx.export(
+         joiner_model,
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.py
deleted file mode 100644
index 805f7405c8..0000000000
--- a/AscendIE/TorchAIE/built-in/audio/Zipformer/export-onnx.py
+++ /dev/null
@@ -1,625 +0,0 @@
-#!/usr/bin/env python3
-#
-# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
-# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
-
-"""
-This script exports a transducer model from PyTorch to ONNX.
-
-We use the pre-trained model from
-https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
-as an example to show how to use this file.
-
-1. Download the pre-trained model
-
-cd egs/librispeech/ASR
-
-repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
-GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
-repo=$(basename $repo_url)
-
-pushd $repo
-git lfs pull --include "exp/pretrained.pt"
-
-cd exp
-ln -s pretrained.pt epoch-99.pt
-popd
-
-2. Export the model to ONNX
-
-python3 ./egs/librispeech/ASR/zipformer/export-onnx.py \
-  --tokens $repo/data/lang_bpe_500/tokens.txt \
-  --use-averaged-model 0 \
-  --epoch 99 \
-  --avg 1 \
-  --exp-dir $repo/exp \
-  --num-encoder-layers "2,2,3,4,3,2" \
-  --downsampling-factor "1,2,4,8,4,2" \
-  --feedforward-dim "512,768,1024,1536,1024,768" \
-  --num-heads "4,4,4,8,4,4" \
-  --encoder-dim "192,256,384,512,384,256" \
-  --query-head-dim 32 \
-  --value-head-dim 12 \
-  --pos-head-dim 4 \
-  --pos-dim 48 \
-  --encoder-unmasked-dim "192,192,256,256,256,192" \
-  --cnn-module-kernel "31,31,15,15,15,31" \
-  --decoder-dim 512 \
-  --joiner-dim 512 \
-  --causal False \
-  --chunk-size "16,32,64,-1" \
-  --left-context-frames "64,128,256,-1"
-
-It will generate the following 3 files inside $repo/exp:
-
-  - encoder-epoch-99-avg-1.onnx
-  - decoder-epoch-99-avg-1.onnx
-  - joiner-epoch-99-avg-1.onnx
-
-See ./onnx_pretrained.py and ./onnx_check.py for how to
-use the exported ONNX models.
-"""
-
-import argparse
-import logging
-from pathlib import Path
-from typing import Dict, Tuple
-
-import k2
-import onnx
-import torch
-import torch.nn as nn
-from decoder import Decoder
-from onnxruntime.quantization import QuantType, quantize_dynamic
-from scaling_converter import convert_scaled_to_non_scaled
-from train import add_model_arguments, get_model, get_params
-from zipformer import Zipformer2
-
-from icefall.checkpoint import (
-    average_checkpoints,
-    average_checkpoints_with_averaged_model,
-    find_checkpoints,
-    load_checkpoint,
-)
-from icefall.utils import make_pad_mask, num_tokens, str2bool
-
-
-def get_parser():
-    parser = argparse.ArgumentParser(
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter
-    )
-
-    parser.add_argument(
-        "--epoch",
-        type=int,
-        default=12,
-        help="""It specifies the checkpoint to use for averaging.
-        Note: Epoch counts from 0.
-        You can specify --avg to use more checkpoints for model averaging.""",
-    )
-
-    parser.add_argument(
-        "--iter",
-        type=int,
-        default=0,
-        help="""If positive, --epoch is ignored and it
-        will use the checkpoint exp_dir/checkpoint-iter.pt.
-        You can specify --avg to use more checkpoints for model averaging.
-        """,
-    )
-
-    parser.add_argument(
-        "--avg",
-        type=int,
-        default=1,
-        help="Number of checkpoints to average. Automatically select "
-             "consecutive checkpoints before the checkpoint specified by "
-             "'--epoch' and '--iter'",
-    )
-
-    parser.add_argument(
-        "--use-averaged-model",
-        type=str2bool,
-        default=False,
-        help="Whether to load averaged model. Currently it only supports "
-             "using --epoch. If True, it would decode with the averaged model "
-             "over the epoch range from `epoch-avg` (excluded) to `epoch`."
-             "Actually only the models with epoch number of `epoch-avg` and "
-             "`epoch` are loaded for averaging. ",
-    )
-
-    parser.add_argument(
-        "--exp-dir",
-        type=str,
-        default="/home/devkit/hanyifeng/icefall/egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/",
-        help="""It specifies the directory where all training related
-        files, e.g., checkpoints, log, etc, are saved
-        """,
-    )
-
-    parser.add_argument(
-        "--tokens",
-        type=str,
-        default="/home/devkit/hanyifeng/icefall/egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/data/lang_char/tokens.txt",
-        help="Path to the tokens.txt",
-    )
-
-    parser.add_argument(
-        "--context-size",
-        type=int,
-        default=2,
-        help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
-    )
-
-    add_model_arguments(parser)
-
-    return parser
-
-
-def add_meta_data(filename: str, meta_data: Dict[str, str]):
-    """Add meta data to an ONNX model. It is changed in-place.
-
-    Args:
-      filename:
-        Filename of the ONNX model to be changed.
-      meta_data:
-        Key-value pairs.
-    """
-    model = onnx.load(filename)
-    for key, value in meta_data.items():
-        meta = model.metadata_props.add()
-        meta.key = key
-        meta.value = value
-
-    onnx.save(model, filename)
-
-
-class OnnxEncoder(nn.Module):
-    """A wrapper for Zipformer and the encoder_proj from the joiner"""
-
-    def __init__(
-            self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
-    ):
-        """
-        Args:
-          encoder:
-            A Zipformer encoder.
-          encoder_proj:
-            The projection layer for encoder from the joiner.
-        """
-        super().__init__()
-        self.encoder = encoder
-        self.encoder_embed = encoder_embed
-        self.encoder_proj = encoder_proj
-
-    def forward(
-            self,
-            x: torch.Tensor,
-            x_lens: torch.Tensor,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Please see the help information of Zipformer.forward
-
-        Args:
-          x:
-            A 3-D tensor of shape (N, T, C)
-          x_lens:
-            A 1-D tensor of shape (N,). Its dtype is torch.int64
-        Returns:
-          Return a tuple containing:
-            - encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
-            - encoder_out_lens, A 1-D tensor of shape (N,)
-        """
-        x, x_lens = self.encoder_embed(x, x_lens)
-        src_key_padding_mask = make_pad_mask(x_lens)
-        x = x.permute(1, 0, 2)
-        encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
-        encoder_out = encoder_out.permute(1, 0, 2)
-        encoder_out = self.encoder_proj(encoder_out)
-        # Now encoder_out is of shape (N, T, joiner_dim)
-
-        return encoder_out, encoder_out_lens
-
-
-class OnnxDecoder(nn.Module):
-    """A wrapper for Decoder and the decoder_proj from the joiner"""
-
-    def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
-        super().__init__()
-        self.decoder = decoder
-        self.decoder_proj = decoder_proj
-
-    def forward(self, y: torch.Tensor) -> torch.Tensor:
-        """
-        Args:
-          y:
-            A 2-D tensor of shape (N, context_size).
-        Returns
-          Return a 2-D tensor of shape (N, joiner_dim)
-        """
-        need_pad = False
-        decoder_output = self.decoder(y, need_pad=need_pad)
-        decoder_output = decoder_output.squeeze(1)
-        output = self.decoder_proj(decoder_output)
-
-        return output
-
-
-class OnnxJoiner(nn.Module):
-    """A wrapper for the joiner"""
-
-    def __init__(self, output_linear: nn.Linear):
-        super().__init__()
-        self.output_linear = output_linear
-
-    def forward(
-            self,
-            encoder_out: torch.Tensor,
-            decoder_out: torch.Tensor,
-    ) -> torch.Tensor:
-        """
-        Args:
-          encoder_out:
-            A 2-D tensor of shape (N, joiner_dim)
-          decoder_out:
-            A 2-D tensor of shape (N, joiner_dim)
-        Returns:
-          Return a 2-D tensor of shape (N, vocab_size)
-        """
-        logit = encoder_out + decoder_out
-        logit = self.output_linear(torch.tanh(logit))
-        return logit
-
-
-def export_encoder_model_onnx(
-        encoder_model: OnnxEncoder,
-        encoder_filename: str,
-        opset_version: int = 11,
-) -> None:
-    """Export the given encoder model to ONNX format.
-    The exported model has two inputs:
-
-        - x, a tensor of shape (N, T, C); dtype is torch.float32
-        - x_lens, a tensor of shape (N,); dtype is torch.int64
-
-    and it has two outputs:
-
-        - encoder_out, a tensor of shape (N, T', joiner_dim)
-        - encoder_out_lens, a tensor of shape (N,)
-
-    Args:
-      encoder_model:
-        The input encoder model
-      encoder_filename:
-        The filename to save the exported ONNX model.
-      opset_version:
-        The opset version to use.
-    """
-    x = torch.zeros(1, 100, 80, dtype=torch.float32)
-    x_lens = torch.tensor([100], dtype=torch.int64)
-
-    encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
-    encoder_model.save(str(encoder_filename).replace("onnx", "pt"))
-    torch.onnx.export(
-        encoder_model,
-        (x, x_lens),
-        encoder_filename,
-        verbose=False,
-        opset_version=opset_version,
-        input_names=["x", "x_lens"],
-        output_names=["encoder_out", "encoder_out_lens"],
-        dynamic_axes={
-            "x": {0: "N", 1: "T"},
-            "x_lens": {0: "N"},
-            "encoder_out": {0: "N", 1: "T"},
-            "encoder_out_lens": {0: "N"},
-        },
-    )
-
-    meta_data = {
-        "model_type": "zipformer2",
-        "version": "1",
-        "model_author": "k2-fsa",
-        "comment": "non-streaming zipformer2",
-    }
-    logging.info(f"meta_data: {meta_data}")
-
-    add_meta_data(filename=encoder_filename, meta_data=meta_data)
-
-
-def export_decoder_model_onnx(
-        decoder_model: OnnxDecoder,
-        decoder_filename: str,
-        opset_version: int = 11,
-) -> None:
-    """Export the decoder model to ONNX format.
-
-    The exported model has one input:
-
-        - y: a torch.int64 tensor of shape (N, decoder_model.context_size)
-
-    and has one output:
-
-        - decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
-
-    Args:
-      decoder_model:
-        The decoder model to be exported.
-      decoder_filename:
-        Filename to save the exported ONNX model.
-      opset_version:
-        The opset version to use.
-    """
-    context_size = decoder_model.decoder.context_size
-    vocab_size = decoder_model.decoder.vocab_size
-
-    y = torch.zeros(10, context_size, dtype=torch.int64)
-    ts_decoder_model = torch.jit.trace(decoder_model, y)
-    ts_decoder_model.save(str(decoder_filename).replace("onnx", "pt"))
-    decoder_model = torch.jit.script(decoder_model)
-
-    torch.onnx.export(
-        decoder_model,
-        y,
-        decoder_filename,
-        verbose=False,
-        opset_version=opset_version,
-        input_names=["y"],
-        output_names=["decoder_out"],
-        dynamic_axes={
-            "y": {0: "N"},
-            "decoder_out": {0: "N"},
-        },
-    )
-
-    meta_data = {
-        "context_size": str(context_size),
-        "vocab_size": str(vocab_size),
-    }
-    add_meta_data(filename=decoder_filename, meta_data=meta_data)
-
-
-def export_joiner_model_onnx(
-        joiner_model: nn.Module,
-        joiner_filename: str,
-        opset_version: int = 11,
-) -> None:
-    """Export the joiner model to ONNX format.
-    The exported joiner model has two inputs:
-
-        - encoder_out: a tensor of shape (N, joiner_dim)
-        - decoder_out: a tensor of shape (N, joiner_dim)
-
-    and produces one output:
-
-        - logit: a tensor of shape (N, vocab_size)
-    """
-    joiner_dim = joiner_model.output_linear.weight.shape[1]
-    logging.info(f"joiner dim: {joiner_dim}")
-
-    projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
-    projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
-    ts_joiner_model = torch.jit.trace(joiner_model, (projected_encoder_out, projected_decoder_out))
-    ts_joiner_model.save(str(joiner_filename).replace("onnx", "pt"))
-
-    torch.onnx.export(
-        joiner_model,
-        (projected_encoder_out, projected_decoder_out),
-        joiner_filename,
-        verbose=False,
-        opset_version=opset_version,
-        input_names=[
-            "encoder_out",
-            "decoder_out",
-        ],
-        output_names=["logit"],
-        dynamic_axes={
-            "encoder_out": {0: "N"},
-            "decoder_out": {0: "N"},
-            "logit": {0: "N"},
-        },
-    )
-    meta_data = {
-        "joiner_dim": str(joiner_dim),
-    }
-    add_meta_data(filename=joiner_filename, meta_data=meta_data)
-
-
-@torch.no_grad()
-def main():
-    args = get_parser().parse_args()
-    args.exp_dir = Path(args.exp_dir)
-
-    params = get_params()
-    params.update(vars(args))
-
-    device = torch.device("cpu")
-    if torch.cuda.is_available():
-        device = torch.device("cuda", 0)
-
-    logging.info(f"device: {device}")
-
-    token_table = k2.SymbolTable.from_file(params.tokens)
-    params.blank_id = token_table["<blk>"]
-    params.vocab_size = num_tokens(token_table) + 1
-
-    logging.info(params)
-
-    logging.info("About to create model")
-    model = get_model(params)
-
-    model.to(device)
-
-    if not params.use_averaged_model:
-        if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                        : params.avg
-                        ]
-            if len(filenames) == 0:
-                raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
-                )
-            elif len(filenames) < params.avg:
-                raise ValueError(
-                    f"Not enough checkpoints ({len(filenames)}) found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
-                )
-            logging.info(f"averaging {filenames}")
-            model.to(device)
-            model.load_state_dict(average_checkpoints(filenames, device=device))
-        elif params.avg == 1:
-            load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
-        else:
-            start = params.epoch - params.avg + 1
-            filenames = []
-            for i in range(start, params.epoch + 1):
-                if i >= 1:
-                    filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
-            logging.info(f"averaging {filenames}")
-            model.to(device)
-            model.load_state_dict(average_checkpoints(filenames, device=device))
-    else:
-        if params.iter > 0:
-            filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
-                        : params.avg + 1
-                        ]
-            if len(filenames) == 0:
-                raise ValueError(
-                    f"No checkpoints found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
-                )
-            elif len(filenames) < params.avg + 1:
-                raise ValueError(
-                    f"Not enough checkpoints ({len(filenames)}) found for"
-                    f" --iter {params.iter}, --avg {params.avg}"
-                )
-            filename_start = filenames[-1]
-            filename_end = filenames[0]
-            logging.info(
-                "Calculating the averaged model over iteration checkpoints"
-                f" from {filename_start} (excluded) to {filename_end}"
-            )
-            model.to(device)
-            model.load_state_dict(
-                average_checkpoints_with_averaged_model(
-                    filename_start=filename_start,
-                    filename_end=filename_end,
-                    device=device,
-                )
-            )
-        else:
-            assert params.avg > 0, params.avg
-            start = params.epoch - params.avg
-            assert start >= 1, start
-            filename_start = f"{params.exp_dir}/epoch-{start}.pt"
-            filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
-            logging.info(
-                f"Calculating the averaged model over epoch range from "
-                f"{start} (excluded) to {params.epoch}"
-            )
-            model.to(device)
-            model.load_state_dict(
-                average_checkpoints_with_averaged_model(
-                    filename_start=filename_start,
-                    filename_end=filename_end,
-                    device=device,
-                )
-            )
-
-    model.to("cpu")
-    model.eval()
-
-    convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
-
-    encoder = OnnxEncoder(
-        encoder=model.encoder,
-        encoder_embed=model.encoder_embed,
-        encoder_proj=model.joiner.encoder_proj,
-    )
-
-    decoder = OnnxDecoder(
-        decoder=model.decoder,
-        decoder_proj=model.joiner.decoder_proj,
-    )
-
-    joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
-
-    encoder_num_param = sum([p.numel() for p in encoder.parameters()])
-    decoder_num_param = sum([p.numel() for p in decoder.parameters()])
-    joiner_num_param = sum([p.numel() for p in joiner.parameters()])
-    total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
-    logging.info(f"encoder parameters: {encoder_num_param}")
-    logging.info(f"decoder parameters: {decoder_num_param}")
-    logging.info(f"joiner parameters: {joiner_num_param}")
-    logging.info(f"total parameters: {total_num_param}")
-
-    if params.iter > 0:
-        suffix = f"iter-{params.iter}"
-    else:
-        suffix = f"epoch-{params.epoch}"
-
-    suffix += f"-avg-{params.avg}"
-
-    opset_version = 13
-
-    logging.info("Exporting encoder")
-    encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
-    export_encoder_model_onnx(
-        encoder,
-        encoder_filename,
-        opset_version=opset_version,
-    )
-    logging.info(f"Exported encoder to {encoder_filename}")
-
-    logging.info("Exporting decoder")
-    decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
-    export_decoder_model_onnx(
-        decoder,
-        decoder_filename,
-        opset_version=opset_version,
-    )
-    logging.info(f"Exported decoder to {decoder_filename}")
-
-    logging.info("Exporting joiner")
-    joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
-    export_joiner_model_onnx(
-        joiner,
-        joiner_filename,
-        opset_version=opset_version,
-    )
-    logging.info(f"Exported joiner to {joiner_filename}")
-
-    # Generate int8 quantization models
-    # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
-
-    logging.info("Generate int8 quantization models")
-
-    encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
-    quantize_dynamic(
-        model_input=encoder_filename,
-        model_output=encoder_filename_int8,
-        op_types_to_quantize=["MatMul"],
-        weight_type=QuantType.QInt8,
-    )
-
-    decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
-    quantize_dynamic(
-        model_input=decoder_filename,
-        model_output=decoder_filename_int8,
-        op_types_to_quantize=["MatMul", "Gather"],
-        weight_type=QuantType.QInt8,
-    )
-
-    joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
-    quantize_dynamic(
-        model_input=joiner_filename,
-        model_output=joiner_filename_int8,
-        op_types_to_quantize=["MatMul"],
-        weight_type=QuantType.QInt8,
-    )
-
-
-if __name__ == "__main__":
-    formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
-    logging.basicConfig(format=formatter, level=logging.INFO)
-    main()
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_enc.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_model.py
similarity index 59%
rename from AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_enc.py
rename to AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_model.py
index 05c9e64e06..0019d4937f 100644
--- a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_enc.py
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_model.py
@@ -1,4 +1,4 @@
-# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,34 +16,35 @@ import os
 import argparse
 
 import torch
-import torch_aie
-from torch_aie import _enums
+import mindietorch
+from mindietorch import _enums
 
-def export_torch_aie(opt):
+def export_mindietorch(opt):
     trace_model = torch.jit.load(opt.torch_script_path)
     trace_model.eval()
 
-    torch_aie.set_device(0)
+    mindietorch.set_device(0)
     inputs = []
-    # x = torch.zeros(1, 100, 80, dtype=torch.float32)
-    # x_lens = torch.tensor([100], dtype=torch.int64)
-    inputs.append(torch_aie.Input([1, 100, 80], dtype = torch_aie.dtype.FLOAT))
-    inputs.append(torch_aie.Input([1], dtype = torch_aie.dtype.INT64))
+    if opt.export_part == 'encoder':
+        inputs.append(mindietorch.Input([opt.batch_size, 100, 80], dtype = mindietorch.dtype.FLOAT))
+        inputs.append(mindietorch.Input([opt.batch_size], dtype = mindietorch.dtype.INT64))
+    elif opt.export_part == 'decoder':
+        inputs.append(mindietorch.Input([opt.batch_size, 2], dtype=mindietorch.dtype.INT64))
+    else:
+        inputs.append(mindietorch.Input([opt.batch_size, 512], dtype=mindietorch.dtype.FLOAT))
+        inputs.append(mindietorch.Input([opt.batch_size, 512], dtype=mindietorch.dtype.FLOAT))
 
-    torchaie_model = torch_aie.compile(
+    torchaie_model = mindietorch.compile(
         trace_model,
         inputs=inputs,
-        precision_policy=_enums.PrecisionPolicy.FP16,
-        # truncate_long_and_double=True,
-        # require_full_compilation=False,
-        # allow_tensor_replace_int=False,
-        # min_block_size=3,
-        # torch_executed_ops=[],
-        soc_version='Ascend310P3',
+        precision_policy=_enums.PrecisionPolicy.FP32,
+        soc_version=opt.soc_version,
         optimization_level=0
-        )
+    )
     suffix = os.path.splitext(opt.torch_script_path)[-1]
-    saved_name = os.path.basename(opt.torch_script_path).split('.')[0] + f"_torch_aie_bs{opt.batch_size}" + suffix
+    saved_name = os.path.basename(opt.torch_script_path).split('.')[0] + f"_mindietorch_bs{opt.batch_size}" + suffix
+    if not os.path.exists(opt.save_path):
+        os.makedirs(opt.save_path)
     torchaie_model.save(os.path.join(opt.save_path, saved_name))
     print("torch aie tdnn compiled done. saved model is ", os.path.join(opt.save_path, saved_name))
 
@@ -51,6 +52,7 @@ def export_torch_aie(opt):
 def parse_opt():
     parser = argparse.ArgumentParser()
     parser.add_argument('--torch_script_path', type=str, default='../egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/encoder-epoch-12-avg-1.pt', help='trace model path')
+    parser.add_argument('--export_part', type=str, default='encoder', help='the part of model(encoder, decoder, and joiner) to be exported.')
     parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version')
     parser.add_argument('--batch_size', type=int, default=1, help='batch size')
     parser.add_argument('--save_path', type=str, default='./pt_compiled_model/', help='compiled model path')
@@ -58,7 +60,7 @@ def parse_opt():
     return opt
 
 def main(opt):
-    export_torch_aie(opt)
+    export_mindietorch(opt)
 
 if __name__ == '__main__':
     opt = parse_opt()
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_dec.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_dec.py
deleted file mode 100644
index bb341fca41..0000000000
--- a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_dec.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import argparse
-
-import torch
-import torch_aie
-from torch_aie import _enums
-
-def export_torch_aie(opt):
-    trace_model = torch.jit.load(opt.torch_script_path)
-    trace_model.eval()
-
-    torch_aie.set_device(0)
-    inputs = []
-    # inputs.append(torch_aie.Input([10, 2], dtype = torch_aie.dtype.INT64))
-    inputs.append(torch_aie.Input([opt.batch_size, 2], dtype = torch_aie.dtype.INT64))
-
-    torchaie_model = torch_aie.compile(
-        trace_model,
-        inputs=inputs,
-        precision_policy=_enums.PrecisionPolicy.FP16,
-        truncate_long_and_double=True,
-        require_full_compilation=False,
-        allow_tensor_replace_int=False,
-        min_block_size=3,
-        torch_executed_ops=[],
-        soc_version='Ascend310P3',
-        optimization_level=0)
-    suffix = os.path.splitext(opt.torch_script_path)[-1]
-    saved_name = os.path.basename(opt.torch_script_path).split('.')[0] + f"_torch_aie_bs{opt.batch_size}" + suffix
-    torchaie_model.save(os.path.join(opt.save_path, saved_name))
-    print("torch aie tdnn compiled done. saved model is ", os.path.join(opt.save_path, saved_name))
-
-
-def parse_opt():
-    parser = argparse.ArgumentParser()
-    parser.add_argument('--torch_script_path', type=str, default='../egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/decoder-epoch-12-avg-1.pt', help='trace model path')
-    parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version')
-    parser.add_argument('--batch_size', type=int, default=1, help='batch size')
-    parser.add_argument('--save_path', type=str, default='./pt_compiled_model/', help='compiled model path')
-    opt = parser.parse_args()
-    return opt
-
-def main(opt):
-    export_torch_aie(opt)
-
-if __name__ == '__main__':
-    opt = parse_opt()
-    main(opt)
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_join.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_join.py
deleted file mode 100644
index ae0b1e5b47..0000000000
--- a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/export_torch_aie_ts_join.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import argparse
-
-import torch
-import torch_aie
-from torch_aie import _enums
-
-def export_torch_aie(opt):
-    trace_model = torch.jit.load(opt.torch_script_path)
-    trace_model.eval()
-
-    torch_aie.set_device(0)
-    inputs = []
-    # inputs.append(torch_aie.Input([11, 512], dtype = torch_aie.dtype.FLOAT))
-    # inputs.append(torch_aie.Input([11, 512], dtype = torch_aie.dtype.FLOAT))
-    inputs.append(torch_aie.Input([opt.batch_size, 512], dtype = torch_aie.dtype.FLOAT))
-    inputs.append(torch_aie.Input([opt.batch_size, 512], dtype = torch_aie.dtype.FLOAT))
-    # projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
-    # projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
-    # inputs.append(torch_aie.Input([10], dtype = torch_aie.dtype.INT64))
-
-    torchaie_model = torch_aie.compile(
-        trace_model,
-        inputs=inputs,
-        precision_policy=_enums.PrecisionPolicy.FP16,
-        truncate_long_and_double=True,
-        require_full_compilation=False,
-        allow_tensor_replace_int=False,
-        min_block_size=3,
-        torch_executed_ops=[],
-        soc_version='Ascend310P3',
-        optimization_level=0)
-    suffix = os.path.splitext(opt.torch_script_path)[-1]
-    saved_name = os.path.basename(opt.torch_script_path).split('.')[0] + f"_torch_aie_bs{opt.batch_size}" + suffix
-    torchaie_model.save(os.path.join(opt.save_path, saved_name))
-    print("torch aie tdnn compiled done. saved model is ", os.path.join(opt.save_path, saved_name))
-
-
-def parse_opt():
-    parser = argparse.ArgumentParser()
-    parser.add_argument('--torch_script_path', type=str, default='../egs/librispeech/ASR/icefall-asr-zipformer-wenetspeech-20230615/exp/joiner-epoch-12-avg-1.pt', help='trace model path')
-    parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version')
-    parser.add_argument('--batch_size', type=int, default=1, help='batch size')
-    parser.add_argument('--save_path', type=str, default='./pt_compiled_model/', help='compiled model path')
-    opt = parser.parse_args()
-    return opt
-
-def main(opt):
-    export_torch_aie(opt)
-
-if __name__ == '__main__':
-    opt = parse_opt()
-    main(opt)
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_dec.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_dec.py
index d8934b5f29..9149479660 100644
--- a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_dec.py
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_dec.py
@@ -1,4 +1,4 @@
-# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,7 +16,7 @@ import time
 from tqdm import tqdm
 
 import torch
-import torch_aie
+import mindietorch
 
 
 def forward_infer(model, dataloader, batchsize, device_id):
@@ -37,8 +37,8 @@ def forward_infer(model, dataloader, batchsize, device_id):
 def pt_infer(model, input_li_1, device_id, loop_num, inference_time):
 
     input_npu_li_1 = input_li_1.to("npu:" + str(device_id))
-    stream = torch_aie.npu.Stream("npu:" + str(device_id))
-    with torch_aie.npu.stream(stream):
+    stream = mindietorch.npu.Stream("npu:" + str(device_id))
+    with mindietorch.npu.stream(stream):
         inf_start = time.time()
         output_npu = model.forward(input_npu_li_1)
         stream.synchronize()
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_enc.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_enc.py
index 61396050b6..c2197fc9e0 100644
--- a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_enc.py
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_enc.py
@@ -1,4 +1,4 @@
-# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,9 +14,8 @@
 
 import time
 from tqdm import tqdm
-import numpy as np
 import torch
-import torch_aie
+import mindietorch
 
 
 def forward_infer(model, dataloader, batchsize, device_id):
@@ -38,8 +37,8 @@ def pt_infer(model, input_li_1, input_li_2, device_id, loop_num, inference_time)
 
     input_npu_li_1 = input_li_1.to("npu:" + str(device_id))
     input_npu_li_2 = input_li_2.to("npu:" + str(device_id))
-    stream = torch_aie.npu.Stream("npu:" + str(device_id))
-    with torch_aie.npu.stream(stream):
+    stream = mindietorch.npu.Stream("npu:" + str(device_id))
+    with mindietorch.npu.stream(stream):
         inf_start = time.time()
         output_npu = model.forward(input_npu_li_1, input_npu_li_2)
         stream.synchronize()
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_join.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_join.py
index 45b8ec0f93..858046037d 100644
--- a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_join.py
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/model_pt_join.py
@@ -1,4 +1,4 @@
-# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,7 +16,7 @@ import time
 from tqdm import tqdm
 
 import torch
-import torch_aie
+import mindietorch
 
 
 def forward_infer(model, dataloader, batchsize, device_id):
@@ -38,8 +38,8 @@ def pt_infer(model, input_li_1, input_li_2, device_id, loop_num, inference_time)
 
     input_npu_li_1 = input_li_1.to("npu:" + str(device_id))
     input_npu_li_2 = input_li_2.to("npu:" + str(device_id))
-    stream = torch_aie.npu.Stream("npu:" + str(device_id))
-    with torch_aie.npu.stream(stream):
+    stream = mindietorch.npu.Stream("npu:" + str(device_id))
+    with mindietorch.npu.stream(stream):
         inf_start = time.time()
         output_npu = model.forward(input_npu_li_1, input_npu_li_2)
         stream.synchronize()
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_dec.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_dec.py
index 3106753822..8363fe5177 100644
--- a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_dec.py
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_dec.py
@@ -1,4 +1,4 @@
-# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,9 +17,9 @@ import copy
 import argparse
 
 import torch
-import torch_aie
+import mindietorch
 import numpy as np
-from torch_aie import _enums
+from mindietorch import _enums
 from torch.utils.data import dataloader
 
 from model_pt_dec import forward_infer
@@ -58,56 +58,12 @@ class _RepeatSampler:
         while True:
             yield from iter(self.sampler)
 
-# def collate_fn(batch):
-#     """
-#     data preprocessing
-#     """
-#     def func(p):
-#         """
-#         data size
-#         """
-#         return p[0].size(1)
-
-#     batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True)
-#     longest_sample = max(batch, key=func)[0]
-#     freq_size = longest_sample.size(0)
-#     minibatch_size = len(batch)
-#     max_seqlength = longest_sample.size(1)
-#     inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength)
-#     input_percentages = torch.FloatTensor(minibatch_size)
-#     for x in range(minibatch_size):
-#         sample = batch[x]
-#         tensor = sample[0]
-#         seq_length = tensor.size(1)
-#         inputs[x][0].narrow(1, 0, seq_length).copy_(tensor)
-#         input_percentages[x] = seq_length / float(max_seqlength)
-#     return inputs, input_percentages, []
-
 
 def get_dataloader(opt):
-    # with open(to_absolute_path(opt.label_file)) as label_file:
-    #     labels = json.load(label_file)
-
-    # dataset = SpectrogramDataset(
-    #     audio_conf=DataConfig.spect,
-    #     input_path=opt.data_file,
-    #     labels=labels,
-    #     normalize=True,
-    #     aug_cfg=DataConfig.augmentation
-    # )
-    # inputs, input_percentages, _ = collate_fn(dataset)
-    # input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
-    # print(inputs[0])
-    # print(input_sizes.tolist())
-
-    # datasets = [[inputs[i], input_sizes[i]] for i in range(len(input_sizes))]
     x = torch.zeros(2, dtype=torch.int64)
-    # x_lens = torch.tensor([100], dtype=torch.int64)
-    # x_lens = 100
     datasets = []
     for i in range(20):
         datasets.append([copy.deepcopy(x)])
-    # print(datasets)
     while len(datasets) % opt.batch_size != 0:
         datasets.append(datasets[-1])
     m = 1
@@ -136,66 +92,40 @@ def save_tensor_arr_to_file(arr, file_path):
     with open(file_path, "w", encoding='utf-8') as f:
         f.write(write_sen)
 
-def save_size_to_file(size, file_path):
-    write_sen = "" + str(size) + " "
-    with open(file_path, "w", encoding='utf-8') as f:
-        f.write(write_sen)
 
 def main(opt):
     # load model
     model = torch.jit.load(opt.model)
     batch_size = opt.batch_size
-    torch_aie.set_device(opt.device_id)
+    mindietorch.set_device(opt.device_id)
     if opt.need_compile:
         inputs = []
-        inputs.append(torch_aie.Input([10, 2], dtype=torch_aie.dtype.INT64))
-        # inputs.append(torch_aie.Input([opt.batch_size], dtype=torch_aie.dtype.INT32))
-
-        model = torch_aie.compile(
+        inputs.append(mindietorch.Input([batch_size, 2], dtype=mindietorch.dtype.INT64))
+        model = mindietorch.compile(
             model,
             inputs=inputs,
-            precision_policy=_enums.PrecisionPolicy.FP16,
-            truncate_long_and_double=True,
-            require_full_compilation=False,
-            allow_tensor_replace_int=False,
-            min_block_size=3,
-            torch_executed_ops=[],
-            soc_version='Ascend310P3',
-            optimization_level=0)
-
+            precision_policy=_enums.PrecisionPolicy.FP32,
+            soc_version=opt.soc_version,
+            optimization_level=0
+        )
     dataloader = get_dataloader(opt)
     pred_results = forward_infer(model, dataloader, batch_size, opt.device_id)
-    # for index, res in enumerate(pred_results):
-    #     print(index, "  ", res)
-    #     print(res[0].shape)
 
     if opt.batch_size == 1 and opt.multi == 1:
         result_path = opt.result_path
         if(os.path.exists(result_path) == False):
-                os.makedirs(result_path)
+            os.makedirs(result_path)
         for index, res in enumerate(pred_results):
-            # for i in range(batch_size):
-            # result_fname_0 = 'data' + str(index * batch_size + i + 1) + '_0.txt'
-            # result_fname_1 = 'data' + str(index * batch_size + i + 1) + '_1.txt'
             result_fname_0 = 'data' + str(index) + '_0.txt'
-            # result_fname_1 = 'data' + str(index) + '_1.txt'
-            # res = np.array(res)
-            # save_tensor_arr_to_file(np.array(res[0][i]), os.path.join(result_path, result_fname_0))
-            # save_size_to_file(res[1].numpy()[i], os.path.join(result_path, result_fname_1))
             save_tensor_arr_to_file(np.array(res), os.path.join(result_path, result_fname_0))
-            # save_size_to_file(res[1].numpy()[0], os.path.join(result_path, result_fname_1))
-
-
 
 if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='DeepSpeech2 offline model inference.')
+    parser = argparse.ArgumentParser(description='Zipformer offline model decoder inference.')
     parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version')
-    parser.add_argument('--model', type=str, default="./pt_compiled_model/decoder-epoch-12-avg-1_torch_aie_bs1.pt", help='ts model path')
+    parser.add_argument('--model', type=str, default="./pt_compiled_model/decoder-epoch-12-avg-1_mindietorch_bs1.pt", help='ts model path')
     parser.add_argument('--need_compile', action="store_true", help='if the loaded model needs to be compiled or not')
     parser.add_argument('--batch_size', type=int, default=1, help='batch size')
     parser.add_argument('--device_id', type=int, default=0, help='device id')
-    parser.add_argument('--data_file', default='./deepspeech.pytorch/data/an4_test_manifest.json')
-    parser.add_argument('--label_file', default='./deepspeech.pytorch/labels.json')
     parser.add_argument('--result_path', default='result/decoder')
     parser.add_argument('--multi', type=int, default=1, help='multiples of dataset replication for enough infer loop. if multi != 1, the pred result will not be stored.')
     opt = parser.parse_args()
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_enc.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_enc.py
index 496e767fe0..55af8f7239 100644
--- a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_enc.py
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_enc.py
@@ -1,4 +1,4 @@
-# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,9 +17,9 @@ import copy
 import argparse
 
 import torch
-import torch_aie
+import mindietorch
 import numpy as np
-from torch_aie import _enums
+from mindietorch import _enums
 from torch.utils.data import dataloader
 
 from model_pt_enc import forward_infer
@@ -58,56 +58,12 @@ class _RepeatSampler:
         while True:
             yield from iter(self.sampler)
 
-# def collate_fn(batch):
-#     """
-#     data preprocessing
-#     """
-#     def func(p):
-#         """
-#         data size
-#         """
-#         return p[0].size(1)
-
-#     batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True)
-#     longest_sample = max(batch, key=func)[0]
-#     freq_size = longest_sample.size(0)
-#     minibatch_size = len(batch)
-#     max_seqlength = longest_sample.size(1)
-#     inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength)
-#     input_percentages = torch.FloatTensor(minibatch_size)
-#     for x in range(minibatch_size):
-#         sample = batch[x]
-#         tensor = sample[0]
-#         seq_length = tensor.size(1)
-#         inputs[x][0].narrow(1, 0, seq_length).copy_(tensor)
-#         input_percentages[x] = seq_length / float(max_seqlength)
-#     return inputs, input_percentages, []
-
-
 def get_dataloader(opt):
-    # with open(to_absolute_path(opt.label_file)) as label_file:
-    #     labels = json.load(label_file)
-
-    # dataset = SpectrogramDataset(
-    #     audio_conf=DataConfig.spect,
-    #     input_path=opt.data_file,
-    #     labels=labels,
-    #     normalize=True,
-    #     aug_cfg=DataConfig.augmentation
-    # )
-    # inputs, input_percentages, _ = collate_fn(dataset)
-    # input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
-    # print(inputs[0])
-    # print(input_sizes.tolist())
-
-    # datasets = [[inputs[i], input_sizes[i]] for i in range(len(input_sizes))]
     x = torch.zeros(100, 80, dtype=torch.float32)
-    # x_lens = torch.tensor([100], dtype=torch.int64)
     x_lens = 100
     datasets = []
     for i in range(20):
         datasets.append([copy.deepcopy(x), copy.deepcopy(x_lens)])
-    # print(datasets)
     while len(datasets) % opt.batch_size != 0:
         datasets.append(datasets[-1])
     m = 1
@@ -116,7 +72,7 @@ def get_dataloader(opt):
         datasets += datasets_orig
         m += 1
 
-    loader =  InfiniteDataLoader  # only DataLoader allows for attribute updates
+    loader = InfiniteDataLoader  # only DataLoader allows for attribute updates
     print("OPT_BATCHSIZE: ", opt.batch_size)
     return loader(datasets,
                   batch_size=opt.batch_size,
@@ -145,57 +101,41 @@ def main(opt):
     # load model
     model = torch.jit.load(opt.model)
     batch_size = opt.batch_size
-    torch_aie.set_device(opt.device_id)
+    mindietorch.set_device(opt.device_id)
     if opt.need_compile:
         inputs = []
-        inputs.append(torch_aie.Input([opt.batch_size, 100, 80], dtype=torch_aie.dtype.FLOAT))
-        inputs.append(torch_aie.Input([opt.batch_size], dtype=torch_aie.dtype.INT32))
+        inputs.append(mindietorch.Input([opt.batch_size, 100, 80], dtype=mindietorch.dtype.FLOAT))
+        inputs.append(mindietorch.Input([opt.batch_size], dtype=mindietorch.dtype.INT32))
 
-        model = torch_aie.compile(
+        model = mindietorch.compile(
             model,
             inputs=inputs,
-            precision_policy=_enums.PrecisionPolicy.FP16,
-            truncate_long_and_double=True,
-            require_full_compilation=False,
-            allow_tensor_replace_int=False,
-            min_block_size=3,
-            torch_executed_ops=[],
-            soc_version='Ascend310P3',
-            optimization_level=0)
+            precision_policy=_enums.PrecisionPolicy.FP32,
+            soc_version=opt.soc_version,
+            optimization_level=0
+        )
 
     dataloader = get_dataloader(opt)
     pred_results = forward_infer(model, dataloader, batch_size, opt.device_id)
-    for index, res in enumerate(pred_results):
-        print(index, "  ", res)
-        print(res[0].shape)
 
     if opt.batch_size == 1 and opt.multi == 1:
         result_path = opt.result_path
         if(os.path.exists(result_path) == False):
-                os.makedirs(result_path)
+            os.makedirs(result_path)
         for index, res in enumerate(pred_results):
-            # for i in range(batch_size):
-            # result_fname_0 = 'data' + str(index * batch_size + i + 1) + '_0.txt'
-            # result_fname_1 = 'data' + str(index * batch_size + i + 1) + '_1.txt'
             result_fname_0 = 'data' + str(index) + '_0.txt'
             result_fname_1 = 'data' + str(index) + '_1.txt'
-            # res = np.array(res)
-            # save_tensor_arr_to_file(np.array(res[0][i]), os.path.join(result_path, result_fname_0))
-            # save_size_to_file(res[1].numpy()[i], os.path.join(result_path, result_fname_1))
             save_tensor_arr_to_file(np.array(res[0]), os.path.join(result_path, result_fname_0))
             save_size_to_file(res[1].numpy()[0], os.path.join(result_path, result_fname_1))
 
 
-
 if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='DeepSpeech2 offline model inference.')
+    parser = argparse.ArgumentParser(description='Zipformer offline model encoder inference.')
     parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version')
-    parser.add_argument('--model', type=str, default="./pt_compiled_model/encoder-epoch-12-avg-1_torch_aie_bs1.pt", help='ts model path')
+    parser.add_argument('--model', type=str, default="./pt_compiled_model/encoder-epoch-12-avg-1_mindietorch_bs1.pt", help='ts model path')
     parser.add_argument('--need_compile', action="store_true", help='if the loaded model needs to be compiled or not')
     parser.add_argument('--batch_size', type=int, default=1, help='batch size')
     parser.add_argument('--device_id', type=int, default=0, help='device id')
-    parser.add_argument('--data_file', default='./deepspeech.pytorch/data/an4_test_manifest.json')
-    parser.add_argument('--label_file', default='./deepspeech.pytorch/labels.json')
     parser.add_argument('--result_path', default='result/encoder')
     parser.add_argument('--multi', type=int, default=1, help='multiples of dataset replication for enough infer loop. if multi != 1, the pred result will not be stored.')
     opt = parser.parse_args()
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_join.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_join.py
index a4fe80f4de..9f5b31ae6c 100644
--- a/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_join.py
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/icefall_pt/pt_val_join.py
@@ -1,4 +1,4 @@
-# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+# Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,9 +17,9 @@ import copy
 import argparse
 
 import torch
-import torch_aie
+import mindietorch
 import numpy as np
-from torch_aie import _enums
+from mindietorch import _enums
 from torch.utils.data import dataloader
 
 from model_pt_join import forward_infer
@@ -58,57 +58,13 @@ class _RepeatSampler:
         while True:
             yield from iter(self.sampler)
 
-# def collate_fn(batch):
-#     """
-#     data preprocessing
-#     """
-#     def func(p):
-#         """
-#         data size
-#         """
-#         return p[0].size(1)
-
-#     batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True)
-#     longest_sample = max(batch, key=func)[0]
-#     freq_size = longest_sample.size(0)
-#     minibatch_size = len(batch)
-#     max_seqlength = longest_sample.size(1)
-#     inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength)
-#     input_percentages = torch.FloatTensor(minibatch_size)
-#     for x in range(minibatch_size):
-#         sample = batch[x]
-#         tensor = sample[0]
-#         seq_length = tensor.size(1)
-#         inputs[x][0].narrow(1, 0, seq_length).copy_(tensor)
-#         input_percentages[x] = seq_length / float(max_seqlength)
-#     return inputs, input_percentages, []
-
 
 def get_dataloader(opt):
-    # with open(to_absolute_path(opt.label_file)) as label_file:
-    #     labels = json.load(label_file)
-
-    # dataset = SpectrogramDataset(
-    #     audio_conf=DataConfig.spect,
-    #     input_path=opt.data_file,
-    #     labels=labels,
-    #     normalize=True,
-    #     aug_cfg=DataConfig.augmentation
-    # )
-    # inputs, input_percentages, _ = collate_fn(dataset)
-    # input_sizes = input_percentages.mul_(int(inputs.size(3))).int()
-    # print(inputs[0])
-    # print(input_sizes.tolist())
-
-    # datasets = [[inputs[i], input_sizes[i]] for i in range(len(input_sizes))]
     x = torch.zeros(512, dtype=torch.float32)
     y = torch.zeros(512, dtype=torch.float32)
-    # x_lens = torch.tensor([100], dtype=torch.int64)
-    # x_lens = 100
     datasets = []
     for i in range(20):
         datasets.append([copy.deepcopy(x), copy.deepcopy(y)])
-    # print(datasets)
     while len(datasets) % opt.batch_size != 0:
         datasets.append(datasets[-1])
     m = 1
@@ -137,66 +93,44 @@ def save_tensor_arr_to_file(arr, file_path):
     with open(file_path, "w", encoding='utf-8') as f:
         f.write(write_sen)
 
-def save_size_to_file(size, file_path):
-    write_sen = "" + str(size) + " "
-    with open(file_path, "w", encoding='utf-8') as f:
-        f.write(write_sen)
 
 def main(opt):
     # load model
     model = torch.jit.load(opt.model)
     batch_size = opt.batch_size
-    torch_aie.set_device(opt.device_id)
+    mindietorch.set_device(opt.device_id)
     if opt.need_compile:
         inputs = []
-        inputs.append(torch_aie.Input([opt.batch_size, 512], dtype = torch_aie.dtype.FLOAT))
-        inputs.append(torch_aie.Input([opt.batch_size, 512], dtype = torch_aie.dtype.FLOAT))
+        inputs.append(mindietorch.Input([opt.batch_size, 512], dtype = mindietorch.dtype.FLOAT))
+        inputs.append(mindietorch.Input([opt.batch_size, 512], dtype = mindietorch.dtype.FLOAT))
 
-        model = torch_aie.compile(
+        model = mindietorch.compile(
             model,
             inputs=inputs,
-            precision_policy=_enums.PrecisionPolicy.FP16,
-            truncate_long_and_double=True,
-            require_full_compilation=False,
-            allow_tensor_replace_int=False,
-            min_block_size=3,
-            torch_executed_ops=[],
-            soc_version='Ascend310P3',
-            optimization_level=0)
+            precision_policy=_enums.PrecisionPolicy.FP32,
+            soc_version=opt.soc_version,
+            optimization_level=0
+        )
 
     dataloader = get_dataloader(opt)
     pred_results = forward_infer(model, dataloader, batch_size, opt.device_id)
-    # for index, res in enumerate(pred_results):
-    #     print(index, "  ", res)
-    #     print(res[0].shape)
 
     if opt.batch_size == 1 and opt.multi == 1:
         result_path = opt.result_path
         if(os.path.exists(result_path) == False):
-                os.makedirs(result_path)
+            os.makedirs(result_path)
         for index, res in enumerate(pred_results):
-            # for i in range(batch_size):
-            # result_fname_0 = 'data' + str(index * batch_size + i + 1) + '_0.txt'
-            # result_fname_1 = 'data' + str(index * batch_size + i + 1) + '_1.txt'
             result_fname_0 = 'data' + str(index) + '_0.txt'
-            # result_fname_1 = 'data' + str(index) + '_1.txt'
-            # res = np.array(res)
-            # save_tensor_arr_to_file(np.array(res[0][i]), os.path.join(result_path, result_fname_0))
-            # save_size_to_file(res[1].numpy()[i], os.path.join(result_path, result_fname_1))
             save_tensor_arr_to_file(np.array(res), os.path.join(result_path, result_fname_0))
-            # save_size_to_file(res[1].numpy()[0], os.path.join(result_path, result_fname_1))
-
 
 
 if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='DeepSpeech2 offline model inference.')
+    parser = argparse.ArgumentParser(description='Zipformer offline model joiner inference.')
     parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version')
-    parser.add_argument('--model', type=str, default="./pt_compiled_model/joiner-epoch-12-avg-1_torch_aie_bs1.pt", help='ts model path')
+    parser.add_argument('--model', type=str, default="./pt_compiled_model/joiner-epoch-12-avg-1_mindietorch_bs1.pt", help='ts model path')
     parser.add_argument('--need_compile', action="store_true", help='if the loaded model needs to be compiled or not')
     parser.add_argument('--batch_size', type=int, default=1, help='batch size')
     parser.add_argument('--device_id', type=int, default=0, help='device id')
-    parser.add_argument('--data_file', default='./deepspeech.pytorch/data/an4_test_manifest.json')
-    parser.add_argument('--label_file', default='./deepspeech.pytorch/labels.json')
     parser.add_argument('--result_path', default='result/joiner')
     parser.add_argument('--multi', type=int, default=1, help='multiples of dataset replication for enough infer loop. if multi != 1, the pred result will not be stored.')
     opt = parser.parse_args()
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/modify_decoder.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/modify_decoder.py
deleted file mode 100644
index 866216820f..0000000000
--- a/AscendIE/TorchAIE/built-in/audio/Zipformer/modify_decoder.py
+++ /dev/null
@@ -1,25 +0,0 @@
-from argparse import ArgumentParser
-
-from auto_optimizer import OnnxGraph
-
-
-def main():
-    parser = ArgumentParser()
-    parser.add_argument("--onnx", type=str, required=True)
-    args = parser.parse_args()
-
-    graph = OnnxGraph.parse(args.onnx)
-    graph.remove("/decoder/Clip")
-    gather = graph["/decoder/embedding/Gather"]
-    gather.inputs[1] = "y"
-    graph.update_map()
-    graph.infershape()
-
-    g_sim = graph.simplify()
-    save_path = args.onnx.replace(".onnx", "_modified.onnx")
-    g_sim.save(save_path)
-    print("Modified model saved to ", save_path)
-
-
-if __name__ == "__main__":
-    main()
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.patch b/AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.patch
new file mode 100644
index 0000000000..d5a19abee9
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.patch
@@ -0,0 +1,59 @@
+--- zipformer.py	2024-03-08 15:18:44.384000000 +0800
++++ zipformer.py	2024-03-08 15:18:42.056000000 +0800
+@@ -1415,7 +1415,7 @@
+         self.length_factor = length_factor
+         self.extend_pe(torch.tensor(0.0).expand(max_len))
+ 
+-    def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None:
++    def extend_pe(self, x: Tensor, left_context_len: int = 0) -> Tensor:
+         """Reset the positional encodings."""
+         T = x.size(0) + left_context_len
+ 
+@@ -1423,8 +1423,7 @@
+             # self.pe contains both positive and negative parts
+             # the length of self.pe is 2 * input_len - 1
+             if self.pe.size(0) >= T * 2 - 1:
+-                self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+-                return
++                return self.pe.to(dtype=x.dtype, device=x.device)
+ 
+         # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
+         x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
+@@ -1458,12 +1457,13 @@
+         cosines = (x_atan * freqs).cos()
+         sines = (x_atan * freqs).sin()
+ 
+-        pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
+-        pe[:, 0::2] = cosines
+-        pe[:, 1::2] = sines
+-        pe[:, -1] = 1.0  # for bias.
++        cos_shape0 = cosines.shape[0]
++        bias_one = torch.ones(cos_shape0, 1)
++        pe = torch.cat((cosines.unsqueeze(2), sines.unsqueeze(2)), dim=2)
++        pe = pe.reshape(cos_shape0, -1)
++        pe = torch.cat((pe[:, :-1], bias_one), dim=1)
+ 
+-        self.pe = pe.to(dtype=x.dtype)
++        return pe.to(dtype=x.dtype)
+ 
+     def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
+         """Create positional encoding.
+@@ -1475,14 +1475,14 @@
+         Returns:
+             positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
+         """
+-        self.extend_pe(x, left_context_len)
++        pe = self.extend_pe(x, left_context_len)
+         x_size_left = x.size(0) + left_context_len
+         # length of positive side: x.size(0) + left_context_len
+         # length of negative side: x.size(0)
+-        pos_emb = self.pe[
+-            self.pe.size(0) // 2
++        pos_emb = pe[
++            pe.size(0) // 2
+             - x_size_left
+-            + 1 : self.pe.size(0) // 2  # noqa E203
++            + 1 : pe.size(0) // 2  # noqa E203
+             + x.size(0),
+             :,
+         ]
diff --git a/AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.py b/AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.py
deleted file mode 100644
index 163eb5f6b2..0000000000
--- a/AscendIE/TorchAIE/built-in/audio/Zipformer/zipformer.py
+++ /dev/null
@@ -1,2438 +0,0 @@
-#!/usr/bin/env python3
-# Copyright    2022-2023  Xiaomi Corp.        (authors: Daniel Povey,
-#                                                       Zengwei Yao)
-#
-# See ../../../../LICENSE for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import copy
-import logging
-import math
-import random
-import warnings
-from typing import List, Optional, Tuple, Union
-
-import torch
-from encoder_interface import EncoderInterface
-from scaling import (
-    Identity,  # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
-)
-from scaling import (
-    ScaledLinear,  # not as in other dirs.. just scales down initial parameter values.
-)
-from scaling import (
-    ActivationDropoutAndLinear,
-    Balancer,
-    BiasNorm,
-    ChunkCausalDepthwiseConv1d,
-    Dropout2,
-    FloatLike,
-    ScheduledFloat,
-    Whiten,
-    convert_num_channels,
-    limit_param_value,
-    penalize_abs_values_gt,
-    softmax,
-)
-from torch import Tensor, nn
-
-
-class Zipformer2(EncoderInterface):
-    """
-    Args:
-
-    Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length
-    as downsampling_factor if they are single ints or one-element tuples.  The length of
-    downsampling_factor defines the number of stacks.
-
-        output_downsampling_factor (int): how much to downsample at the output.  Note:
-            we also downsample by a factor of 2 in the Conv2dSubsampling encoder.
-            You should probably leave this at 2.
-        downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
-           Note: this is in addition to the downsampling factor of 2 that is applied in
-           the frontend (self.encoder_embed).
-        encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
-           encoder stack.
-        num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
-        encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of
-            the encoder stacks for purposes of per-frame dropout (recommend 256 for
-            now).
-        query_head_dim (int or Tuple[int]): dimension of query and key per attention
-           head: per stack, if a tuple..
-        pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per
-           attention head
-        value_head_dim (int or Tuple[int]): dimension of value in each attention head
-        num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
-              Must be at least 4.
-        feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
-        cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
-
-        pos_dim (int): the dimension of each positional-encoding vector prior to projection,
-            e.g. 128.
-
-        dropout (float): dropout rate
-        warmup_batches (float): number of batches to warm up over; this controls
-          dropout of encoder layers.
-        causal (bool): if True, support chunkwise causal convolution.  This should
-          not hurt WER as no modeling power is lost, but the convolution modules will be
-          slightly slower and use more memory.  Enables use of the chunk_size and
-          left_context_chunks options in forward(), which simulates streaming
-          decoding.
-        chunk_size: (list of int): only set this to other than [-1] if causal;
-           the chunk size will be randomly chosen from this list.  -1 means no chunking.
-        left_context_frames: (list of int): determines the number of left-
-           context chunks for causal training; will be rounded to a number of
-           chunks.  Must not be less than cnn_module_kernel (after factoring in
-           rounding and downsampling); an error will be thrown if this is violated.
-    """
-
-    def __init__(
-        self,
-        output_downsampling_factor: int = 2,
-        downsampling_factor: Tuple[int] = (2, 4),
-        encoder_dim: Union[int, Tuple[int]] = 384,
-        num_encoder_layers: Union[int, Tuple[int]] = 4,
-        encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
-        query_head_dim: Union[int, Tuple[int]] = 24,
-        pos_head_dim: Union[int, Tuple[int]] = 4,
-        value_head_dim: Union[int, Tuple[int]] = 12,
-        num_heads: Union[int, Tuple[int]] = 8,
-        feedforward_dim: Union[int, Tuple[int]] = 1536,
-        cnn_module_kernel: Union[int, Tuple[int]] = 31,
-        pos_dim: int = 192,
-        dropout: FloatLike = None,  # see code below for default
-        warmup_batches: float = 4000.0,
-        causal: bool = False,
-        chunk_size: Tuple[int] = [-1],
-        left_context_frames: Tuple[int] = [-1],
-    ) -> None:
-        super(Zipformer2, self).__init__()
-
-        if dropout is None:
-            dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
-
-        def _to_tuple(x):
-            """Converts a single int or a 1-tuple of an int to a tuple with the same length
-            as downsampling_factor"""
-            if isinstance(x, int):
-                x = (x,)
-            if len(x) == 1:
-                x = x * len(downsampling_factor)
-            else:
-                assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
-            return x
-
-        self.output_downsampling_factor = output_downsampling_factor  # int
-        self.downsampling_factor = downsampling_factor  # tuple
-        self.encoder_dim = encoder_dim = _to_tuple(encoder_dim)  # tuple
-        self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(
-            encoder_unmasked_dim
-        )  # tuple
-        num_encoder_layers = _to_tuple(num_encoder_layers)
-        self.num_encoder_layers = num_encoder_layers
-        self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
-        self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
-        pos_head_dim = _to_tuple(pos_head_dim)
-        self.num_heads = num_heads = _to_tuple(num_heads)
-        feedforward_dim = _to_tuple(feedforward_dim)
-        self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
-
-        self.causal = causal
-        self.chunk_size = chunk_size
-        self.left_context_frames = left_context_frames
-
-        for u, d in zip(encoder_unmasked_dim, encoder_dim):
-            assert u <= d
-
-        # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
-        encoders = []
-
-        num_encoders = len(downsampling_factor)
-        for i in range(num_encoders):
-            encoder_layer = Zipformer2EncoderLayer(
-                embed_dim=encoder_dim[i],
-                pos_dim=pos_dim,
-                num_heads=num_heads[i],
-                query_head_dim=query_head_dim[i],
-                pos_head_dim=pos_head_dim[i],
-                value_head_dim=value_head_dim[i],
-                feedforward_dim=feedforward_dim[i],
-                dropout=dropout,
-                cnn_module_kernel=cnn_module_kernel[i],
-                causal=causal,
-            )
-
-            # For the segment of the warmup period, we let the Conv2dSubsampling
-            # layer learn something.  Then we start to warm up the other encoders.
-            encoder = Zipformer2Encoder(
-                encoder_layer,
-                num_encoder_layers[i],
-                pos_dim=pos_dim,
-                dropout=dropout,
-                warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
-                warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
-                final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
-            )
-
-            if downsampling_factor[i] != 1:
-                encoder = DownsampledZipformer2Encoder(
-                    encoder,
-                    dim=encoder_dim[i],
-                    downsample=downsampling_factor[i],
-                    dropout=dropout,
-                )
-
-            encoders.append(encoder)
-
-        self.encoders = nn.ModuleList(encoders)
-
-        self.downsample_output = SimpleDownsample(
-            max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout
-        )
-
-    def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]:
-        """
-        In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
-        randomized feature masks, one per encoder.
-        On e.g. 15% of frames, these masks will zero out all enocder dims larger than
-        some supplied number, e.g. >256, so in effect on those frames we are using
-        a smaller encoer dim.
-
-        We generate the random masks at this level because we want the 2 masks to 'agree'
-        all the way up the encoder stack. This will mean that the 1st mask will have
-        mask values repeated self.zipformer_subsampling_factor times.
-
-        Args:
-           x: the embeddings (needed for the shape and dtype and device), of shape
-             (1, batch_size, encoder_dims0)
-        """
-        num_encoders = len(self.encoder_dim)
-        if not self.training:
-            return [1.0] * num_encoders
-
-        (num_frames0, batch_size, _encoder_dims0) = x.shape
-
-        assert self.encoder_dim[0] == _encoder_dims0, (
-            self.encoder_dim[0],
-            _encoder_dims0,
-        )
-
-        feature_mask_dropout_prob = 0.125
-
-        # mask1 shape: (1, batch_size, 1)
-        mask1 = (
-            torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob
-        ).to(x.dtype)
-
-        # mask2 has additional sequences masked, about twice the number.
-        mask2 = torch.logical_and(
-            mask1,
-            (
-                torch.rand(1, batch_size, 1, device=x.device)
-                > feature_mask_dropout_prob
-            ).to(x.dtype),
-        )
-
-        # dim: (1, batch_size, 2)
-        mask = torch.cat((mask1, mask2), dim=-1)
-
-        feature_masks = []
-        for i in range(num_encoders):
-            channels = self.encoder_dim[i]
-            feature_mask = torch.ones(
-                1, batch_size, channels, dtype=x.dtype, device=x.device
-            )
-            u1 = self.encoder_unmasked_dim[i]
-            u2 = u1 + (channels - u1) // 2
-
-            feature_mask[:, :, u1:u2] *= mask[..., 0:1]
-            feature_mask[:, :, u2:] *= mask[..., 1:2]
-
-            feature_masks.append(feature_mask)
-
-        return feature_masks
-
-    def get_chunk_info(self) -> Tuple[int, int]:
-        """
-        Returns chunk_size and left_context_chunks.
-        """
-        if not self.causal:
-            return -1, -1
-
-        if torch.jit.is_scripting() or torch.jit.is_tracing():
-            assert len(self.chunk_size) == 1, self.chunk_size
-            chunk_size = self.chunk_size[0]
-        else:
-            chunk_size = random.choice(self.chunk_size)
-
-        if chunk_size == -1:
-            left_context_chunks = -1
-        else:
-            if torch.jit.is_scripting() or torch.jit.is_tracing():
-                assert len(self.left_context_frames) == 1, self.left_context_frames
-                left_context_frames = self.left_context_frames[0]
-            else:
-                left_context_frames = random.choice(self.left_context_frames)
-            # Note: in Python, -1 // n == -1 for n > 0
-            left_context_chunks = left_context_frames // chunk_size
-            if left_context_chunks == 0:
-                left_context_chunks = 1
-
-        return chunk_size, left_context_chunks
-
-    def forward(
-        self,
-        x: Tensor,
-        x_lens: Tensor,
-        src_key_padding_mask: Optional[Tensor] = None,
-    ) -> Tuple[Tensor, Tensor]:
-        """
-        Args:
-          x:
-            The input tensor. Its shape is (seq_len, batch_size, feature_dim).
-          x_lens:
-            A tensor of shape (batch_size,) containing the number of frames in
-            `x` before padding.
-          src_key_padding_mask:
-            The mask for padding, of shape (batch_size, seq_len); True means
-            masked position. May be None.
-        Returns:
-          Return a tuple containing 2 tensors:
-            - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
-            - lengths, a tensor of shape (batch_size,) containing the number
-              of frames in `embeddings` before padding.
-        """
-        outputs = []
-        if torch.jit.is_scripting() or torch.jit.is_tracing():
-            feature_masks = [1.0] * len(self.encoder_dim)
-        else:
-            feature_masks = self.get_feature_masks(x)
-
-        chunk_size, left_context_chunks = self.get_chunk_info()
-
-        if torch.jit.is_scripting() or torch.jit.is_tracing():
-            # Not support exporting a model for simulating streaming decoding
-            attn_mask = None
-        else:
-            attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
-
-        for i, module in enumerate(self.encoders):
-            ds = self.downsampling_factor[i]
-            x = convert_num_channels(x, self.encoder_dim[i])
-
-            x = module(
-                x,
-                chunk_size=chunk_size,
-                feature_mask=feature_masks[i],
-                src_key_padding_mask=(
-                    None
-                    if src_key_padding_mask is None
-                    else src_key_padding_mask[..., ::ds]
-                ),
-                attn_mask=attn_mask,
-            )
-            outputs.append(x)
-
-        # if the last output has the largest dimension, x will be unchanged,
-        # it will be the same as outputs[-1].  Otherwise it will be concatenated
-        # from different pieces of 'outputs', taking each dimension from the
-        # most recent output that has it present.
-        x = self._get_full_dim_output(outputs)
-        x = self.downsample_output(x)
-        # class Downsample has this rounding behavior..
-        assert self.output_downsampling_factor == 2, self.output_downsampling_factor
-        if torch.jit.is_scripting() or torch.jit.is_tracing():
-            lengths = (x_lens + 1) // 2
-        else:
-            with warnings.catch_warnings():
-                warnings.simplefilter("ignore")
-                lengths = (x_lens + 1) // 2
-
-        return x, lengths
-
-    def _get_attn_mask(
-        self, x: Tensor, chunk_size: int, left_context_chunks: int
-    ) -> Optional[Tensor]:
-        """
-        Return None if chunk_size == -1, else return attention mask of shape
-          (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len).  True
-           means a masked position.
-        Args:
-           x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
-          chunk_size: chunk size, must divide
-        """
-        if chunk_size <= 0:
-            return None
-        assert all(chunk_size % d == 0 for d in self.downsampling_factor)
-        if left_context_chunks >= 0:
-            num_encoders = len(self.encoder_dim)
-            assert all(
-                chunk_size * left_context_chunks
-                >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i]
-                for i in range(num_encoders)
-            )
-        else:
-            left_context_chunks = 1000000
-
-        seq_len = x.shape[0]
-
-        # t is frame index, shape (seq_len,)
-        t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
-        # c is chunk index for each frame, shape (seq_len,)
-        if torch.jit.is_scripting() or torch.jit.is_tracing():
-            c = t // chunk_size
-        else:
-            with warnings.catch_warnings():
-                warnings.simplefilter("ignore")
-                c = t // chunk_size
-        src_c = c
-        tgt_c = c.unsqueeze(-1)
-
-        attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks)
-        if __name__ == "__main__":
-            logging.info(f"attn_mask = {attn_mask}")
-        return attn_mask
-
-    def _get_full_dim_output(self, outputs: List[Tensor]):
-        num_encoders = len(self.encoder_dim)
-        assert len(outputs) == num_encoders
-        output_dim = max(self.encoder_dim)
-        output_pieces = [outputs[-1]]
-        cur_dim = self.encoder_dim[-1]
-        for i in range(num_encoders - 2, -1, -1):
-            d = self.encoder_dim[i]
-            if d > cur_dim:
-                this_output = outputs[i]
-                output_pieces.append(this_output[..., cur_dim:d])
-                cur_dim = d
-        assert cur_dim == output_dim
-        return torch.cat(output_pieces, dim=-1)
-
-    def streaming_forward(
-        self,
-        x: Tensor,
-        x_lens: Tensor,
-        states: List[Tensor],
-        src_key_padding_mask: Tensor,
-    ) -> Tuple[Tensor, Tensor, List[Tensor]]:
-        """
-        Args:
-          x:
-            The input tensor. Its shape is (seq_len, batch_size, feature_dim).
-          x_lens:
-            A tensor of shape (batch_size,) containing the number of frames in
-            `x` before padding.
-          states: list of cached tensors of all encoder layers. For layer-i,
-            states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
-            cached_conv1, cached_conv2).
-          src_key_padding_mask:
-            The mask for padding, of shape (batch_size, seq_len); True means
-            masked position. May be None.
-        Returns:
-          Return a tuple containing 2 tensors:
-            - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
-            - lengths, a tensor of shape (batch_size,) containing the number
-              of frames in `embeddings` before padding.
-            - updated states
-        """
-        outputs = []
-        new_states = []
-        layer_offset = 0
-
-        for i, module in enumerate(self.encoders):
-            num_layers = module.num_layers
-            ds = self.downsampling_factor[i]
-            x = convert_num_channels(x, self.encoder_dim[i])
-
-            x, new_layer_states = module.streaming_forward(
-                x,
-                states=states[layer_offset * 6 : (layer_offset + num_layers) * 6],
-                left_context_len=self.left_context_frames[0] // ds,
-                src_key_padding_mask=src_key_padding_mask[..., ::ds],
-            )
-            layer_offset += num_layers
-            outputs.append(x)
-            new_states += new_layer_states
-
-        # if the last output has the largest dimension, x will be unchanged,
-        # it will be the same as outputs[-1].  Otherwise it will be concatenated
-        # from different pieces of 'outputs', taking each dimension from the
-        # most recent output that has it present.
-        x = self._get_full_dim_output(outputs)
-        x = self.downsample_output(x)
-        # class Downsample has this rounding behavior..
-        assert self.output_downsampling_factor == 2
-        if torch.jit.is_scripting() or torch.jit.is_tracing():
-            lengths = (x_lens + 1) // 2
-        else:
-            with warnings.catch_warnings():
-                warnings.simplefilter("ignore")
-                lengths = (x_lens + 1) // 2
-
-        return x, lengths, new_states
-
-    @torch.jit.export
-    def get_init_states(
-        self,
-        batch_size: int = 1,
-        device: torch.device = torch.device("cpu"),
-    ) -> List[Tensor]:
-        """Get initial states.
-
-        A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
-        is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
-        """
-        states = []
-        for i, module in enumerate(self.encoders):
-            num_layers = module.num_layers
-            embed_dim = self.encoder_dim[i]
-            ds = self.downsampling_factor[i]
-            num_heads = self.num_heads[i]
-            key_dim = self.query_head_dim[i] * num_heads
-            value_dim = self.value_head_dim[i] * num_heads
-            downsample_left = self.left_context_frames[0] // ds
-            nonlin_attn_head_dim = 3 * embed_dim // 4
-            conv_left_pad = self.cnn_module_kernel[i] // 2
-            for layer in range(num_layers):
-                cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(
-                    device
-                )
-                cached_nonlin_attn = torch.zeros(
-                    1, batch_size, downsample_left, nonlin_attn_head_dim
-                ).to(device)
-                cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(
-                    device
-                )
-                cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(
-                    device
-                )
-                cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
-                    device
-                )
-                cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
-                    device
-                )
-                states += [
-                    cached_key,
-                    cached_nonlin_attn,
-                    cached_val1,
-                    cached_val2,
-                    cached_conv1,
-                    cached_conv2,
-                ]
-
-        return states
-
-
-def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
-    return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x)
-
-
-def _balancer_schedule(min_prob: float):
-    return ScheduledFloat((0.0, 0.4), (8000.0, min_prob))
-
-
-class Zipformer2EncoderLayer(nn.Module):
-    """
-    Args:
-        embed_dim: the number of expected features in the input (required).
-        nhead: the number of heads in the multiheadattention models (required).
-        feedforward_dim: the dimension of the feedforward network model (default=2048).
-        dropout: the dropout value (default=0.1).
-        cnn_module_kernel (int): Kernel size of convolution module.
-
-    Examples::
-        >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
-        >>> src = torch.rand(10, 32, 512)
-        >>> pos_emb = torch.rand(32, 19, 512)
-        >>> out = encoder_layer(src, pos_emb)
-    """
-
-    def __init__(
-        self,
-        embed_dim: int,
-        pos_dim: int,
-        num_heads: int,
-        query_head_dim: int,
-        pos_head_dim: int,
-        value_head_dim: int,
-        feedforward_dim: int,
-        dropout: FloatLike = 0.1,
-        cnn_module_kernel: int = 31,
-        causal: bool = False,
-        attention_skip_rate: FloatLike = ScheduledFloat(
-            (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
-        ),
-        conv_skip_rate: FloatLike = ScheduledFloat(
-            (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
-        ),
-        const_attention_rate: FloatLike = ScheduledFloat(
-            (0.0, 0.25), (4000.0, 0.025), default=0
-        ),
-        ff2_skip_rate: FloatLike = ScheduledFloat(
-            (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
-        ),
-        ff3_skip_rate: FloatLike = ScheduledFloat(
-            (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
-        ),
-        bypass_skip_rate: FloatLike = ScheduledFloat(
-            (0.0, 0.5), (4000.0, 0.02), default=0
-        ),
-    ) -> None:
-        super(Zipformer2EncoderLayer, self).__init__()
-        self.embed_dim = embed_dim
-
-        # self.bypass implements layer skipping as well as bypass; see its default values.
-        self.bypass = BypassModule(
-            embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0
-        )
-        # bypass_mid is bypass used in the middle of the layer.
-        self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0)
-
-        # skip probability for dynamic modules (meaning: anything but feedforward).
-        self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
-        # an additional skip probability that applies to ConvModule to stop it from
-        # contributing too much early on.
-        self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
-
-        # ff2_skip_rate is to prevent the ff2 module from having output that's too big
-        # compared to its residual.
-        self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
-        self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
-
-        self.const_attention_rate = copy.deepcopy(const_attention_rate)
-
-        self.self_attn_weights = RelPositionMultiheadAttentionWeights(
-            embed_dim,
-            pos_dim=pos_dim,
-            num_heads=num_heads,
-            query_head_dim=query_head_dim,
-            pos_head_dim=pos_head_dim,
-            dropout=0.0,
-        )
-
-        self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
-
-        self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim)
-
-        self.feed_forward1 = FeedforwardModule(
-            embed_dim, (feedforward_dim * 3) // 4, dropout
-        )
-
-        self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout)
-
-        self.feed_forward3 = FeedforwardModule(
-            embed_dim, (feedforward_dim * 5) // 4, dropout
-        )
-
-        self.nonlin_attention = NonlinAttention(
-            embed_dim, hidden_channels=3 * embed_dim // 4
-        )
-
-        self.conv_module1 = ConvolutionModule(
-            embed_dim, cnn_module_kernel, causal=causal
-        )
-
-        self.conv_module2 = ConvolutionModule(
-            embed_dim, cnn_module_kernel, causal=causal
-        )
-
-        # TODO: remove it
-        self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
-
-        self.norm = BiasNorm(embed_dim)
-
-        self.balancer1 = Balancer(
-            embed_dim,
-            channel_dim=-1,
-            min_positive=0.45,
-            max_positive=0.55,
-            min_abs=0.2,
-            max_abs=4.0,
-        )
-
-        # balancer for output of NonlinAttentionModule
-        self.balancer_na = Balancer(
-            embed_dim,
-            channel_dim=-1,
-            min_positive=0.3,
-            max_positive=0.7,
-            min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
-            prob=0.05,  # out of concern for memory usage
-        )
-
-        # balancer for output of feedforward2, prevent it from staying too
-        # small.  give this a very small probability, even at the start of
-        # training, it's to fix a rare problem and it's OK to fix it slowly.
-        self.balancer_ff2 = Balancer(
-            embed_dim,
-            channel_dim=-1,
-            min_positive=0.3,
-            max_positive=0.7,
-            min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0),
-            max_abs=2.0,
-            prob=0.05,
-        )
-
-        self.balancer_ff3 = Balancer(
-            embed_dim,
-            channel_dim=-1,
-            min_positive=0.3,
-            max_positive=0.7,
-            min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0),
-            max_abs=4.0,
-            prob=0.05,
-        )
-
-        self.whiten = Whiten(
-            num_groups=1,
-            whitening_limit=_whitening_schedule(4.0, ratio=3.0),
-            prob=(0.025, 0.25),
-            grad_scale=0.01,
-        )
-
-        self.balancer2 = Balancer(
-            embed_dim,
-            channel_dim=-1,
-            min_positive=0.45,
-            max_positive=0.55,
-            min_abs=0.1,
-            max_abs=4.0,
-        )
-
-    def get_sequence_dropout_mask(
-        self, x: Tensor, dropout_rate: float
-    ) -> Optional[Tensor]:
-        if (
-            dropout_rate == 0.0
-            or not self.training
-            or torch.jit.is_scripting()
-            or torch.jit.is_tracing()
-        ):
-            return None
-        batch_size = x.shape[1]
-        mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
-        return mask
-
-    def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
-        """
-        Apply sequence-level dropout to x.
-        x shape: (seq_len, batch_size, embed_dim)
-        """
-        dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
-        if dropout_mask is None:
-            return x
-        else:
-            return x * dropout_mask
-
-    def forward(
-        self,
-        src: Tensor,
-        pos_emb: Tensor,
-        chunk_size: int = -1,
-        attn_mask: Optional[Tensor] = None,
-        src_key_padding_mask: Optional[Tensor] = None,
-    ) -> Tensor:
-        """
-            Pass the input through the encoder layer.
-            Args:
-                src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
-             pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
-             chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
-           feature_mask: something that broadcasts with src, that we'll multiply `src`
-                  by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
-             attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
-                    interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
-                   True means masked position. May be None.
-        src_key_padding_mask:  the mask for padding, of shape (batch_size, seq_len); True means
-                 masked position.  May be None.
-
-            Returns:
-               A tensor which has the same shape as src
-        """
-        src_orig = src
-
-        # dropout rate for non-feedforward submodules
-        if torch.jit.is_scripting() or torch.jit.is_tracing():
-            attention_skip_rate = 0.0
-        else:
-            attention_skip_rate = (
-                float(self.attention_skip_rate) if self.training else 0.0
-            )
-
-        # attn_weights: (num_heads, batch_size, seq_len, seq_len)
-        attn_weights = self.self_attn_weights(
-            src,
-            pos_emb=pos_emb,
-            attn_mask=attn_mask,
-            key_padding_mask=src_key_padding_mask,
-        )
-
-        src = src + self.feed_forward1(src)
-
-        self_attn_dropout_mask = self.get_sequence_dropout_mask(
-            src, attention_skip_rate
-        )
-
-        selected_attn_weights = attn_weights[0:1]
-        if torch.jit.is_scripting() or torch.jit.is_tracing():
-            pass
-        elif not self.training and random.random() < float(self.const_attention_rate):
-            # Make attention weights constant.  The intention is to
-            # encourage these modules to do something similar to an
-            # averaging-over-time operation.
-            # only need the mask, can just use the 1st one and expand later
-            selected_attn_weights = selected_attn_weights[0:1]
-            selected_attn_weights = (selected_attn_weights > 0.0).to(
-                selected_attn_weights.dtype
-            )
-            selected_attn_weights = selected_attn_weights * (
-                1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
-            )
-
-        na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
-
-        src = src + (
-            na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
-        )
-
-        self_attn = self.self_attn1(src, attn_weights)
-
-        src = src + (
-            self_attn
-            if self_attn_dropout_mask is None
-            else self_attn * self_attn_dropout_mask
-        )
-
-        if torch.jit.is_scripting() or torch.jit.is_tracing():
-            conv_skip_rate = 0.0
-        else:
-            conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
-        src = src + self.sequence_dropout(
-            self.conv_module1(
-                src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
-            ),
-            conv_skip_rate,
-        )
-
-        if torch.jit.is_scripting() or torch.jit.is_tracing():
-            ff2_skip_rate = 0.0
-        else:
-            ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
-        src = src + self.sequence_dropout(
-            self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate
-        )
-
-        # bypass in the middle of the layer.
-        src = self.bypass_mid(src_orig, src)
-
-        self_attn = self.self_attn2(src, attn_weights)
-
-        src = src + (
-            self_attn
-            if self_attn_dropout_mask is None
-            else self_attn * self_attn_dropout_mask
-        )
-
-        if torch.jit.is_scripting() or torch.jit.is_tracing():
-            conv_skip_rate = 0.0
-        else:
-            conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
-        src = src + self.sequence_dropout(
-            self.conv_module2(
-                src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
-            ),
-            conv_skip_rate,
-        )
-
-        if torch.jit.is_scripting() or torch.jit.is_tracing():
-            ff3_skip_rate = 0.0
-        else:
-            ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
-        src = src + self.sequence_dropout(
-            self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate
-        )
-
-        src = self.balancer1(src)
-        src = self.norm(src)
-
-        src = self.bypass(src_orig, src)
-
-        src = self.balancer2(src)
-        src = self.whiten(src)
-
-        return src
-
-    def streaming_forward(
-        self,
-        src: Tensor,
-        pos_emb: Tensor,
-        cached_key: Tensor,
-        cached_nonlin_attn: Tensor,
-        cached_val1: Tensor,
-        cached_val2: Tensor,
-        cached_conv1: Tensor,
-        cached_conv2: Tensor,
-        left_context_len: int,
-        src_key_padding_mask: Tensor,
-    ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
-        """Pass the input through the encoder layer in streaming forward mode.
-
-        Args:
-            src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
-            pos_emb: (1, left_context_len+2*seq_len-1, pos_emb_dim) or
-              (batch_size, left_context_len+2*seq_len-1, pos_emb_dim)
-            cached_key: cached attention key tensor of left context,
-              of shape (left_context_len, batch_size, key_dim)
-            cached_nonlin_attn: left context for nonlin_attention module, a Tensor of shape
-              (num_heads, batch_size, left_context_len, head_dim)
-            cached_val1: cached left context for the first attention module,
-              of shape (left_context_len, batch_size, value_dim)
-            cached_val2: cached left context for the second attention module,
-              of shape (left_context_len, batch_size, value_dim)
-            cached_conv1: cached left context for the first convolution module,
-              of shape (batch_size, channels, left_pad)
-            cached_conv2: cached left context for the second convolution module,
-              of shape (batch_size, channels, left_pad)
-            left_context_len: number of left context frames.
-            src_key_padding_mask:  the mask for padding, of shape
-              (batch_size, left_context_len + seq_len); True means masked position.
-              May be None.
-
-        Returns:
-            - x, with the same shape as src
-            - updated cached_key
-            - updated cached_nonlin_attn
-            - updated cached_val1
-            - updated cached_val2
-            - updated cached_conv1
-            - updated cached_conv2
-        """
-        src_orig = src
-
-        # attn_weights: (num_heads, batch_size, seq_len, seq_len)
-        attn_weights, cached_key = self.self_attn_weights.streaming_forward(
-            src,
-            pos_emb=pos_emb,
-            cached_key=cached_key,
-            left_context_len=left_context_len,
-            key_padding_mask=src_key_padding_mask,
-        )
-
-        src = src + self.feed_forward1(src)
-
-        na, cached_nonlin_attn = self.nonlin_attention.streaming_forward(
-            src,
-            attn_weights[0:1],
-            cached_x=cached_nonlin_attn,
-            left_context_len=left_context_len,
-        )
-        src = src + na
-
-        self_attn, cached_val1 = self.self_attn1.streaming_forward(
-            src,
-            attn_weights=attn_weights,
-            cached_val=cached_val1,
-            left_context_len=left_context_len,
-        )
-        src = src + self_attn
-
-        src_conv, cached_conv1 = self.conv_module1.streaming_forward(
-            src,
-            cache=cached_conv1,
-            src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
-        )
-        src = src + src_conv
-
-        src = src + self.feed_forward2(src)
-
-        # bypass in the middle of the layer.
-        src = self.bypass_mid(src_orig, src)
-
-        self_attn, cached_val2 = self.self_attn2.streaming_forward(
-            src,
-            attn_weights=attn_weights,
-            cached_val=cached_val2,
-            left_context_len=left_context_len,
-        )
-        src = src + self_attn
-
-        src_conv, cached_conv2 = self.conv_module2.streaming_forward(
-            src,
-            cache=cached_conv2,
-            src_key_padding_mask=src_key_padding_mask[:, left_context_len:],
-        )
-        src = src + src_conv
-
-        src = src + self.feed_forward3(src)
-
-        src = self.norm(src)
-
-        src = self.bypass(src_orig, src)
-
-        return (
-            src,
-            cached_key,
-            cached_nonlin_attn,
-            cached_val1,
-            cached_val2,
-            cached_conv1,
-            cached_conv2,
-        )
-
-
-class Zipformer2Encoder(nn.Module):
-    r"""Zipformer2Encoder is a stack of N encoder layers
-
-    Args:
-        encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
-        num_layers: the number of sub-encoder-layers in the encoder (required).
-       pos_dim: the dimension for the relative positional encoding
-
-    Examples::
-        >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
-        >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
-        >>> src = torch.rand(10, 32, 512)
-        >>> out = zipformer_encoder(src)
-    """
-
-    def __init__(
-        self,
-        encoder_layer: nn.Module,
-        num_layers: int,
-        pos_dim: int,
-        dropout: float,
-        warmup_begin: float,
-        warmup_end: float,
-        initial_layerdrop_rate: float = 0.5,
-        final_layerdrop_rate: float = 0.05,
-    ) -> None:
-        super().__init__()
-        self.encoder_pos = CompactRelPositionalEncoding(
-            pos_dim, dropout_rate=0.15, length_factor=1.0
-        )
-
-        self.layers = nn.ModuleList(
-            [copy.deepcopy(encoder_layer) for i in range(num_layers)]
-        )
-        self.num_layers = num_layers
-
-        assert 0 <= warmup_begin <= warmup_end
-
-        delta = (1.0 / num_layers) * (warmup_end - warmup_begin)
-        cur_begin = warmup_begin  # interpreted as a training batch index
-        for i in range(num_layers):
-            cur_end = cur_begin + delta
-            self.layers[i].bypass.skip_rate = ScheduledFloat(
-                (cur_begin, initial_layerdrop_rate),
-                (cur_end, final_layerdrop_rate),
-                default=0.0,
-            )
-            cur_begin = cur_end
-
-    def forward(
-        self,
-        src: Tensor,
-        chunk_size: int = -1,
-        feature_mask: Union[Tensor, float] = 1.0,
-        attn_mask: Optional[Tensor] = None,
-        src_key_padding_mask: Optional[Tensor] = None,
-    ) -> Tensor:
-        r"""Pass the input through the encoder layers in turn.
-
-        Args:
-            src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
-            chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
-            feature_mask: something that broadcasts with src, that we'll multiply `src`
-               by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
-            attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
-                 interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
-                 True means masked position. May be None.
-            src_key_padding_mask:  the mask for padding, of shape (batch_size, seq_len); True means
-                 masked position.  May be None.
-
-        Returns: a Tensor with the same shape as src.
-        """
-        pos_emb = self.encoder_pos(src)
-        output = src
-
-        if not torch.jit.is_scripting() and not torch.jit.is_tracing():
-            output = output * feature_mask
-
-        for i, mod in enumerate(self.layers):
-            output = mod(
-                output,
-                pos_emb,
-                chunk_size=chunk_size,
-                attn_mask=attn_mask,
-                src_key_padding_mask=src_key_padding_mask,
-            )
-
-            if not torch.jit.is_scripting() and not torch.jit.is_tracing():
-                output = output * feature_mask
-
-        return output
-
-    def streaming_forward(
-        self,
-        src: Tensor,
-        states: List[Tensor],
-        left_context_len: int,
-        src_key_padding_mask: Tensor,
-    ) -> Tuple[Tensor, List[Tensor]]:
-        r"""Pass the input through the encoder layers in turn.
-
-        Args:
-            src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
-            states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
-              (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
-            left_context_len: Number of left context frames.
-            src_key_padding_mask:  the mask for padding, of shape
-              (batch_size, left_context_len + seq_len); True means masked position.
-              May be None.
-
-        Returns:
-          - output, a Tensor with the same shape as src.
-          - updated states
-        """
-        pos_emb = self.encoder_pos(src, left_context_len)
-        output = src
-
-        new_states = []
-        for i, mod in enumerate(self.layers):
-            (
-                cached_key,
-                cached_nonlin_attn,
-                cached_val1,
-                cached_val2,
-                cached_conv1,
-                cached_conv2,
-            ) = states[i * 6 : (i + 1) * 6]
-            (
-                output,
-                new_cached_key,
-                new_cached_nonlin_attn,
-                new_cached_val1,
-                new_cached_val2,
-                new_cached_conv1,
-                new_cached_conv2,
-            ) = mod.streaming_forward(
-                output,
-                pos_emb,
-                cached_key=cached_key,
-                cached_nonlin_attn=cached_nonlin_attn,
-                cached_val1=cached_val1,
-                cached_val2=cached_val2,
-                cached_conv1=cached_conv1,
-                cached_conv2=cached_conv2,
-                left_context_len=left_context_len,
-                src_key_padding_mask=src_key_padding_mask,
-            )
-            new_states += [
-                new_cached_key,
-                new_cached_nonlin_attn,
-                new_cached_val1,
-                new_cached_val2,
-                new_cached_conv1,
-                new_cached_conv2,
-            ]
-
-        return output, new_states
-
-
-class BypassModule(nn.Module):
-    """
-    An nn.Module that implements a learnable bypass scale, and also randomized per-sequence
-    layer-skipping.  The bypass is limited during early stages of training to be close to
-    "straight-through", i.e. to not do the bypass operation much initially, in order to
-    force all the modules to learn something.
-    """
-
-    def __init__(
-        self,
-        embed_dim: int,
-        skip_rate: FloatLike = 0.0,
-        straight_through_rate: FloatLike = 0.0,
-        scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
-        scale_max: FloatLike = 1.0,
-    ):
-        super().__init__()
-        self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
-        self.skip_rate = copy.deepcopy(skip_rate)
-        self.straight_through_rate = copy.deepcopy(straight_through_rate)
-        self.scale_min = copy.deepcopy(scale_min)
-        self.scale_max = copy.deepcopy(scale_max)
-
-    def _get_bypass_scale(self, batch_size: int):
-        # returns bypass-scale of shape (num_channels,),
-        # or (batch_size, num_channels,).  This is actually the
-        # scale on the non-residual term, so 0 correponds to bypassing
-        # this module.
-        if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
-            return self.bypass_scale
-        else:
-            ans = limit_param_value(
-                self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max)
-            )
-            skip_rate = float(self.skip_rate)
-            if skip_rate != 0.0:
-                mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
-                ans = ans * mask
-                # now ans is of shape (batch_size, num_channels), and is zero for sequences
-                # on which we have randomly chosen to do layer-skipping.
-            straight_through_rate = float(self.straight_through_rate)
-            if straight_through_rate != 0.0:
-                mask = (
-                    torch.rand((batch_size, 1), device=ans.device)
-                    < straight_through_rate
-                )
-                ans = torch.maximum(ans, mask.to(ans.dtype))
-            return ans
-
-    def forward(self, src_orig: Tensor, src: Tensor):
-        """
-        Args: src_orig and src are both of shape (seq_len, batch_size, num_channels)
-        Returns: something with the same shape as src and src_orig
-        """
-        bypass_scale = self._get_bypass_scale(src.shape[1])
-        return src_orig + (src - src_orig) * bypass_scale
-
-
-class DownsampledZipformer2Encoder(nn.Module):
-    r"""
-    DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
-    after convolutional downsampling, and then upsampled again at the output, and combined
-    with the origin input, so that the output has the same shape as the input.
-    """
-
-    def __init__(
-        self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike
-    ):
-        super(DownsampledZipformer2Encoder, self).__init__()
-        self.downsample_factor = downsample
-        self.downsample = SimpleDownsample(dim, downsample, dropout)
-        self.num_layers = encoder.num_layers
-        self.encoder = encoder
-        self.upsample = SimpleUpsample(dim, downsample)
-        self.out_combiner = BypassModule(dim, straight_through_rate=0)
-
-    def forward(
-        self,
-        src: Tensor,
-        chunk_size: int = -1,
-        feature_mask: Union[Tensor, float] = 1.0,
-        attn_mask: Optional[Tensor] = None,
-        src_key_padding_mask: Optional[Tensor] = None,
-    ) -> Tensor:
-        r"""Downsample, go through encoder, upsample.
-
-        Args:
-            src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
-            feature_mask: something that broadcasts with src, that we'll multiply `src`
-               by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
-            attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
-                 interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
-                 True means masked position. May be None.
-            src_key_padding_mask:  the mask for padding, of shape (batch_size, seq_len); True means
-                 masked position.  May be None.
-
-        Returns: a Tensor with the same shape as src.
-        """
-        src_orig = src
-        src = self.downsample(src)
-        ds = self.downsample_factor
-        if attn_mask is not None:
-            attn_mask = attn_mask[::ds, ::ds]
-
-        src = self.encoder(
-            src,
-            chunk_size=chunk_size // ds,
-            feature_mask=feature_mask,
-            attn_mask=attn_mask,
-            src_key_padding_mask=src_key_padding_mask,
-        )
-        src = self.upsample(src)
-        # remove any extra frames that are not a multiple of downsample_factor
-        src = src[: src_orig.shape[0]]
-
-        return self.out_combiner(src_orig, src)
-
-    def streaming_forward(
-        self,
-        src: Tensor,
-        states: List[Tensor],
-        left_context_len: int,
-        src_key_padding_mask: Tensor,
-    ) -> Tuple[Tensor, List[Tensor]]:
-        r"""Downsample, go through encoder, upsample, in streaming forward mode.
-
-        Args:
-            src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
-            states: list of cached tensors of N encoder layers. For layer-i, states[i*6:(i+1)*6] is
-              (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
-            left_context_len: Number of left context frames.
-            src_key_padding_mask: the mask for padding, of shape (batch_size, left_context_len+seq_len);
-              True means masked position. May be None.
-
-        Returns:
-            - output, a Tensor with the same shape as src.
-            - updated states
-        """
-        src_orig = src
-        src = self.downsample(src)
-
-        src, new_states = self.encoder.streaming_forward(
-            src,
-            states=states,
-            left_context_len=left_context_len,
-            src_key_padding_mask=src_key_padding_mask,
-        )
-        src = self.upsample(src)
-        # remove any extra frames that are not a multiple of downsample_factor
-        src = src[: src_orig.shape[0]]
-
-        return self.out_combiner(src_orig, src), new_states
-
-
-class SimpleDownsample(torch.nn.Module):
-    """
-    Does downsampling with attention, by weighted sum, and a projection..
-    """
-
-    def __init__(self, channels: int, downsample: int, dropout: FloatLike):
-        super(SimpleDownsample, self).__init__()
-
-        self.bias = nn.Parameter(torch.zeros(downsample))
-
-        self.name = None  # will be set from training code
-        self.dropout = copy.deepcopy(dropout)
-
-        self.downsample = downsample
-
-    def forward(self, src: Tensor) -> Tensor:
-        """
-        x: (seq_len, batch_size, in_channels)
-        Returns a tensor of shape
-           ( (seq_len+downsample-1)//downsample, batch_size, channels)
-        """
-        (seq_len, batch_size, in_channels) = src.shape
-        ds = self.downsample
-        d_seq_len = (seq_len + ds - 1) // ds
-
-        # Pad to an exact multiple of self.downsample
-        # right-pad src, repeating the last element.
-        pad = d_seq_len * ds - seq_len
-        src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
-        src = torch.cat((src, src_extra), dim=0)
-        assert src.shape[0] == d_seq_len * ds
-
-        src = src.reshape(d_seq_len, ds, batch_size, in_channels)
-
-        weights = self.bias.softmax(dim=0)
-        # weights: (downsample, 1, 1)
-        weights = weights.unsqueeze(-1).unsqueeze(-1)
-
-        # ans1 is the first `in_channels` channels of the output
-        ans = (src * weights).sum(dim=1)
-
-        return ans
-
-
-class SimpleUpsample(torch.nn.Module):
-    """
-    A very simple form of upsampling that mostly just repeats the input, but
-    also adds a position-specific bias.
-    """
-
-    def __init__(self, num_channels: int, upsample: int):
-        super(SimpleUpsample, self).__init__()
-        self.upsample = upsample
-
-    def forward(self, src: Tensor) -> Tensor:
-        """
-        x: (seq_len, batch_size, num_channels)
-        Returns a tensor of shape
-           ( (seq_len*upsample), batch_size, num_channels)
-        """
-        upsample = self.upsample
-        (seq_len, batch_size, num_channels) = src.shape
-        src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
-        src = src.reshape(seq_len * upsample, batch_size, num_channels)
-        return src
-
-
-class CompactRelPositionalEncoding(torch.nn.Module):
-    """
-    Relative positional encoding module.  This version is "compact" meaning it is able to encode
-    the important information about the relative position in a relatively small number of dimensions.
-    The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001)
-    make very little difference to the embedding.   Such differences were potentially important
-    when encoding absolute position, but not important when encoding relative position because there
-    is now no need to compare two large offsets with each other.
-
-    Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval
-    using the atan() function, before doing the fourier transform of that fixed interval.  The
-    atan() function would compress the "long tails" too small,
-    making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic
-    function to compress large offsets to a smaller range before applying atan().
-    Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long
-    as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim)
-
-
-    Args:
-        embed_dim: Embedding dimension.
-        dropout_rate: Dropout rate.
-        max_len: Maximum input length: just a heuristic for initialization.
-        length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
-           less weight to small differences of offset near the origin.
-    """
-
-    def __init__(
-        self,
-        embed_dim: int,
-        dropout_rate: FloatLike,
-        max_len: int = 1000,
-        length_factor: float = 1.0,
-    ) -> None:
-        """Construct a CompactRelPositionalEncoding object."""
-        super(CompactRelPositionalEncoding, self).__init__()
-        self.embed_dim = embed_dim
-        assert embed_dim % 2 == 0
-        self.dropout = Dropout2(dropout_rate)
-        self.pe = None
-        assert length_factor >= 1.0
-        self.length_factor = length_factor
-        self.extend_pe(torch.tensor(0.0).expand(max_len))
-
-    def extend_pe(self, x: Tensor, left_context_len: int = 0):
-        """Reset the positional encodings."""
-        T = x.size(0) + left_context_len
-
-        if self.pe is not None:
-            # self.pe contains both positive and negative parts
-            # the length of self.pe is 2 * input_len - 1
-            if self.pe.size(0) >= T * 2 - 1:
-                # self.pe = self.pe.to(dtype=x.dtype, device=x.device)
-                return self.pe.to(dtype=x.dtype, device=x.device)
-
-        # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
-        x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
-
-        freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
-
-        # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution
-        # for small time offsets but less resolution for large time offsets.
-        compression_length = self.embed_dim**0.5
-        # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity;
-        # but it does so more slowly than T for large absolute values of T.
-        # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which
-        # is important.
-        x_compressed = (
-            compression_length
-            * x.sign()
-            * ((x.abs() + compression_length).log() - math.log(compression_length))
-        )
-
-        # if self.length_factor == 1.0, then length_scale is chosen so that the
-        # FFT can exactly separate points close to the origin (T == 0).  So this
-        # part of the formulation is not really heuristic.
-        # But empirically, for ASR at least, length_factor > 1.0 seems to work better.
-        length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)
-
-        # note for machine implementations: if atan is not available, we can use:
-        #   x.sign() * ((1 / (x.abs() + 1)) - 1)  * (-math.pi/2)
-        #  check on wolframalpha.com: plot(sign(x) *  (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
-        x_atan = (x_compressed / length_scale).atan()  # results between -pi and pi
-
-        cosines = (x_atan * freqs).cos()
-        sines = (x_atan * freqs).sin()
-
-        pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
-        pe[:, 0::2] = cosines
-        pe[:, 1::2] = sines
-        pe[:, -1] = 1.0  # for bias.
-
-        # self.pe = pe.to(dtype=x.dtype)
-        return pe.to(dtype=x.dtype)
-
-    def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
-        """Create positional encoding.
-
-        Args:
-            x (Tensor): Input tensor (time, batch, `*`).
-            left_context_len: (int): Length of cached left context.
-
-        Returns:
-            positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
-        """
-        pe = self.extend_pe(x, left_context_len)
-        x_size_left = x.size(0) + left_context_len
-        # length of positive side: x.size(0) + left_context_len
-        # length of negative side: x.size(0)
-        pos_emb = pe[
-            pe.size(0) // 2
-            - x_size_left
-            + 1 : pe.size(0) // 2  # noqa E203
-            + x.size(0),
-            :,
-        ]
-        pos_emb = pos_emb.unsqueeze(0)
-        return self.dropout(pos_emb)
-
-
-class RelPositionMultiheadAttentionWeights(nn.Module):
-    r"""Module that computes multi-head attention weights with relative position encoding.
-    Various other modules consume the resulting attention weights: see, for example, the
-    SimpleAttention module which allows you to compute conventional attention.
-
-    This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context",
-    we have to write up the differences.
-
-
-    Args:
-           embed_dim: number of channels at the input to this module, e.g. 256
-             pos_dim: dimension of the positional encoding vectors, e.g. 128.
-           num_heads:  number of heads to compute weights for, e.g. 8
-     query_head_dim: dimension of the query (and key), per head.  e.g. 24.
-       pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
-            dropout: dropout probability for attn_output_weights. Default: 0.0.
-       pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
-                     any given call to forward(), in training time.
-    """
-
-    def __init__(
-        self,
-        embed_dim: int,
-        pos_dim: int,
-        num_heads: int,
-        query_head_dim: int,
-        pos_head_dim: int,
-        dropout: float = 0.0,
-        pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
-    ) -> None:
-        super().__init__()
-        self.embed_dim = embed_dim
-        self.num_heads = num_heads
-        self.query_head_dim = query_head_dim
-        self.pos_head_dim = pos_head_dim
-        self.dropout = dropout
-        self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
-        self.name = None  # will be overwritten in training code; for diagnostics.
-
-        key_head_dim = query_head_dim
-        in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
-
-        # the initial_scale is supposed to take over the "scaling" factor of
-        # head_dim ** -0.5 that has been used in previous forms of attention,
-        # dividing it between the query and key.   Note: this module is intended
-        # to be used with the ScaledAdam optimizer; with most other optimizers,
-        # it would be necessary to apply the scaling factor in the forward function.
-        self.in_proj = ScaledLinear(
-            embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
-        )
-
-        self.whiten_keys = Whiten(
-            num_groups=num_heads,
-            whitening_limit=_whitening_schedule(3.0),
-            prob=(0.025, 0.25),
-            grad_scale=0.025,
-        )
-
-        # add a balancer for the keys that runs with very small probability, and
-        # tries to enforce that all dimensions have mean around zero.  The
-        # weights produced by this module are invariant to adding a constant to
-        # the keys, so the derivative of the bias is mathematically zero; but
-        # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero
-        # bias because the small numerical roundoff tends to have a non-random
-        # sign.  This module is intended to prevent that.  Use a very small
-        # probability; that should be suffixient to fix the problem.
-        self.balance_keys = Balancer(
-            key_head_dim * num_heads,
-            channel_dim=-1,
-            min_positive=0.4,
-            max_positive=0.6,
-            min_abs=0.0,
-            max_abs=100.0,
-            prob=0.025,
-        )
-
-        # linear transformation for positional encoding.
-        self.linear_pos = ScaledLinear(
-            pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05
-        )
-
-        # the following are for diagnosics only, see --print-diagnostics option
-        self.copy_pos_query = Identity()
-        self.copy_query = Identity()
-
-    def forward(
-        self,
-        x: Tensor,
-        pos_emb: Tensor,
-        key_padding_mask: Optional[Tensor] = None,
-        attn_mask: Optional[Tensor] = None,
-    ) -> Tensor:
-        r"""
-        Args:
-            x: input of shape (seq_len, batch_size, embed_dim)
-            pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
-            key_padding_mask: a bool tensor of shape (batch_size, seq_len).  Positions that
-               are True in this mask will be ignored as sources in the attention weighting.
-            attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
-               interpreted as ([batch_size,] tgt_seq_len, src_seq_len)
-               saying which positions are allowed to attend to which other positions.
-        Returns:
-           a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
-           interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
-        """
-        x = self.in_proj(x)
-        query_head_dim = self.query_head_dim
-        pos_head_dim = self.pos_head_dim
-        num_heads = self.num_heads
-
-        seq_len, batch_size, _ = x.shape
-
-        query_dim = query_head_dim * num_heads
-
-        # self-attention
-        q = x[..., 0:query_dim]
-        k = x[..., query_dim : 2 * query_dim]
-        # p is the position-encoding query
-        p = x[..., 2 * query_dim :]
-        assert p.shape[-1] == num_heads * pos_head_dim
-
-        q = self.copy_query(q)  # for diagnostics only, does nothing.
-        k = self.whiten_keys(self.balance_keys(k))  # does nothing in the forward pass.
-        p = self.copy_pos_query(p)  # for diagnostics only, does nothing.
-
-        q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
-        p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
-        k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
-
-        # time1 refers to target, time2 refers to source.
-        q = q.permute(2, 1, 0, 3)  # (head, batch, time1, query_head_dim)
-        p = p.permute(2, 1, 0, 3)  # (head, batch, time1, pos_head_dim)
-        k = k.permute(2, 1, 3, 0)  # (head, batch, d_k, time2)
-
-        attn_scores = torch.matmul(q, k)
-
-        use_pos_scores = False
-        if torch.jit.is_scripting() or torch.jit.is_tracing():
-            # We can't put random.random() in the same line
-            use_pos_scores = True
-        elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
-            use_pos_scores = True
-
-        if use_pos_scores:
-            pos_emb = self.linear_pos(pos_emb)
-            seq_len2 = 2 * seq_len - 1
-            pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
-                2, 0, 3, 1
-            )
-            # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
-
-            # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
-            #  [where seq_len2 represents relative position.]
-            pos_scores = torch.matmul(p, pos_emb)
-            # the following .as_strided() expression converts the last axis of pos_scores from relative
-            # to absolute position.  I don't know whether I might have got the time-offsets backwards or
-            # not, but let this code define which way round it is supposed to be.
-            if torch.jit.is_tracing():
-                (num_heads, batch_size, time1, n) = pos_scores.shape
-                rows = torch.arange(start=time1 - 1, end=-1, step=-1)
-                cols = torch.arange(seq_len)
-                rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
-                indexes = rows + cols
-                pos_scores = pos_scores.reshape(-1, n)
-                pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
-                pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
-            else:
-                pos_scores = pos_scores.as_strided(
-                    (num_heads, batch_size, seq_len, seq_len),
-                    (
-                        pos_scores.stride(0),
-                        pos_scores.stride(1),
-                        pos_scores.stride(2) - pos_scores.stride(3),
-                        pos_scores.stride(3),
-                    ),
-                    storage_offset=pos_scores.stride(3) * (seq_len - 1),
-                )
-
-            attn_scores = attn_scores + pos_scores
-
-        if torch.jit.is_scripting() or torch.jit.is_tracing():
-            pass
-        elif self.training and random.random() < 0.1:
-            # This is a harder way of limiting the attention scores to not be
-            # too large.  It incurs a penalty if any of them has an absolute
-            # value greater than 50.0.  this should be outside the normal range
-            # of the attention scores.  We use this mechanism instead of, say,
-            # something added to the loss function involving the entropy,
-            # because once the entropy gets very small gradients through the
-            # softmax can become very small, and we'd get zero derivatives.  The
-            # choices of 1.0e-04 as the scale on the penalty makes this
-            # mechanism vulnerable to the absolute scale of the loss function,
-            # but we view this as a failsafe to avoid "implausible" parameter
-            # values rather than a regularization method that should be active
-            # under normal circumstances.
-            attn_scores = penalize_abs_values_gt(
-                attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
-            )
-
-        assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
-
-        if attn_mask is not None:
-            assert attn_mask.dtype == torch.bool
-            # use -1000 to avoid nan's where attn_mask and key_padding_mask make
-            # all scores zero.  It's important that this be large enough that exp(-1000)
-            # is exactly zero, for reasons related to const_attention_rate, it
-            # compares the final weights with zero.
-            attn_scores = attn_scores.masked_fill(attn_mask, -1000)
-
-        if key_padding_mask is not None:
-            assert key_padding_mask.shape == (
-                batch_size,
-                seq_len,
-            ), key_padding_mask.shape
-            attn_scores = attn_scores.masked_fill(
-                key_padding_mask.unsqueeze(1),
-                -1000,
-            )
-
-        # We use our own version of softmax, defined in scaling.py, which should
-        # save a little of the memory used in backprop by, if we are in
-        # automatic mixed precision mode (amp / autocast), by only storing the
-        # half-precision output for backprop purposes.
-        attn_weights = softmax(attn_scores, dim=-1)
-
-        if torch.jit.is_scripting() or torch.jit.is_tracing():
-            pass
-        elif random.random() < 0.001 and not self.training:
-            self._print_attn_entropy(attn_weights)
-
-        attn_weights = nn.functional.dropout(
-            attn_weights, p=self.dropout, training=self.training
-        )
-
-        return attn_weights
-
-    def streaming_forward(
-        self,
-        x: Tensor,
-        pos_emb: Tensor,
-        cached_key: Tensor,
-        left_context_len: int,
-        key_padding_mask: Tensor,
-    ) -> Tuple[Tensor, Tensor]:
-        r"""
-        Args:
-            x: input of shape (seq_len, batch_size, embed_dim)
-            pos_emb: Positional embedding tensor, of shape (1, left_context_len+2*seq_len-1, pos_dim)
-            cached_key: cached attention key tensor of left context,
-              of shape (left_context_len, batch_size, key_dim)
-            left_context_len: number of left context frames.
-            key_padding_mask: a bool tensor of shape (batch_size, seq_len).  Positions that
-              are True in this mask will be ignored as sources in the attention weighting.
-
-        Returns:
-           - attention weights, of shape (hum_heads, batch_size, seq_len, seq_len2),
-             interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
-           - updated cached attention key tensor of left context.
-        """
-        x = self.in_proj(x)
-        query_head_dim = self.query_head_dim
-        pos_head_dim = self.pos_head_dim
-        num_heads = self.num_heads
-
-        seq_len, batch_size, _ = x.shape
-
-        query_dim = query_head_dim * num_heads
-
-        # self-attention
-        q = x[..., 0:query_dim]
-        k = x[..., query_dim : 2 * query_dim]
-        # p is the position-encoding query
-        p = x[..., 2 * query_dim :]
-        assert p.shape[-1] == num_heads * pos_head_dim
-
-        # Pad cached left contexts
-        assert cached_key.shape[0] == left_context_len, (
-            cached_key.shape[0],
-            left_context_len,
-        )
-        k = torch.cat([cached_key, k], dim=0)
-        # Update cached left contexts
-        cached_key = k[-left_context_len:, ...]
-
-        # The length of key
-        k_len = k.shape[0]
-
-        q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
-        p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
-        k = k.reshape(k_len, batch_size, num_heads, query_head_dim)
-
-        # time1 refers to target, time2 refers to source.
-        q = q.permute(2, 1, 0, 3)  # (head, batch, time1, query_head_dim)
-        p = p.permute(2, 1, 0, 3)  # (head, batch, time1, pos_head_dim)
-        k = k.permute(2, 1, 3, 0)  # (head, batch, d_k, time2)
-
-        attn_scores = torch.matmul(q, k)
-
-        pos_emb = self.linear_pos(pos_emb)
-        seq_len2 = 2 * seq_len - 1 + left_context_len
-        pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
-            2, 0, 3, 1
-        )
-        # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
-
-        # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
-        #  [where seq_len2 represents relative position.]
-        pos_scores = torch.matmul(p, pos_emb)
-
-        if torch.jit.is_tracing():
-            (num_heads, batch_size, time1, n) = pos_scores.shape
-            rows = torch.arange(start=time1 - 1, end=-1, step=-1)
-            cols = torch.arange(k_len)
-            rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
-            indexes = rows + cols
-            pos_scores = pos_scores.reshape(-1, n)
-            pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
-            pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
-        # the following .as_strided() expression converts the last axis of pos_scores from relative
-        # to absolute position.  I don't know whether I might have got the time-offsets backwards or
-        # not, but let this code define which way round it is supposed to be.
-        else:
-            pos_scores = pos_scores.as_strided(
-                (num_heads, batch_size, seq_len, k_len),
-                (
-                    pos_scores.stride(0),
-                    pos_scores.stride(1),
-                    pos_scores.stride(2) - pos_scores.stride(3),
-                    pos_scores.stride(3),
-                ),
-                storage_offset=pos_scores.stride(3) * (seq_len - 1),
-            )
-
-        attn_scores = attn_scores + pos_scores
-
-        assert attn_scores.shape == (
-            num_heads,
-            batch_size,
-            seq_len,
-            k_len,
-        ), attn_scores.shape
-
-        if key_padding_mask is not None:
-            assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape
-            attn_scores = attn_scores.masked_fill(
-                key_padding_mask.unsqueeze(1),
-                -1000,
-            )
-
-        attn_weights = attn_scores.softmax(dim=-1)
-
-        return attn_weights, cached_key
-
-    def _print_attn_entropy(self, attn_weights: Tensor):
-        # attn_weights: (num_heads, batch_size, seq_len, seq_len)
-        (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
-
-        with torch.no_grad():
-            with torch.cuda.amp.autocast(enabled=False):
-                attn_weights = attn_weights.to(torch.float32)
-                attn_weights_entropy = (
-                    -((attn_weights + 1.0e-20).log() * attn_weights)
-                    .sum(dim=-1)
-                    .mean(dim=(1, 2))
-                )
-                logging.info(
-                    f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}"
-                )
-
-
-class SelfAttention(nn.Module):
-    """
-    The simplest possible attention module.  This one works with already-computed attention
-    weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
-
-    Args:
-          embed_dim: the input and output embedding dimension
-          num_heads: the number of attention heads
-          value_head_dim: the value dimension per head
-    """
-
-    def __init__(
-        self,
-        embed_dim: int,
-        num_heads: int,
-        value_head_dim: int,
-    ) -> None:
-        super().__init__()
-        self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
-
-        self.out_proj = ScaledLinear(
-            num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
-        )
-
-        self.whiten = Whiten(
-            num_groups=1,
-            whitening_limit=_whitening_schedule(7.5, ratio=3.0),
-            prob=(0.025, 0.25),
-            grad_scale=0.01,
-        )
-
-    def forward(
-        self,
-        x: Tensor,
-        attn_weights: Tensor,
-    ) -> Tensor:
-        """
-        Args:
-          x: input tensor, of shape (seq_len, batch_size, embed_dim)
-         attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
-          with seq_len being interpreted as (tgt_seq_len, src_seq_len).  Expect
-          attn_weights.sum(dim=-1) == 1.
-        Returns:
-           a tensor with the same shape as x.
-        """
-        (seq_len, batch_size, embed_dim) = x.shape
-        num_heads = attn_weights.shape[0]
-        assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
-
-        x = self.in_proj(x)  # (seq_len, batch_size, num_heads * value_head_dim)
-        x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
-        # now x: (num_heads, batch_size, seq_len, value_head_dim)
-        value_head_dim = x.shape[-1]
-
-        # todo: see whether there is benefit in overriding matmul
-        x = torch.matmul(attn_weights, x)
-        # v: (num_heads, batch_size, seq_len, value_head_dim)
-
-        x = (
-            x.permute(2, 1, 0, 3)
-            .contiguous()
-            .view(seq_len, batch_size, num_heads * value_head_dim)
-        )
-
-        # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
-        x = self.out_proj(x)
-        x = self.whiten(x)
-
-        return x
-
-    def streaming_forward(
-        self,
-        x: Tensor,
-        attn_weights: Tensor,
-        cached_val: Tensor,
-        left_context_len: int,
-    ) -> Tuple[Tensor, Tensor]:
-        """
-        Args:
-            x: input tensor, of shape (seq_len, batch_size, embed_dim)
-            attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
-              with seq_len being interpreted as (tgt_seq_len, src_seq_len).  Expect
-              attn_weights.sum(dim=-1) == 1.
-            cached_val: cached attention value tensor of left context,
-              of shape (left_context_len, batch_size, value_dim)
-            left_context_len: number of left context frames.
-
-        Returns:
-           - attention weighted output, a tensor with the same shape as x.
-           - updated cached attention value tensor of left context.
-        """
-        (seq_len, batch_size, embed_dim) = x.shape
-        num_heads = attn_weights.shape[0]
-        seq_len2 = seq_len + left_context_len
-        assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2)
-
-        x = self.in_proj(x)  # (seq_len, batch_size, num_heads * value_head_dim)
-
-        # Pad cached left contexts
-        assert cached_val.shape[0] == left_context_len, (
-            cached_val.shape[0],
-            left_context_len,
-        )
-        x = torch.cat([cached_val, x], dim=0)
-        # Update cached left contexts
-        cached_val = x[-left_context_len:, ...]
-
-        x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3)
-        # now x: (num_heads, batch_size, seq_len, value_head_dim)
-        value_head_dim = x.shape[-1]
-
-        # todo: see whether there is benefit in overriding matmul
-        x = torch.matmul(attn_weights, x)
-        # v: (num_heads, batch_size, seq_len, value_head_dim)
-
-        x = (
-            x.permute(2, 1, 0, 3)
-            .contiguous()
-            .view(seq_len, batch_size, num_heads * value_head_dim)
-        )
-
-        # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
-        x = self.out_proj(x)
-
-        return x, cached_val
-
-
-class FeedforwardModule(nn.Module):
-    """Feedforward module in Zipformer2 model."""
-
-    def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike):
-        super(FeedforwardModule, self).__init__()
-        self.in_proj = nn.Linear(embed_dim, feedforward_dim)
-
-        self.hidden_balancer = Balancer(
-            feedforward_dim,
-            channel_dim=-1,
-            min_positive=0.3,
-            max_positive=1.0,
-            min_abs=0.75,
-            max_abs=5.0,
-        )
-
-        # shared_dim=0 means we share the dropout mask along the time axis
-        self.out_proj = ActivationDropoutAndLinear(
-            feedforward_dim,
-            embed_dim,
-            activation="SwooshL",
-            dropout_p=dropout,
-            dropout_shared_dim=0,
-            bias=True,
-            initial_scale=0.1,
-        )
-
-        self.out_whiten = Whiten(
-            num_groups=1,
-            whitening_limit=_whitening_schedule(7.5),
-            prob=(0.025, 0.25),
-            grad_scale=0.01,
-        )
-
-    def forward(self, x: Tensor):
-        x = self.in_proj(x)
-        x = self.hidden_balancer(x)
-        # out_proj contains SwooshL activation, then dropout, then linear.
-        x = self.out_proj(x)
-        x = self.out_whiten(x)
-        return x
-
-
-class NonlinAttention(nn.Module):
-    """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed
-       from the attention module) in place of actual convolution.  We also took out the second nonlinearity, the
-       one after the attention mechanism.
-
-    Args:
-        channels (int): The number of channels of conv layers.
-    """
-
-    def __init__(
-        self,
-        channels: int,
-        hidden_channels: int,
-    ) -> None:
-        super().__init__()
-
-        self.hidden_channels = hidden_channels
-
-        self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
-
-        # balancer that goes before the sigmoid.  Have quite a large min_abs value, at 2.0,
-        # because we noticed that well-trained instances of this module have abs-value before the sigmoid
-        # starting from about 3, and poorly-trained instances of the module have smaller abs values
-        # before the sigmoid.
-        self.balancer = Balancer(
-            hidden_channels,
-            channel_dim=-1,
-            min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
-            max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
-            min_abs=0.5,
-            max_abs=5.0,
-        )
-        self.tanh = nn.Tanh()
-
-        self.identity1 = Identity()  # for diagnostics.
-        self.identity2 = Identity()  # for diagnostics.
-        self.identity3 = Identity()  # for diagnostics.
-
-        self.out_proj = ScaledLinear(
-            hidden_channels, channels, bias=True, initial_scale=0.05
-        )
-
-        self.whiten1 = Whiten(
-            num_groups=1,
-            whitening_limit=_whitening_schedule(5.0),
-            prob=(0.025, 0.25),
-            grad_scale=0.01,
-        )
-
-        self.whiten2 = Whiten(
-            num_groups=1,
-            whitening_limit=_whitening_schedule(5.0, ratio=3.0),
-            prob=(0.025, 0.25),
-            grad_scale=0.01,
-        )
-
-    def forward(
-        self,
-        x: Tensor,
-        attn_weights: Tensor,
-    ) -> Tensor:
-        """.
-                Args:
-                   x: a Tensor of shape (seq_len, batch_size, num_channels)
-        attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
-                Returns:
-                   a Tensor with the same shape as x
-        """
-        x = self.in_proj(x)
-
-        (seq_len, batch_size, _) = x.shape
-        hidden_channels = self.hidden_channels
-
-        s, x, y = x.chunk(3, dim=2)
-
-        # s will go through tanh.
-
-        s = self.balancer(s)
-        s = self.tanh(s)
-
-        s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
-        x = self.whiten1(x)
-        x = x * s
-        x = self.identity1(x)  # diagnostics only, it's the identity.
-
-        (seq_len, batch_size, embed_dim) = x.shape
-        num_heads = attn_weights.shape[0]
-        assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
-
-        x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
-        # now x: (num_heads, batch_size, seq_len, head_dim)
-        x = torch.matmul(attn_weights, x)
-        # now x: (num_heads, batch_size, seq_len, head_dim)
-        x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
-
-        y = self.identity2(y)
-        x = x * y
-        x = self.identity3(x)
-
-        x = self.out_proj(x)
-        x = self.whiten2(x)
-        return x
-
-    def streaming_forward(
-        self,
-        x: Tensor,
-        attn_weights: Tensor,
-        cached_x: Tensor,
-        left_context_len: int,
-    ) -> Tuple[Tensor, Tensor]:
-        """.
-        Args:
-            x: a Tensor of shape (seq_len, batch_size, num_channels)
-            attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
-            cached_x: left context, a Tensor of shape
-              (num_heads, batch_size, left_context_len, head_dim)
-            left_context_len: number of left context frames.
-        Returns:
-            - a Tensor with the same shape as x
-            - updated left context with same shape as cached_x
-        """
-        x = self.in_proj(x)
-
-        (seq_len, batch_size, _) = x.shape
-        hidden_channels = self.hidden_channels
-
-        s, x, y = x.chunk(3, dim=2)
-
-        # s will go through tanh.
-        s = self.tanh(s)
-
-        s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
-        x = x * s
-
-        (seq_len, batch_size, embed_dim) = x.shape
-        num_heads = attn_weights.shape[0]
-        assert attn_weights.shape == (
-            num_heads,
-            batch_size,
-            seq_len,
-            left_context_len + seq_len,
-        )
-
-        x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
-        # now x: (num_heads, batch_size, seq_len, head_dim)
-
-        # Pad cached tensor
-        assert cached_x.shape[2] == left_context_len, (
-            cached_x.shape[2],
-            left_context_len,
-        )
-        x_pad = torch.cat([cached_x, x], dim=2)
-        # Update cached tensor
-        cached_x = x_pad[:, :, -left_context_len:, :]
-
-        x = torch.matmul(attn_weights, x_pad)
-        # now x: (num_heads, batch_size, seq_len, head_dim)
-        x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
-
-        x = x * y
-
-        x = self.out_proj(x)
-        return x, cached_x
-
-
-class ConvolutionModule(nn.Module):
-    """ConvolutionModule in Zipformer2 model.
-    Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
-
-    Args:
-        channels (int): The number of channels of conv layers.
-        kernel_size (int): Kernerl size of conv layers.
-        bias (bool): Whether to use bias in conv layers (default=True).
-
-    """
-
-    def __init__(
-        self,
-        channels: int,
-        kernel_size: int,
-        causal: bool,
-    ) -> None:
-        """Construct a ConvolutionModule object."""
-        super(ConvolutionModule, self).__init__()
-        # kernerl_size should be a odd number for 'SAME' padding
-        assert (kernel_size - 1) % 2 == 0
-
-        bottleneck_dim = channels
-        self.causal = causal
-
-        self.in_proj = nn.Linear(
-            channels,
-            2 * bottleneck_dim,
-        )
-        # the gradients on in_proj are a little noisy, likely to do with the
-        # sigmoid in glu.
-
-        # after in_proj we put x through a gated linear unit (nn.functional.glu).
-        # For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
-        # but sometimes, for some reason, for layer 0 the rms ends up being very large,
-        # between 50 and 100 for different channels.  This will cause very peaky and
-        # sparse derivatives for the sigmoid gating function, which will tend to make
-        # the loss function not learn effectively.  (for most layers the average absolute values
-        # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
-        # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
-        # layers, which likely breaks down as 0.5 for the "linear" half and
-        # 0.2 to 0.3 for the part that goes into the sigmoid.  The idea is that if we
-        # constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
-        # it will be in a better position to start learning something, i.e. to latch onto
-        # the correct range.
-        self.balancer1 = Balancer(
-            bottleneck_dim,
-            channel_dim=-1,
-            min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
-            max_positive=1.0,
-            min_abs=1.5,
-            max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
-        )
-
-        self.activation1 = Identity()  # for diagnostics
-
-        self.sigmoid = nn.Sigmoid()
-
-        self.activation2 = Identity()  # for diagnostics
-
-        assert kernel_size % 2 == 1
-
-        self.depthwise_conv = (
-            ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size)
-            if causal
-            else nn.Conv1d(
-                in_channels=bottleneck_dim,
-                out_channels=bottleneck_dim,
-                groups=bottleneck_dim,
-                kernel_size=kernel_size,
-                padding=kernel_size // 2,
-            )
-        )
-
-        self.balancer2 = Balancer(
-            bottleneck_dim,
-            channel_dim=1,
-            min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
-            max_positive=1.0,
-            min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
-            max_abs=10.0,
-        )
-
-        self.whiten = Whiten(
-            num_groups=1,
-            whitening_limit=_whitening_schedule(7.5),
-            prob=(0.025, 0.25),
-            grad_scale=0.01,
-        )
-
-        self.out_proj = ActivationDropoutAndLinear(
-            bottleneck_dim,
-            channels,
-            activation="SwooshR",
-            dropout_p=0.0,
-            initial_scale=0.05,
-        )
-
-    def forward(
-        self,
-        x: Tensor,
-        src_key_padding_mask: Optional[Tensor] = None,
-        chunk_size: int = -1,
-    ) -> Tensor:
-        """Compute convolution module.
-
-        Args:
-            x: Input tensor (#time, batch, channels).
-           src_key_padding_mask: the mask for the src keys per batch (optional):
-               (batch, #time), contains True in masked positions.
-
-        Returns:
-            Tensor: Output tensor (#time, batch, channels).
-
-        """
-
-        x = self.in_proj(x)  # (time, batch, 2*channels)
-
-        x, s = x.chunk(2, dim=2)
-        s = self.balancer1(s)
-        s = self.sigmoid(s)
-        x = self.activation1(x)  # identity.
-        x = x * s
-        x = self.activation2(x)  # identity
-
-        # (time, batch, channels)
-
-        # exchange the temporal dimension and the feature dimension
-        x = x.permute(1, 2, 0)  # (#batch, channels, time).
-
-        if src_key_padding_mask is not None:
-            x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
-
-        if (
-            not torch.jit.is_scripting()
-            and not torch.jit.is_tracing()
-            and chunk_size >= 0
-        ):
-            # Not support exporting a model for simulated streaming decoding
-            assert (
-                self.causal
-            ), "Must initialize model with causal=True if you use chunk_size"
-            x = self.depthwise_conv(x, chunk_size=chunk_size)
-        else:
-            x = self.depthwise_conv(x)
-
-        x = self.balancer2(x)
-        x = x.permute(2, 0, 1)  # (time, batch, channels)
-
-        x = self.whiten(x)  # (time, batch, channels)
-        x = self.out_proj(x)  # (time, batch, channels)
-
-        return x
-
-    def streaming_forward(
-        self,
-        x: Tensor,
-        cache: Tensor,
-        src_key_padding_mask: Tensor,
-    ) -> Tuple[Tensor, Tensor]:
-        """Compute convolution module in streaming forward mode.
-
-        Args:
-            x: Input tensor (#time, batch, channels).
-            cache: cached left context for depthwise_conv of shape
-              (#batch, channels, left_pad)
-            src_key_padding_mask: the mask for the src keys per batch (optional):
-              (batch, #time), contains True in masked positions.
-
-        Returns:
-            - Output tensor (#time, batch, channels).
-            - Updated cache (#batch, channels, left_pad)
-        """
-
-        x = self.in_proj(x)  # (time, batch, 2*channels)
-
-        x, s = x.chunk(2, dim=2)
-        s = self.sigmoid(s)
-        x = x * s
-        # (time, batch, channels)
-
-        # exchange the temporal dimension and the feature dimension
-        x = x.permute(1, 2, 0)  # (#batch, channels, time).
-
-        if src_key_padding_mask is not None:
-            x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
-
-        x, cache = self.depthwise_conv.streaming_forward(x, cache=cache)
-
-        x = x.permute(2, 0, 1)  # (time, batch, channels)
-
-        x = self.out_proj(x)  # (time, batch, channels)
-
-        return x, cache
-
-
-class ScalarMultiply(nn.Module):
-    def __init__(self, scale: float):
-        super().__init__()
-        self.scale = scale
-
-    def forward(self, x):
-        return x * self.scale
-
-
-def _test_zipformer_main(causal: bool = False):
-    batch_size = 5
-    seq_len = 20
-    # Just make sure the forward pass runs.
-
-    c = Zipformer2(
-        encoder_dim=(64, 96),
-        encoder_unmasked_dim=(48, 64),
-        num_heads=(4, 4),
-        causal=causal,
-        chunk_size=(4,) if causal else (-1,),
-        left_context_frames=(64,),
-    )
-    batch_size = 5
-    seq_len = 20
-    # Just make sure the forward pass runs.
-    f = c(
-        torch.randn(seq_len, batch_size, 64),
-        torch.full((batch_size,), seq_len, dtype=torch.int64),
-    )
-    f[0].sum().backward()
-    c.eval()
-    f = c(
-        torch.randn(seq_len, batch_size, 64),
-        torch.full((batch_size,), seq_len, dtype=torch.int64),
-    )
-    f  # to remove flake8 warnings
-
-
-if __name__ == "__main__":
-    logging.getLogger().setLevel(logging.INFO)
-    torch.set_num_threads(1)
-    torch.set_num_interop_threads(1)
-    _test_zipformer_main(False)
-    _test_zipformer_main(True)
\ No newline at end of file
-- 
Gitee