diff --git a/speech/speech_recognition/transformer/pytorch/README.md b/speech/speech_recognition/transformer/pytorch/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6ac6c24367001dca65a9b91e8b7b494ec3a8a62b --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/README.md @@ -0,0 +1,44 @@ +# Transformer + +## Step 1: Installing packages + +``` +pip3 install -r requirements.txt +``` + +## Step 2: Training + +Dataset is data_aishell.tgz and resource_aishell.tgz from wenet. +You could just run the whole script, which will download the dataset automatically. +``` +bash run.sh --stage -1 --stop-stage 6 +``` +Or you also run each stage one by one manually and check the result to understand the whole process. +``` +# Download data +bash run.sh --stage -1 --stop-stage -1 +# Prepare Training data +bash run.sh --stage 0 --stop-stage 0 +# Extract optinal cmvn features +bash run.sh --stage 1 --stop-stage 1 +# Generate label token dictionary +bash run.sh --stage 2 --stop-stage 2 +# Prepare WeNet data format +bash run.sh --stage 3 --stop-stage 3 +# Neural Network training +bash run.sh --stage 4 --stop-stage 4 +# Recognize wav using the trained model +bash run.sh --stage 5 --stop-stage 5 +# Export the trained model +bash run.sh --stage 6 --stop-stage 6 +``` + +## Results on BI-V100 + +| GPUs | FP16 | QPS | +|------|-------|-----| +| 1x8 | False | 597 | + + +## Reference +https://github.com/wenet-e2e/wenet diff --git a/speech/speech_recognition/transformer/pytorch/conf/train_transformer.yaml b/speech/speech_recognition/transformer/pytorch/conf/train_transformer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7d7eee83ace095b4c7a09e61fd63776cb50b2d6 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/conf/train_transformer.yaml @@ -0,0 +1,72 @@ +# network architecture +# encoder related +encoder: transformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder architecture type + normalize_before: true + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +dataset_conf: + filter_conf: + max_length: 40960 + min_length: 0 + token_max_length: 200 + token_min_length: 1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 0.1 + spec_aug: true + spec_aug_conf: + num_t_mask: 2 + num_f_mask: 2 + max_t: 50 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: true + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 26 + +grad_clip: 5 +accum_grad: 1 +max_epoch: 240 +log_interval: 100 + +optim: adam +optim_conf: + lr: 0.002 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + warmup_steps: 25000 diff --git a/speech/speech_recognition/transformer/pytorch/local/aishell_data_prep.sh b/speech/speech_recognition/transformer/pytorch/local/aishell_data_prep.sh new file mode 100644 index 0000000000000000000000000000000000000000..fb4d5fb0adefb9e3e3ebeaa5ccb1a92562eb77c1 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/local/aishell_data_prep.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +# Copyright 2017 Xingyu Na +# Apache 2.0 + +. ./path.sh || exit 1; + +if [ $# != 2 ]; then + echo "Usage: $0 " + echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript" + exit 1; +fi + +aishell_audio_dir=$1 +aishell_text=$2/aishell_transcript_v0.8.txt + +train_dir=data/local/train +dev_dir=data/local/dev +test_dir=data/local/test +tmp_dir=data/local/tmp + +mkdir -p $train_dir +mkdir -p $dev_dir +mkdir -p $test_dir +mkdir -p $tmp_dir + +# data directory check +if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then + echo "Error: $0 requires two directory arguments" + exit 1; +fi + +# find wav audio file for train, dev and test resp. +find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist +n=`cat $tmp_dir/wav.flist | wc -l` +[ $n -ne 141925 ] && \ + echo Warning: expected 141925 data data files, found $n + +grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1; +grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1; +grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1; + +rm -r $tmp_dir + +# Transcriptions preparation +for dir in $train_dir $dev_dir $test_dir; do + echo Preparing $dir transcriptions + sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list + paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all + tools/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt + awk '{print $1}' $dir/transcripts.txt > $dir/utt.list + tools/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp + sort -u $dir/transcripts.txt > $dir/text +done + +mkdir -p data/train data/dev data/test + +for f in wav.scp text; do + cp $train_dir/$f data/train/$f || exit 1; + cp $dev_dir/$f data/dev/$f || exit 1; + cp $test_dir/$f data/test/$f || exit 1; +done + +echo "$0: AISHELL data preparation succeeded" +exit 0; diff --git a/speech/speech_recognition/transformer/pytorch/local/aishell_train_lms.sh b/speech/speech_recognition/transformer/pytorch/local/aishell_train_lms.sh new file mode 100644 index 0000000000000000000000000000000000000000..30ffb7973b3ddec4ef4c0f09c8184837cad768d6 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/local/aishell_train_lms.sh @@ -0,0 +1,59 @@ +#!/bin/bash + + +# To be run from one directory above this script. +. ./path.sh + +text=data/local/lm/text +lexicon=data/local/dict/lexicon.txt + +for f in "$text" "$lexicon"; do + [ ! -f $x ] && echo "$0: No such file $f" && exit 1; +done + +# Check SRILM tools +if ! which ngram-count > /dev/null; then + echo "srilm tools are not found, please download it and install it from: " + echo "http://www.speech.sri.com/projects/srilm/download.html" + echo "Then add the tools to your PATH" + exit 1 +fi + +# This script takes no arguments. It assumes you have already run +# aishell_data_prep.sh. +# It takes as input the files +# data/local/lm/text +# data/local/dict/lexicon.txt +dir=data/local/lm +mkdir -p $dir + + +cleantext=$dir/text.no_oov + +cat $text | awk -v lex=$lexicon 'BEGIN{while((getline0){ seen[$1]=1; } } + {for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf(" ");} } printf("\n");}' \ + > $cleantext || exit 1; + +cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort | uniq -c | \ + sort -nr > $dir/word.counts || exit 1; + +# Get counts from acoustic training transcripts, and add one-count +# for each word in the lexicon (but not silence, we don't want it +# in the LM-- we'll add it optionally later). +cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \ + cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \ + sort | uniq -c | sort -nr > $dir/unigram.counts || exit 1; + +cat $dir/unigram.counts | awk '{print $2}' | cat - <(echo ""; echo "" ) > $dir/wordlist + +heldout_sent=10000 # Don't change this if you want result to be comparable with + # kaldi_lm results +mkdir -p $dir +cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n $dir/heldout +cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n $dir/train + +ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \ + -map-unk "" -kndiscount -interpolate -lm $dir/lm.arpa +ngram -lm $dir/lm.arpa -ppl $dir/heldout diff --git a/speech/speech_recognition/transformer/pytorch/local/download_and_untar.sh b/speech/speech_recognition/transformer/pytorch/local/download_and_untar.sh new file mode 100644 index 0000000000000000000000000000000000000000..58a278241d75caeba25ba4b17d186912d0d724ec --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/local/download_and_untar.sh @@ -0,0 +1,105 @@ +#!/bin/bash + +# Copyright 2014 Johns Hopkins University (author: Daniel Povey) +# 2017 Xingyu Na +# Apache 2.0 + +remove_archive=false + +if [ "$1" == --remove-archive ]; then + remove_archive=true + shift +fi + +if [ $# -ne 3 ]; then + echo "Usage: $0 [--remove-archive] " + echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell" + echo "With --remove-archive it will remove the archive after successfully un-tarring it." + echo " can be one of: data_aishell, resource_aishell." +fi + +data=$1 +url=$2 +part=$3 + +if [ ! -d "$data" ]; then + echo "$0: no such directory $data" + exit 1; +fi + +part_ok=false +list="data_aishell resource_aishell" +for x in $list; do + if [ "$part" == $x ]; then part_ok=true; fi +done +if ! $part_ok; then + echo "$0: expected to be one of $list, but got '$part'" + exit 1; +fi + +if [ -z "$url" ]; then + echo "$0: empty URL base." + exit 1; +fi + +if [ -f $data/$part/.complete ]; then + echo "$0: data part $part was already successfully extracted, nothing to do." + exit 0; +fi + +# sizes of the archive files in bytes. +sizes="15582913665 1246920" + +if [ -f $data/$part.tgz ]; then + size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}') + size_ok=false + for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done + if ! $size_ok; then + echo "$0: removing existing file $data/$part.tgz because its size in bytes $size" + echo "does not equal the size of one of the archives." + rm $data/$part.tgz + else + echo "$data/$part.tgz exists and appears to be complete." + fi +fi + +if [ ! -f $data/$part.tgz ]; then + if ! which wget >/dev/null; then + echo "$0: wget is not installed." + exit 1; + fi + full_url=$url/$part.tgz + echo "$0: downloading data from $full_url. This may take some time, please be patient." + + cd $data + if ! wget --no-check-certificate $full_url; then + echo "$0: error executing wget $full_url" + exit 1; + fi +fi + +cd $data + +if ! tar -xvzf $part.tgz; then + echo "$0: error un-tarring archive $data/$part.tgz" + exit 1; +fi + +touch $data/$part/.complete + +if [ $part == "data_aishell" ]; then + cd $data/$part/wav + for wav in ./*.tar.gz; do + echo "Extracting wav from $wav" + tar -zxf $wav && rm $wav + done +fi + +echo "$0: Successfully downloaded and un-tarred $data/$part.tgz" + +if $remove_archive; then + echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied." + rm $data/$part.tgz +fi + +exit 0; diff --git a/speech/speech_recognition/transformer/pytorch/path.sh b/speech/speech_recognition/transformer/pytorch/path.sh new file mode 100644 index 0000000000000000000000000000000000000000..0805a5e7cfc5dc1797d4c86775c98d5b608a03c7 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/path.sh @@ -0,0 +1,8 @@ +export WENET_DIR=$PWD/../../.. +export BUILD_DIR=${WENET_DIR}/runtime/server/x86/build +export OPENFST_PREFIX_DIR=${BUILD_DIR}/../fc_base/openfst-subbuild/openfst-populate-prefix +export PATH=$PWD:${BUILD_DIR}:${BUILD_DIR}/kaldi:${OPENFST_PREFIX_DIR}/bin:$PATH + +# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=../../../:$PYTHONPATH \ No newline at end of file diff --git a/speech/speech_recognition/transformer/pytorch/requirements.txt b/speech/speech_recognition/transformer/pytorch/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2ea79d3bfb5b2015d89432cebe9f547af060d459 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/requirements.txt @@ -0,0 +1,6 @@ +pyyaml>=5.1 +sentencepiece==0.1.86 +tensorboard +tensorboardX +typeguard +textgrid \ No newline at end of file diff --git a/speech/speech_recognition/transformer/pytorch/run.sh b/speech/speech_recognition/transformer/pytorch/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..a79e6e5d5d0ef5179c194307d2669c81003de57c --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/run.sh @@ -0,0 +1,242 @@ +#!/bin/bash + +# Copyright 2019 Mobvoi Inc. All Rights Reserved. +. ./path.sh || exit 1; + +# Use this to control how many gpu you use, It's 1-gpu training if you specify +# just 1gpu, otherwise it's is multiple gpu training based on DDP in pytorch +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" +# The NCCL_SOCKET_IFNAME variable specifies which IP interface to use for nccl +# communication. More details can be found in +# https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html +# export NCCL_SOCKET_IFNAME=ens4f1 +export NCCL_DEBUG=INFO +stage=0 # start from 0 if you need to start from data preparation +stop_stage=5 + +# The num of machines(nodes) for multi-machine training, 1 is for one machine. +# NFS is required if num_nodes > 1. +num_nodes=1 + +# The rank of each node or machine, which ranges from 0 to `num_nodes - 1`. +# You should set the node_rank=0 on the first machine, set the node_rank=1 +# on the second machine, and so on. +node_rank=0 +# The aishell dataset location, please change this to your own path +# make sure of using absolute path. DO-NOT-USE relatvie path! +data=/export/data/asr-data/OpenSLR/33/ +data_url=www.openslr.org/resources/33 + +nj=16 +dict=data/dict/lang_char.txt + +# data_type can be `raw` or `shard`. Typically, raw is used for small dataset, +# `shard` is used for large dataset which is over 1k hours, and `shard` is +# faster on reading data and training. +data_type=raw +num_utts_per_shard=1000 + +train_set=train +# Optional train_config +# 1. conf/train_transformer.yaml: Standard transformer +# 2. conf/train_conformer.yaml: Standard conformer +# 3. conf/train_unified_conformer.yaml: Unified dynamic chunk causal conformer +# 4. conf/train_unified_transformer.yaml: Unified dynamic chunk transformer +# 5. conf/train_u2++_conformer.yaml: U2++ conformer +# 6. conf/train_u2++_transformer.yaml: U2++ transformer +train_config=conf/train_transformer.yaml +cmvn=true +dir=exp/transformer +checkpoint= + +# use average_checkpoint will get better result +average_checkpoint=true +decode_checkpoint=$dir/final.pt +average_num=30 +decode_modes="ctc_greedy_search ctc_prefix_beam_search attention attention_rescoring" + +. tools/parse_options.sh || exit 1; + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + echo "stage -1: Data Download" + local/download_and_untar.sh ${data} ${data_url} data_aishell + local/download_and_untar.sh ${data} ${data_url} resource_aishell +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # Data preparation + local/aishell_data_prep.sh ${data}/data_aishell/wav \ + ${data}/data_aishell/transcript +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # remove the space between the text labels for Mandarin dataset + for x in train dev test; do + cp data/${x}/text data/${x}/text.org + paste -d " " <(cut -f 1 -d" " data/${x}/text.org) \ + <(cut -f 2- -d" " data/${x}/text.org | tr -d " ") \ + > data/${x}/text + rm data/${x}/text.org + done + + tools/compute_cmvn_stats.py --num_workers 16 --train_config $train_config \ + --in_scp data/${train_set}/wav.scp \ + --out_cmvn data/$train_set/global_cmvn +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "Make a dictionary" + mkdir -p $(dirname $dict) + echo " 0" > ${dict} # 0 is for "blank" in CTC + echo " 1" >> ${dict} # must be 1 + tools/text2token.py -s 1 -n 1 data/train/text | cut -f 2- -d" " \ + | tr " " "\n" | sort | uniq | grep -a -v -e '^\s*$' | \ + awk '{print $0 " " NR+1}' >> ${dict} + num_token=$(cat $dict | wc -l) + echo " $num_token" >> $dict +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "Prepare data, prepare required format" + for x in dev test ${train_set}; do + if [ $data_type == "shard" ]; then + tools/make_shard_list.py --num_utts_per_shard $num_utts_per_shard \ + --num_threads 16 data/$x/wav.scp data/$x/text \ + $(realpath data/$x/shards) data/$x/data.list + else + tools/make_raw_list.py data/$x/wav.scp data/$x/text \ + data/$x/data.list + fi + done +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + mkdir -p $dir + # You have to rm `INIT_FILE` manually when you resume or restart a + # multi-machine training. + # INIT_FILE=$dir/ddp_init + # init_method=file://$(readlink -f $INIT_FILE) + # echo "$0: init method is $init_method" + num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') + # Use "nccl" if it works, otherwise use "gloo" + dist_backend="nccl" + world_size=`expr $num_gpus \* $num_nodes` + echo "total gpus is: $world_size" + cmvn_opts= + $cmvn && cp data/${train_set}/global_cmvn $dir + $cmvn && cmvn_opts="--cmvn ${dir}/global_cmvn" + export MASTER_ADDR="0.0.0.0" + export MASTER_PORT=11118 + # train.py rewrite $train_config to $dir/train.yaml with model input + # and output dimension, and $dir/train.yaml will be used for inference + # and export. + for ((i = 0; i < $num_gpus; ++i)); do + { + gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) + # Rank of each gpu/process used for knowing whether it is + # the master of a worker. + rank=`expr $node_rank \* $num_gpus + $i` + python3 wenet/bin/train.py --gpu $gpu_id \ + --config $train_config \ + --data_type $data_type \ + --symbol_table $dict \ + --train_data data/$train_set/data.list \ + --cv_data data/dev/data.list \ + ${checkpoint:+--checkpoint $checkpoint} \ + --model_dir $dir \ + --ddp.world_size $world_size \ + --ddp.rank $rank \ + --ddp.dist_backend $dist_backend \ + --num_workers 1 \ + $cmvn_opts \ + --pin_memory + } & + done + wait +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # Test model, please specify the model you want to test by --checkpoint + if [ ${average_checkpoint} == true ]; then + decode_checkpoint=$dir/avg_${average_num}.pt + echo "do model average and final checkpoint is $decode_checkpoint" + python3 wenet/bin/average_model.py \ + --dst_model $decode_checkpoint \ + --src_path $dir \ + --num ${average_num} \ + --val_best + fi + # Please specify decoding_chunk_size for unified streaming and + # non-streaming model. The default value is -1, which is full chunk + # for non-streaming inference. + decoding_chunk_size= + ctc_weight=0.5 + reverse_weight=0.0 + for mode in ${decode_modes}; do + { + test_dir=$dir/test_${mode} + mkdir -p $test_dir + python3 wenet/bin/recognize.py --gpu 0 \ + --mode $mode \ + --config $dir/train.yaml \ + --data_type $data_type \ + --test_data data/test/data.list \ + --checkpoint $decode_checkpoint \ + --beam_size 10 \ + --batch_size 1 \ + --penalty 0.0 \ + --dict $dict \ + --ctc_weight $ctc_weight \ + --reverse_weight $reverse_weight \ + --result_file $test_dir/text \ + ${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size} + python3 tools/compute-wer.py --char=1 --v=1 \ + data/test/text $test_dir/text > $test_dir/wer + } & + done + wait +fi + + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + # Export the best model you want + python3 wenet/bin/export_jit.py \ + --config $dir/train.yaml \ + --checkpoint $dir/avg_${average_num}.pt \ + --output_file $dir/final.zip \ + --output_quant_file $dir/final_quant.zip +fi + +# Optionally, you can add LM and test it with runtime. +if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then + # 7.1 Prepare dict + unit_file=$dict + mkdir -p data/local/dict + cp $unit_file data/local/dict/units.txt + tools/fst/prepare_dict.py $unit_file ${data}/resource_aishell/lexicon.txt \ + data/local/dict/lexicon.txt + # 7.2 Train lm + lm=data/local/lm + mkdir -p $lm + tools/filter_scp.pl data/train/text \ + $data/data_aishell/transcript/aishell_transcript_v0.8.txt > $lm/text + local/aishell_train_lms.sh + # 7.3 Build decoding TLG + tools/fst/compile_lexicon_token_fst.sh \ + data/local/dict data/local/tmp data/local/lang + tools/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1; + # 7.4 Decoding with runtime + chunk_size=-1 + ./tools/decode.sh --nj 16 \ + --beam 15.0 --lattice_beam 7.5 --max_active 7000 \ + --blank_skip_thresh 0.98 --ctc_weight 0.5 --rescoring_weight 1.0 \ + --chunk_size $chunk_size \ + --fst_path data/lang_test/TLG.fst \ + --dict_path data/lang_test/words.txt \ + data/test/wav.scp data/test/text $dir/final.zip \ + data/lang_test/units.txt $dir/lm_with_runtime + # Please see $dir/lm_with_runtime for wer +fi + + diff --git a/speech/speech_recognition/transformer/pytorch/tools/compute-wer.py b/speech/speech_recognition/transformer/pytorch/tools/compute-wer.py new file mode 100644 index 0000000000000000000000000000000000000000..a3eefc0dc7b67f252e685da71a5189312e74ef85 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/tools/compute-wer.py @@ -0,0 +1,500 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +import re, sys, unicodedata +import codecs + +remove_tag = True +spacelist= [' ', '\t', '\r', '\n'] +puncts = ['!', ',', '?', + '、', '。', '!', ',', ';', '?', + ':', '「', '」', '︰', '『', '』', '《', '》'] + +def characterize(string) : + res = [] + i = 0 + while i < len(string): + char = string[i] + if char in puncts: + i += 1 + continue + cat1 = unicodedata.category(char) + #https://unicodebook.readthedocs.io/unicode.html#unicode-categories + if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned + i += 1 + continue + if cat1 == 'Lo': # letter-other + res.append(char) + i += 1 + else: + # some input looks like: , we want to separate it to two words. + sep = ' ' + if char == '<': sep = '>' + j = i+1 + while j < len(string): + c = string[j] + if ord(c) >= 128 or (c in spacelist) or (c==sep): + break + j += 1 + if j < len(string) and string[j] == '>': + j += 1 + res.append(string[i:j]) + i = j + return res + +def stripoff_tags(x): + if not x: return '' + chars = [] + i = 0; T=len(x) + while i < T: + if x[i] == '<': + while i < T and x[i] != '>': + i += 1 + i += 1 + else: + chars.append(x[i]) + i += 1 + return ''.join(chars) + + +def normalize(sentence, ignore_words, cs, split=None): + """ sentence, ignore_words are both in unicode + """ + new_sentence = [] + for token in sentence: + x = token + if not cs: + x = x.upper() + if x in ignore_words: + continue + if remove_tag: + x = stripoff_tags(x) + if not x: + continue + if split and x in split: + new_sentence += split[x] + else: + new_sentence.append(x) + return new_sentence + +class Calculator : + def __init__(self) : + self.data = {} + self.space = [] + self.cost = {} + self.cost['cor'] = 0 + self.cost['sub'] = 1 + self.cost['del'] = 1 + self.cost['ins'] = 1 + def calculate(self, lab, rec) : + # Initialization + lab.insert(0, '') + rec.insert(0, '') + while len(self.space) < len(lab) : + self.space.append([]) + for row in self.space : + for element in row : + element['dist'] = 0 + element['error'] = 'non' + while len(row) < len(rec) : + row.append({'dist' : 0, 'error' : 'non'}) + for i in range(len(lab)) : + self.space[i][0]['dist'] = i + self.space[i][0]['error'] = 'del' + for j in range(len(rec)) : + self.space[0][j]['dist'] = j + self.space[0][j]['error'] = 'ins' + self.space[0][0]['error'] = 'non' + for token in lab : + if token not in self.data and len(token) > 0 : + self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0} + for token in rec : + if token not in self.data and len(token) > 0 : + self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0} + # Computing edit distance + for i, lab_token in enumerate(lab) : + for j, rec_token in enumerate(rec) : + if i == 0 or j == 0 : + continue + min_dist = sys.maxsize + min_error = 'none' + dist = self.space[i-1][j]['dist'] + self.cost['del'] + error = 'del' + if dist < min_dist : + min_dist = dist + min_error = error + dist = self.space[i][j-1]['dist'] + self.cost['ins'] + error = 'ins' + if dist < min_dist : + min_dist = dist + min_error = error + if lab_token == rec_token : + dist = self.space[i-1][j-1]['dist'] + self.cost['cor'] + error = 'cor' + else : + dist = self.space[i-1][j-1]['dist'] + self.cost['sub'] + error = 'sub' + if dist < min_dist : + min_dist = dist + min_error = error + self.space[i][j]['dist'] = min_dist + self.space[i][j]['error'] = min_error + # Tracing back + result = {'lab':[], 'rec':[], 'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} + i = len(lab) - 1 + j = len(rec) - 1 + while True : + if self.space[i][j]['error'] == 'cor' : # correct + if len(lab[i]) > 0 : + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 + result['all'] = result['all'] + 1 + result['cor'] = result['cor'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'sub' : # substitution + if len(lab[i]) > 0 : + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 + result['all'] = result['all'] + 1 + result['sub'] = result['sub'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'del' : # deletion + if len(lab[i]) > 0 : + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 + result['all'] = result['all'] + 1 + result['del'] = result['del'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, "") + i = i - 1 + elif self.space[i][j]['error'] == 'ins' : # insertion + if len(rec[j]) > 0 : + self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 + result['ins'] = result['ins'] + 1 + result['lab'].insert(0, "") + result['rec'].insert(0, rec[j]) + j = j - 1 + elif self.space[i][j]['error'] == 'non' : # starting point + break + else : # shouldn't reach here + print('this should not happen , i = {i} , j = {j} , error = {error}'.format(i = i, j = j, error = self.space[i][j]['error'])) + return result + def overall(self) : + result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} + for token in self.data : + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + def cluster(self, data) : + result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} + for token in data : + if token in self.data : + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + def keys(self) : + return list(self.data.keys()) + +def width(string): + return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) + +def default_cluster(word) : + unicode_names = [ unicodedata.name(char) for char in word ] + for i in reversed(range(len(unicode_names))) : + if unicode_names[i].startswith('DIGIT') : # 1 + unicode_names[i] = 'Number' # 'DIGIT' + elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or + unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')) : + # 明 / 郎 + unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' + elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or + unicode_names[i].startswith('LATIN SMALL LETTER')) : + # A / a + unicode_names[i] = 'English' # 'LATIN LETTER' + elif unicode_names[i].startswith('HIRAGANA LETTER') : # は こ め + unicode_names[i] = 'Japanese' # 'GANA LETTER' + elif (unicode_names[i].startswith('AMPERSAND') or + unicode_names[i].startswith('APOSTROPHE') or + unicode_names[i].startswith('COMMERCIAL AT') or + unicode_names[i].startswith('DEGREE CELSIUS') or + unicode_names[i].startswith('EQUALS SIGN') or + unicode_names[i].startswith('FULL STOP') or + unicode_names[i].startswith('HYPHEN-MINUS') or + unicode_names[i].startswith('LOW LINE') or + unicode_names[i].startswith('NUMBER SIGN') or + unicode_names[i].startswith('PLUS SIGN') or + unicode_names[i].startswith('SEMICOLON')) : + # & / ' / @ / ℃ / = / . / - / _ / # / + / ; + del unicode_names[i] + else : + return 'Other' + if len(unicode_names) == 0 : + return 'Other' + if len(unicode_names) == 1 : + return unicode_names[0] + for i in range(len(unicode_names)-1) : + if unicode_names[i] != unicode_names[i+1] : + return 'Other' + return unicode_names[0] + +def usage() : + print("compute-wer.py : compute word error rate (WER) and align recognition results and references.") + print(" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer") + +if __name__ == '__main__': + if len(sys.argv) == 1 : + usage() + sys.exit(0) + calculator = Calculator() + cluster_file = '' + ignore_words = set() + tochar = False + verbose= 1 + padding_symbol= ' ' + case_sensitive = False + max_words_per_line = sys.maxsize + split = None + while len(sys.argv) > 3: + a = '--maxw=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):] + del sys.argv[1] + max_words_per_line = int(b) + continue + a = '--rt=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + remove_tag = (b == 'true') or (b != '0') + continue + a = '--cs=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + case_sensitive = (b == 'true') or (b != '0') + continue + a = '--cluster=' + if sys.argv[1].startswith(a): + cluster_file = sys.argv[1][len(a):] + del sys.argv[1] + continue + a = '--splitfile=' + if sys.argv[1].startswith(a): + split_file = sys.argv[1][len(a):] + del sys.argv[1] + split = dict() + with codecs.open(split_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + words = line.strip().split() + if len(words) >= 2: + split[words[0]] = words[1:] + continue + a = '--ig=' + if sys.argv[1].startswith(a): + ignore_file = sys.argv[1][len(a):] + del sys.argv[1] + with codecs.open(ignore_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + line = line.strip() + if len(line) > 0: + ignore_words.add(line) + continue + a = '--char=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + tochar = (b == 'true') or (b != '0') + continue + a = '--v=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + verbose=0 + try: + verbose=int(b) + except: + if b == 'true' or b != '0': + verbose = 1 + continue + a = '--padding-symbol=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + if b == 'space': + padding_symbol= ' ' + elif b == 'underline': + padding_symbol= '_' + continue + if True or sys.argv[1].startswith('-'): + #ignore invalid switch + del sys.argv[1] + continue + + if not case_sensitive: + ig=set([w.upper() for w in ignore_words]) + ignore_words = ig + + default_clusters = {} + default_words = {} + + ref_file = sys.argv[1] + hyp_file = sys.argv[2] + rec_set = {} + if split and not case_sensitive: + newsplit = dict() + for w in split: + words = split[w] + for i in range(len(words)): + words[i] = words[i].upper() + newsplit[w.upper()] = words + split = newsplit + + with codecs.open(hyp_file, 'r', 'utf-8') as fh: + for line in fh: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array)==0: continue + fid = array[0] + rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split) + + # compute error rate on the interaction of reference file and hyp file + for line in open(ref_file, 'r', encoding='utf-8') : + if tochar: + array = characterize(line) + else: + array = line.rstrip('\n').split() + if len(array)==0: continue + fid = array[0] + if fid not in rec_set: + continue + lab = normalize(array[1:], ignore_words, case_sensitive, split) + rec = rec_set[fid] + if verbose: + print('\nutt: %s' % fid) + + for word in rec + lab : + if word not in default_words : + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters : + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name] : + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name + + result = calculator.calculate(lab, rec) + if verbose: + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('WER: %4.2f %%' % wer, end = ' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + space = {} + space['lab'] = [] + space['rec'] = [] + for idx in range(len(result['lab'])) : + len_lab = width(result['lab'][idx]) + len_rec = width(result['rec'][idx]) + length = max(len_lab, len_rec) + space['lab'].append(length-len_lab) + space['rec'].append(length-len_rec) + upper_lab = len(result['lab']) + upper_rec = len(result['rec']) + lab1, rec1 = 0, 0 + while lab1 < upper_lab or rec1 < upper_rec: + if verbose > 1: + print('lab(%s):' % fid.encode('utf-8'), end = ' ') + else: + print('lab:', end = ' ') + lab2 = min(upper_lab, lab1 + max_words_per_line) + for idx in range(lab1, lab2): + token = result['lab'][idx] + print('{token}'.format(token = token), end = '') + for n in range(space['lab'][idx]) : + print(padding_symbol, end = '') + print(' ',end='') + print() + if verbose > 1: + print('rec(%s):' % fid.encode('utf-8'), end = ' ') + else: + print('rec:', end = ' ') + rec2 = min(upper_rec, rec1 + max_words_per_line) + for idx in range(rec1, rec2): + token = result['rec'][idx] + print('{token}'.format(token = token), end = '') + for n in range(space['rec'][idx]) : + print(padding_symbol, end = '') + print(' ',end='') + print('\n', end='\n') + lab1 = lab2 + rec1 = rec2 + + if verbose: + print('===========================================================================') + print() + + result = calculator.overall() + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('Overall -> %4.2f %%' % wer, end = ' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + if not verbose: + print() + + if verbose: + for cluster_id in default_clusters : + result = calculator.cluster([ k for k in default_clusters[cluster_id] ]) + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + if len(cluster_file) > 0 : # compute separated WERs for word clusters + cluster_id = '' + cluster = [] + for line in open(cluster_file, 'r', encoding='utf-8') : + for token in line.decode('utf-8').rstrip('\n').split() : + # end of cluster reached, like + if token[0:2] == '' and \ + token.lstrip('') == cluster_id : + result = calculator.cluster(cluster) + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + cluster_id = '' + cluster = [] + # begin of cluster reached, like + elif token[0] == '<' and token[len(token)-1] == '>' and \ + cluster_id == '' : + cluster_id = token.lstrip('<').rstrip('>') + cluster = [] + # general terms, like WEATHER / CAR / ... + else : + cluster.append(token) + print() + print('===========================================================================') diff --git a/speech/speech_recognition/transformer/pytorch/tools/compute_cmvn_stats.py b/speech/speech_recognition/transformer/pytorch/tools/compute_cmvn_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..9c89789c47be0c855939469e86040f10398e9d89 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/tools/compute_cmvn_stats.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +import sys +import argparse +import json +import codecs +import yaml + +import torch +import torchaudio +import torchaudio.compliance.kaldi as kaldi +from torch.utils.data import Dataset, DataLoader + +torchaudio.set_audio_backend("sox_io") + + +class CollateFunc(object): + ''' Collate function for AudioDataset + ''' + + def __init__(self, feat_dim, resample_rate): + self.feat_dim = feat_dim + self.resample_rate = resample_rate + pass + + def __call__(self, batch): + mean_stat = torch.zeros(self.feat_dim) + var_stat = torch.zeros(self.feat_dim) + number = 0 + for item in batch: + value = item[1].strip().split(",") + assert len(value) == 3 or len(value) == 1 + wav_path = value[0] + sample_rate = torchaudio.backend.sox_io_backend.info(wav_path).sample_rate + resample_rate = sample_rate + # len(value) == 3 means segmented wav.scp, + # len(value) == 1 means original wav.scp + if len(value) == 3: + start_frame = int(float(value[1]) * sample_rate) + end_frame = int(float(value[2]) * sample_rate) + waveform, sample_rate = torchaudio.backend.sox_io_backend.load( + filepath=wav_path, + num_frames=end_frame - start_frame, + frame_offset=start_frame) + else: + waveform, sample_rate = torchaudio.load(item[1]) + + waveform = waveform * (1 << 15) + if self.resample_rate != 0 and self.resample_rate != sample_rate: + resample_rate = self.resample_rate + waveform = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=resample_rate)(waveform) + + mat = kaldi.fbank(waveform, + num_mel_bins=self.feat_dim, + dither=0.0, + energy_floor=0.0, + sample_frequency=resample_rate) + mean_stat += torch.sum(mat, axis=0) + var_stat += torch.sum(torch.square(mat), axis=0) + number += mat.shape[0] + return number, mean_stat, var_stat + + +class AudioDataset(Dataset): + def __init__(self, data_file): + self.items = [] + with codecs.open(data_file, 'r', encoding='utf-8') as f: + for line in f: + arr = line.strip().split() + self.items.append((arr[0], arr[1])) + + def __len__(self): + return len(self.items) + + def __getitem__(self, idx): + return self.items[idx] + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='extract CMVN stats') + parser.add_argument('--num_workers', + default=0, + type=int, + help='num of subprocess workers for processing') + parser.add_argument('--train_config', + default='', + help='training yaml conf') + parser.add_argument('--in_scp', default=None, help='wav scp file') + parser.add_argument('--out_cmvn', + default='global_cmvn', + help='global cmvn file') + + doc = "Print log after every log_interval audios are processed." + parser.add_argument("--log_interval", type=int, default=1000, help=doc) + args = parser.parse_args() + + with open(args.train_config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + feat_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins'] + resample_rate = 0 + if 'resample_conf' in configs['dataset_conf']: + resample_rate = configs['dataset_conf']['resample_conf']['resample_rate'] + print('using resample and new sample rate is {}'.format(resample_rate)) + + collate_func = CollateFunc(feat_dim, resample_rate) + dataset = AudioDataset(args.in_scp) + batch_size = 20 + data_loader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, + sampler=None, + num_workers=args.num_workers, + collate_fn=collate_func) + + with torch.no_grad(): + all_number = 0 + all_mean_stat = torch.zeros(feat_dim) + all_var_stat = torch.zeros(feat_dim) + wav_number = 0 + for i, batch in enumerate(data_loader): + number, mean_stat, var_stat = batch + all_mean_stat += mean_stat + all_var_stat += var_stat + all_number += number + wav_number += batch_size + + if wav_number % args.log_interval == 0: + print(f'processed {wav_number} wavs, {all_number} frames', + file=sys.stderr, + flush=True) + + cmvn_info = { + 'mean_stat': list(all_mean_stat.tolist()), + 'var_stat': list(all_var_stat.tolist()), + 'frame_num': all_number + } + + with open(args.out_cmvn, 'w') as fout: + fout.write(json.dumps(cmvn_info)) diff --git a/speech/speech_recognition/transformer/pytorch/tools/filter_scp.pl b/speech/speech_recognition/transformer/pytorch/tools/filter_scp.pl new file mode 100644 index 0000000000000000000000000000000000000000..b76d37f41be0886470281978bfacf97f6b8ae976 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/tools/filter_scp.pl @@ -0,0 +1,87 @@ +#!/usr/bin/env perl +# Copyright 2010-2012 Microsoft Corporation +# Johns Hopkins University (author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script takes a list of utterance-ids or any file whose first field +# of each line is an utterance-id, and filters an scp +# file (or any file whose "n-th" field is an utterance id), printing +# out only those lines whose "n-th" field is in id_list. The index of +# the "n-th" field is 1, by default, but can be changed by using +# the -f switch + +$exclude = 0; +$field = 1; +$shifted = 0; + +do { + $shifted=0; + if ($ARGV[0] eq "--exclude") { + $exclude = 1; + shift @ARGV; + $shifted=1; + } + if ($ARGV[0] eq "-f") { + $field = $ARGV[1]; + shift @ARGV; shift @ARGV; + $shifted=1 + } +} while ($shifted); + +if(@ARGV < 1 || @ARGV > 2) { + die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . + "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . + "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . + "only the lines that were *not* in id_list.\n" . + "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . + "If your older scripts (written before Oct 2014) stopped working and you used the\n" . + "-f option, add 1 to the argument.\n" . + "See also: utils/filter_scp.pl .\n"; +} + + +$idlist = shift @ARGV; +open(F, "<$idlist") || die "Could not open id-list file $idlist"; +while() { + @A = split; + @A>=1 || die "Invalid id-list file line $_"; + $seen{$A[0]} = 1; +} + +if ($field == 1) { # Treat this as special case, since it is common. + while(<>) { + $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; + # $1 is what we filter on. + if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { + print $_; + } + } +} else { + while(<>) { + @A = split; + @A > 0 || die "Invalid scp file line $_"; + @A >= $field || die "Invalid scp file line $_"; + if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { + print $_; + } + } +} + +# tests: +# the following should print "foo 1" +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo) +# the following should print "bar 2". +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2) diff --git a/speech/speech_recognition/transformer/pytorch/tools/make_raw_list.py b/speech/speech_recognition/transformer/pytorch/tools/make_raw_list.py new file mode 100644 index 0000000000000000000000000000000000000000..2f84f015542bb38da027b8ea61e8638f873cec33 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/tools/make_raw_list.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='') + parser.add_argument('--segments', default=None, help='segments file') + parser.add_argument('wav_file', help='wav file') + parser.add_argument('text_file', help='text file') + parser.add_argument('output_file', help='output list file') + args = parser.parse_args() + + wav_table = {} + with open(args.wav_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + wav_table[arr[0]] = arr[1] + + if args.segments is not None: + segments_table = {} + with open(args.segments, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 4 + segments_table[arr[0]] = (arr[1], float(arr[2]), float(arr[3])) + + with open(args.text_file, 'r', encoding='utf8') as fin, \ + open(args.output_file, 'w', encoding='utf8') as fout: + for line in fin: + arr = line.strip().split(maxsplit=1) + key = arr[0] + txt = arr[1] if len(arr) > 1 else '' + if args.segments is None: + assert key in wav_table + wav = wav_table[key] + line = dict(key=key, wav=wav, txt=txt) + else: + assert key in segments_table + wav_key, start, end = segments_table[key] + wav = wav_table[wav_key] + line = dict(key=key, wav=wav, txt=txt, start=start, end=end) + json_line = json.dumps(line, ensure_ascii=False) + fout.write(json_line + '\n') diff --git a/speech/speech_recognition/transformer/pytorch/tools/parse_options.sh b/speech/speech_recognition/transformer/pytorch/tools/parse_options.sh new file mode 100644 index 0000000000000000000000000000000000000000..34476fdb37a4b14d5fe6e0edbebe97e760d2be5a --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/tools/parse_options.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### No we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. diff --git a/speech/speech_recognition/transformer/pytorch/tools/text2token.py b/speech/speech_recognition/transformer/pytorch/tools/text2token.py new file mode 100644 index 0000000000000000000000000000000000000000..4f4dcc901d436650695f0b80e0cf99e1e99269ee --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/tools/text2token.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Copyright 2021 JD AI Lab. All Rights Reserved. (authors: Lu Fan) +# Copyright 2021 Mobvoi Inc. All Rights Reserved. (Di Wu) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +from __future__ import print_function +from __future__ import unicode_literals + +import argparse +import codecs +import re +import sys + +is_python2 = sys.version_info[0] == 2 + + +def exist_or_not(i, match_pos): + start_pos = None + end_pos = None + for pos in match_pos: + if pos[0] <= i < pos[1]: + start_pos = pos[0] + end_pos = pos[1] + break + + return start_pos, end_pos + +def seg_char(sent): + pattern = re.compile(r'([\u4e00-\u9fa5])') + chars = pattern.split(sent) + chars = [w for w in chars if len(w.strip()) > 0] + return chars + +def get_parser(): + parser = argparse.ArgumentParser( + description='convert raw text to tokenized text', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--nchar', + '-n', + default=1, + type=int, + help='number of characters to split, i.e., \ + aabb -> a a b b with -n 1 and aa bb with -n 2') + parser.add_argument('--skip-ncols', + '-s', + default=0, + type=int, + help='skip first n columns') + parser.add_argument('--space', + default='', + type=str, + help='space symbol') + parser.add_argument('--bpe-model', + '-m', + default=None, + type=str, + help='bpe model for english part') + parser.add_argument('--non-lang-syms', + '-l', + default=None, + type=str, + help='list of non-linguistic symobles,' + ' e.g., etc.') + parser.add_argument('text', + type=str, + default=False, + nargs='?', + help='input text') + parser.add_argument('--trans_type', + '-t', + type=str, + default="char", + choices=["char", "phn", "cn_char_en_bpe"], + help="""Transcript type. char/phn. e.g., for TIMIT + FADG0_SI1279 - + If trans_type is char, read from + SI1279.WRD file -> "bricks are an alternative" + Else if trans_type is phn, + read from SI1279.PHN file -> + "sil b r ih sil k s aa r er n aa l + sil t er n ih sil t ih v sil" """) + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + rs = [] + if args.non_lang_syms is not None: + with codecs.open(args.non_lang_syms, 'r', encoding="utf-8") as f: + nls = [x.rstrip() for x in f.readlines()] + rs = [re.compile(re.escape(x)) for x in nls] + + if args.bpe_model is not None: + import sentencepiece as spm + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + if args.text: + f = codecs.open(args.text, encoding="utf-8") + else: + f = codecs.getreader("utf-8")( + sys.stdin if is_python2 else sys.stdin.buffer) + + sys.stdout = codecs.getwriter("utf-8")( + sys.stdout if is_python2 else sys.stdout.buffer) + line = f.readline() + n = args.nchar + while line: + x = line.split() + print(' '.join(x[:args.skip_ncols]), end=" ") + a = ' '.join(x[args.skip_ncols:]) + + # get all matched positions + match_pos = [] + for r in rs: + i = 0 + while i >= 0: + m = r.search(a, i) + if m: + match_pos.append([m.start(), m.end()]) + i = m.end() + else: + break + + if len(match_pos) > 0: + chars = [] + i = 0 + while i < len(a): + start_pos, end_pos = exist_or_not(i, match_pos) + if start_pos is not None: + chars.append(a[start_pos:end_pos]) + i = end_pos + else: + chars.append(a[i]) + i += 1 + a = chars + + if (args.trans_type == "phn"): + a = a.split(" ") + elif args.trans_type == "cn_char_en_bpe": + b = seg_char(a) + a = [] + for j in b: + # we use "▁" to instead of blanks among english words + # warning: here is "▁", not "_" + for l in j.strip().split("▁"): + if not l.encode('UTF-8').isalpha(): + a.append(l) + else: + for k in sp.encode_as_pieces(l): + a.append(k) + else: + a = [a[j:j + n] for j in range(0, len(a), n)] + + a_flat = [] + for z in a: + a_flat.append("".join(z)) + + a_chars = [z.replace(' ', args.space) for z in a_flat] + if (args.trans_type == "phn"): + a_chars = [z.replace("sil", args.space) for z in a_chars] + print(' '.join(a_chars)) + line = f.readline() + + +if __name__ == '__main__': + main() diff --git a/speech/speech_recognition/transformer/pytorch/wenet/bin/alignment.py b/speech/speech_recognition/transformer/pytorch/wenet/bin/alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..9e055e8802db13dedd6dbadcf0011cfb578b9a9e --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/bin/alignment.py @@ -0,0 +1,235 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Di Wu) +# 2022 Tinnove Inc (authors: Wei Ren) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import copy +import logging +import os +import sys + +import torch +import yaml +from torch.utils.data import DataLoader +from textgrid import TextGrid, IntervalTier + +from wenet.dataset.dataset import Dataset +from wenet.transformer.asr_model import init_asr_model +from wenet.utils.checkpoint import load_checkpoint +from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols +from wenet.utils.ctc_util import forced_align +from wenet.utils.common import get_subsample + + +def generator_textgrid(maxtime, lines, output): + # Download Praat: https://www.fon.hum.uva.nl/praat/ + interval = maxtime / (len(lines) + 1) + margin = 0.0001 + + tg = TextGrid(maxTime=maxtime) + linetier = IntervalTier(name="line", maxTime=maxtime) + + i = 0 + for l in lines: + s, e, w = l.split() + linetier.add(minTime=float(s) + margin, maxTime=float(e), mark=w) + + tg.append(linetier) + print("successfully generator {}".format(output)) + tg.write(output) + + +def get_frames_timestamp(alignment): + # convert alignment to a praat format, which is a doing phonetics + # by computer and helps analyzing alignment + timestamp = [] + # get frames level duration for each token + start = 0 + end = 0 + while end < len(alignment): + while end < len(alignment) and alignment[end] == 0: + end += 1 + if end == len(alignment): + timestamp[-1] += alignment[start:] + break + end += 1 + while end < len(alignment) and alignment[end - 1] == alignment[end]: + end += 1 + timestamp.append(alignment[start:end]) + start = end + return timestamp + + +def get_labformat(timestamp, subsample): + begin = 0 + duration = 0 + labformat = [] + for idx, t in enumerate(timestamp): + # 25ms frame_length,10ms hop_length, 1/subsample + subsample = get_subsample(configs) + # time duration + duration = len(t) * 0.01 * subsample + if idx < len(timestamp) - 1: + print("{:.2f} {:.2f} {}".format(begin, begin + duration, + char_dict[t[-1]])) + labformat.append("{:.2f} {:.2f} {}\n".format( + begin, begin + duration, char_dict[t[-1]])) + else: + non_blank = 0 + for i in t: + if i != 0: + token = i + break + print("{:.2f} {:.2f} {}".format(begin, begin + duration, + char_dict[token])) + labformat.append("{:.2f} {:.2f} {}\n".format( + begin, begin + duration, char_dict[token])) + begin = begin + duration + return labformat + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='use ctc to generate alignment') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--input_file', required=True, help='format data file') + parser.add_argument('--data_type', + default='raw', + choices=['raw', 'shard'], + help='train and cv data type') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--dict', required=True, help='dict file') + parser.add_argument('--non_lang_syms', + help="non-linguistic symbol file. One symbol per line.") + parser.add_argument('--result_file', + required=True, + help='alignment result file') + parser.add_argument('--batch_size', type=int, default=1, help='batch size') + parser.add_argument('--gen_praat', + action='store_true', + help='convert alignment to a praat format') + parser.add_argument('--bpe_model', + default=None, + type=str, + help='bpe model for english part') + + args = parser.parse_args() + print(args) + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + if args.batch_size > 1: + logging.fatal('alignment mode must be running with batch_size == 1') + sys.exit(1) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + # Load dict + char_dict = {} + with open(args.dict, 'r') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + char_dict[int(arr[1])] = arr[0] + eos = len(char_dict) - 1 + + symbol_table = read_symbol_table(args.dict) + + # Init dataset and data loader + ali_conf = copy.deepcopy(configs['dataset_conf']) + + ali_conf['filter_conf']['max_length'] = 102400 + ali_conf['filter_conf']['min_length'] = 0 + ali_conf['filter_conf']['token_max_length'] = 102400 + ali_conf['filter_conf']['token_min_length'] = 0 + ali_conf['filter_conf']['max_output_input_ratio'] = 102400 + ali_conf['filter_conf']['min_output_input_ratio'] = 0 + ali_conf['speed_perturb'] = False + ali_conf['spec_aug'] = False + ali_conf['shuffle'] = False + ali_conf['sort'] = False + ali_conf['fbank_conf']['dither'] = 0.0 + ali_conf['batch_conf']['batch_type'] = "static" + ali_conf['batch_conf']['batch_size'] = args.batch_size + non_lang_syms = read_non_lang_symbols(args.non_lang_syms) + + ali_dataset = Dataset(args.data_type, + args.input_file, + symbol_table, + ali_conf, + args.bpe_model, + non_lang_syms, + partition=False) + + ali_data_loader = DataLoader(ali_dataset, batch_size=None, num_workers=0) + + # Init asr model from configs + model = init_asr_model(configs) + + load_checkpoint(model, args.checkpoint) + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + model = model.to(device) + + model.eval() + with torch.no_grad(), open(args.result_file, 'w', + encoding='utf-8') as fout: + for batch_idx, batch in enumerate(ali_data_loader): + print("#" * 80) + key, feat, target, feats_length, target_length = batch + print(key) + + feat = feat.to(device) + target = target.to(device) + feats_length = feats_length.to(device) + target_length = target_length.to(device) + # Let's assume B = batch_size and N = beam_size + # 1. Encoder + encoder_out, encoder_mask = model._forward_encoder( + feat, feats_length) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + ctc_probs = model.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + # print(ctc_probs.size(1)) + ctc_probs = ctc_probs.squeeze(0) + target = target.squeeze(0) + alignment = forced_align(ctc_probs, target) + print(alignment) + fout.write('{} {}\n'.format(key[0], alignment)) + + if args.gen_praat: + timestamp = get_frames_timestamp(alignment) + print(timestamp) + subsample = get_subsample(configs) + labformat = get_labformat(timestamp, subsample) + + lab_path = os.path.join(os.path.dirname(args.result_file), + key[0] + ".lab") + with open(lab_path, 'w', encoding='utf-8') as f: + f.writelines(labformat) + + textgrid_path = os.path.join(os.path.dirname(args.result_file), + key[0] + ".TextGrid") + generator_textgrid(maxtime=(len(alignment) + 1) * 0.01 * + subsample, + lines=labformat, + output=textgrid_path) diff --git a/speech/speech_recognition/transformer/pytorch/wenet/bin/alignment_deprecated.py b/speech/speech_recognition/transformer/pytorch/wenet/bin/alignment_deprecated.py new file mode 100644 index 0000000000000000000000000000000000000000..94471d4cd93ca2cd8afd8c9cf2ed0dfb142c2ca3 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/bin/alignment_deprecated.py @@ -0,0 +1,216 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Di Wu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import copy +import logging +import os +import sys + +import torch +import yaml +from torch.utils.data import DataLoader +from textgrid import TextGrid, IntervalTier + +from wenet.dataset.dataset_deprecated import AudioDataset, CollateFunc +from wenet.transformer.asr_model import init_asr_model +from wenet.utils.checkpoint import load_checkpoint +from wenet.utils.ctc_util import forced_align +from wenet.utils.common import get_subsample + + +def generator_textgrid(maxtime, lines, output): + # Download Praat: https://www.fon.hum.uva.nl/praat/ + interval = maxtime / (len(lines) + 1) + margin = 0.0001 + + tg = TextGrid(maxTime=maxtime) + linetier = IntervalTier(name="line", maxTime=maxtime) + + i = 0 + for l in lines: + s, e, w = l.split() + linetier.add(minTime=float(s) + margin, maxTime=float(e), mark=w) + + tg.append(linetier) + print("successfully generator {}".format(output)) + tg.write(output) + + +def get_frames_timestamp(alignment): + # convert alignment to a praat format, which is a doing phonetics + # by computer and helps analyzing alignment + timestamp = [] + # get frames level duration for each token + start = 0 + end = 0 + while end < len(alignment): + while end < len(alignment) and alignment[end] == 0: + end += 1 + if end == len(alignment): + timestamp[-1] += alignment[start:] + break + end += 1 + while end < len(alignment) and alignment[end - 1] == alignment[end]: + end += 1 + timestamp.append(alignment[start:end]) + start = end + return timestamp + + +def get_labformat(timestamp, subsample): + begin = 0 + duration = 0 + labformat = [] + for idx, t in enumerate(timestamp): + # 25ms frame_length,10ms hop_length, 1/subsample + subsample = get_subsample(configs) + # time duration + duration = len(t) * 0.01 * subsample + if idx < len(timestamp) - 1: + print("{:.2f} {:.2f} {}".format(begin, begin + duration, + char_dict[t[-1]])) + labformat.append("{:.2f} {:.2f} {}\n".format( + begin, begin + duration, char_dict[t[-1]])) + else: + non_blank = 0 + for i in t: + if i != 0: + token = i + break + print("{:.2f} {:.2f} {}".format(begin, begin + duration, + char_dict[token])) + labformat.append("{:.2f} {:.2f} {}\n".format( + begin, begin + duration, char_dict[token])) + begin = begin + duration + return labformat + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='use ctc to generate alignment') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--input_file', required=True, help='format data file') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--dict', required=True, help='dict file') + parser.add_argument('--result_file', + required=True, + help='alignment result file') + parser.add_argument('--batch_size', type=int, default=1, help='batch size') + parser.add_argument('--gen_praat', + action='store_true', + help='convert alignment to a praat format') + + args = parser.parse_args() + print(args) + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + if args.batch_size > 1: + logging.fatal('alignment mode must be running with batch_size == 1') + sys.exit(1) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + # Load dict + char_dict = {} + with open(args.dict, 'r') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + char_dict[int(arr[1])] = arr[0] + eos = len(char_dict) - 1 + + raw_wav = configs['raw_wav'] + # Init dataset and data loader + ali_collate_conf = copy.deepcopy(configs['collate_conf']) + ali_collate_conf['spec_aug'] = False + ali_collate_conf['spec_sub'] = False + ali_collate_conf['feature_dither'] = False + ali_collate_conf['speed_perturb'] = False + if raw_wav: + ali_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0 + ali_collate_func = CollateFunc(**ali_collate_conf, raw_wav=raw_wav) + dataset_conf = configs.get('dataset_conf', {}) + dataset_conf['batch_size'] = args.batch_size + dataset_conf['batch_type'] = 'static' + dataset_conf['sort'] = False + ali_dataset = AudioDataset(args.input_file, + **dataset_conf, + raw_wav=raw_wav) + ali_data_loader = DataLoader(ali_dataset, + collate_fn=ali_collate_func, + shuffle=False, + batch_size=1, + num_workers=0) + + # Init asr model from configs + model = init_asr_model(configs) + + load_checkpoint(model, args.checkpoint) + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + model = model.to(device) + + model.eval() + with torch.no_grad(), open(args.result_file, 'w', + encoding='utf-8') as fout: + for batch_idx, batch in enumerate(ali_data_loader): + print("#" * 80) + key, feat, target, feats_length, target_length = batch + print(key) + + feat = feat.to(device) + target = target.to(device) + feats_length = feats_length.to(device) + target_length = target_length.to(device) + # Let's assume B = batch_size and N = beam_size + # 1. Encoder + encoder_out, encoder_mask = model._forward_encoder( + feat, feats_length) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + ctc_probs = model.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + # print(ctc_probs.size(1)) + ctc_probs = ctc_probs.squeeze(0) + target = target.squeeze(0) + alignment = forced_align(ctc_probs, target) + print(alignment) + fout.write('{} {}\n'.format(key[0], alignment)) + + if args.gen_praat: + timestamp = get_frames_timestamp(alignment) + print(timestamp) + subsample = get_subsample(configs) + labformat = get_labformat(timestamp, subsample) + + lab_path = os.path.join(os.path.dirname(args.result_file), + key[0] + ".lab") + with open(lab_path, 'w', encoding='utf-8') as f: + f.writelines(labformat) + + textgrid_path = os.path.join(os.path.dirname(args.result_file), + key[0] + ".TextGrid") + generator_textgrid(maxtime=(len(alignment) + 1) * 0.01 * + subsample, + lines=labformat, + output=textgrid_path) diff --git a/speech/speech_recognition/transformer/pytorch/wenet/bin/average_model.py b/speech/speech_recognition/transformer/pytorch/wenet/bin/average_model.py new file mode 100644 index 0000000000000000000000000000000000000000..281e2c064634cb6ed5c1200a99ddc28b73df2ec6 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/bin/average_model.py @@ -0,0 +1,88 @@ +# Copyright 2020 Mobvoi Inc. All Rights Reserved. +# Author: di.wu@mobvoi.com (DI WU) +import os +import argparse +import glob + +import yaml +import numpy as np +import torch + + +def get_args(): + parser = argparse.ArgumentParser(description='average model') + parser.add_argument('--dst_model', required=True, help='averaged model') + parser.add_argument('--src_path', + required=True, + help='src model path for average') + parser.add_argument('--val_best', + action="store_true", + help='averaged model') + parser.add_argument('--num', + default=5, + type=int, + help='nums for averaged model') + parser.add_argument('--min_epoch', + default=0, + type=int, + help='min epoch used for averaging model') + parser.add_argument('--max_epoch', + default=65536, + type=int, + help='max epoch used for averaging model') + + args = parser.parse_args() + print(args) + return args + + +def main(): + args = get_args() + checkpoints = [] + val_scores = [] + if args.val_best: + yamls = glob.glob('{}/[!train]*.yaml'.format(args.src_path)) + for y in yamls: + with open(y, 'r') as f: + dic_yaml = yaml.load(f, Loader=yaml.FullLoader) + loss = dic_yaml['cv_loss'] + epoch = dic_yaml['epoch'] + if epoch >= args.min_epoch and epoch <= args.max_epoch: + val_scores += [[epoch, loss]] + val_scores = np.array(val_scores) + sort_idx = np.argsort(val_scores[:, -1]) + sorted_val_scores = val_scores[sort_idx][::1] + print("best val scores = " + str(sorted_val_scores[:args.num, 1])) + print("selected epochs = " + + str(sorted_val_scores[:args.num, 0].astype(np.int64))) + path_list = [ + args.src_path + '/{}.pt'.format(int(epoch)) + for epoch in sorted_val_scores[:args.num, 0] + ] + else: + path_list = glob.glob('{}/[0-9]*.pt'.format(args.src_path)) + path_list = sorted(path_list, key=os.path.getmtime) + path_list = path_list[-args.num:] + print(path_list) + avg = None + num = args.num + assert num == len(path_list) + for path in path_list: + print('Processing {}'.format(path)) + states = torch.load(path, map_location=torch.device('cpu')) + if avg is None: + avg = states + else: + for k in avg.keys(): + avg[k] += states[k] + # average + for k in avg.keys(): + if avg[k] is not None: + # pytorch 1.6 use true_divide instead of /= + avg[k] = torch.true_divide(avg[k], num) + print('Saving to {}'.format(args.dst_model)) + torch.save(avg, args.dst_model) + + +if __name__ == '__main__': + main() diff --git a/speech/speech_recognition/transformer/pytorch/wenet/bin/export_jit.py b/speech/speech_recognition/transformer/pytorch/wenet/bin/export_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..41f9e3305818f55d74b7f16be2e5130be0d50e28 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/bin/export_jit.py @@ -0,0 +1,70 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import os + +import torch +import yaml + +from wenet.transformer.asr_model import init_asr_model +from wenet.utils.checkpoint import load_checkpoint + + +def get_args(): + parser = argparse.ArgumentParser(description='export your script model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--output_file', default=None, help='output file') + parser.add_argument('--output_quant_file', + default=None, + help='output quantized model file') + args = parser.parse_args() + return args + + +def main(): + args = get_args() + # No need gpu for model export + os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + model = init_asr_model(configs) + print(model) + + load_checkpoint(model, args.checkpoint) + # Export jit torch script model + + if args.output_file: + script_model = torch.jit.script(model) + script_model.save(args.output_file) + print('Export model successfully, see {}'.format(args.output_file)) + + # Export quantized jit torch script model + if args.output_quant_file: + quantized_model = torch.quantization.quantize_dynamic( + model, {torch.nn.Linear}, dtype=torch.qint8 + ) + print(quantized_model) + script_quant_model = torch.jit.script(quantized_model) + script_quant_model.save(args.output_quant_file) + print('Export quantized model successfully, ' + 'see {}'.format(args.output_quant_file)) + + +if __name__ == '__main__': + main() diff --git a/speech/speech_recognition/transformer/pytorch/wenet/bin/export_onnx_cpu.py b/speech/speech_recognition/transformer/pytorch/wenet/bin/export_onnx_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..49934a1e426e9c77fdbc03e0abd4edd2228620d5 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/bin/export_onnx_cpu.py @@ -0,0 +1,394 @@ +#!/usr/bin/env python3 +# Copyright (c) 2022, Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import os +import copy +import sys + +import torch +import yaml +import numpy as np + +from wenet.transformer.asr_model import init_asr_model +from wenet.utils.checkpoint import load_checkpoint + +try: + import onnx + import onnxruntime +except ImportError: + print('Please install onnx and onnxruntime!') + sys.exit(1) + + +def get_args(): + parser = argparse.ArgumentParser(description='export your script model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--output_dir', required=True, help='output directory') + parser.add_argument('--chunk_size', required=True, + type=int, help='decoding chunk size') + parser.add_argument('--num_decoding_left_chunks', required=True, + type=int, help='cache chunks') + parser.add_argument('--reverse_weight', default=0.5, + type=float, help='reverse_weight in attention_rescoing') + args = parser.parse_args() + return args + + +def to_numpy(tensor): + if tensor.requires_grad: + return tensor.detach().cpu().numpy() + else: + return tensor.cpu().numpy() + + +def print_input_output_info(onnx_model, name, prefix="\t\t"): + input_names = [node.name for node in onnx_model.graph.input] + input_shapes = [[d.dim_value for d in node.type.tensor_type.shape.dim] + for node in onnx_model.graph.input] + output_names = [node.name for node in onnx_model.graph.output] + output_shapes = [[d.dim_value for d in node.type.tensor_type.shape.dim] + for node in onnx_model.graph.output] + print("{}{} inputs : {}".format(prefix, name, input_names)) + print("{}{} input shapes : {}".format(prefix, name, input_shapes)) + print("{}{} outputs: {}".format(prefix, name, output_names)) + print("{}{} output shapes : {}".format(prefix, name, output_shapes)) + + +def export_encoder(asr_model, args): + print("Stage-1: export encoder") + encoder = asr_model.encoder + encoder.forward = encoder.forward_chunk + encoder_outpath = os.path.join(args['output_dir'], 'encoder.onnx') + + print("\tStage-1.1: prepare inputs for encoder") + chunk = torch.randn( + (args['batch'], args['decoding_window'], args['feature_size'])) + offset = 0 + # NOTE(xcsong): The uncertainty of `next_cache_start` only appears + # in the first few chunks, this is caused by dynamic att_cache shape, i,e + # (0, 0, 0, 0) for 1st chunk and (elayers, head, ?, d_k*2) for subsequent + # chunks. One way to ease the ONNX export is to keep `next_cache_start` + # as a fixed value. To do this, for the **first** chunk, if + # left_chunks > 0, we feed real cache & real mask to the model, otherwise + # fake cache & fake mask. In this way, we get: + # 1. 16/-1 mode: next_cache_start == 0 for all chunks + # 2. 16/4 mode: next_cache_start == chunk_size for all chunks + # 3. 16/0 mode: next_cache_start == chunk_size for all chunks + # 4. -1/-1 mode: next_cache_start == 0 for all chunks + # NO MORE DYNAMIC CHANGES!! + if args['left_chunks'] > 0: # 16/4 + required_cache_size = args['chunk_size'] * args['left_chunks'] + offset = required_cache_size + # Real cache + att_cache = torch.zeros( + (args['num_blocks'], args['head'], required_cache_size, + args['output_size'] // args['head'] * 2)) + # Real mask + att_mask = torch.ones( + (args['batch'], 1, required_cache_size + args['chunk_size']), + dtype=torch.bool) + att_mask[:, :, :required_cache_size] = 0 + elif args['left_chunks'] <= 0: # 16/-1, -1/-1, 16/0 + required_cache_size = -1 if args['left_chunks'] < 0 else 0 + # Fake cache + att_cache = torch.zeros( + (args['num_blocks'], args['head'], 0, + args['output_size'] // args['head'] * 2)) + # Fake mask + att_mask = torch.ones((0, 0, 0), dtype=torch.bool) + cnn_cache = torch.zeros( + (args['num_blocks'], args['batch'], + args['output_size'], args['cnn_module_kernel'] - 1)) + inputs = (chunk, offset, required_cache_size, + att_cache, cnn_cache, att_mask) + print("\t\tchunk.size(): {}\n".format(chunk.size()), + "\t\toffset: {}\n".format(offset), + "\t\trequired_cache: {}\n".format(required_cache_size), + "\t\tatt_cache.size(): {}\n".format(att_cache.size()), + "\t\tcnn_cache.size(): {}\n".format(cnn_cache.size()), + "\t\tatt_mask.size(): {}\n".format(att_mask.size())) + + print("\tStage-1.2: torch.onnx.export") + dynamic_axes = { + 'chunk': {1: 'T'}, + 'att_cache': {2: 'T_CACHE'}, + 'att_mask': {2: 'T_ADD_T_CACHE'}, + 'output': {1: 'T'}, + 'r_att_cache': {2: 'T_CACHE'}, + } + # NOTE(xcsong): We keep dynamic axes even if in 16/4 mode, this is + # to avoid padding the last chunk (which usually contains less + # frames than required). For users who want static axes, just pop + # out specific axis. + # if args['chunk_size'] > 0: # 16/4, 16/-1, 16/0 + # dynamic_axes.pop('chunk') + # dynamic_axes.pop('output') + # if args['left_chunks'] >= 0: # 16/4, 16/0 + # # NOTE(xsong): since we feed real cache & real mask into the + # # model when left_chunks > 0, the shape of cache will never + # # be changed. + # dynamic_axes.pop('att_cache') + # dynamic_axes.pop('r_att_cache') + torch.onnx.export( + encoder, inputs, encoder_outpath, opset_version=13, + export_params=True, do_constant_folding=True, + input_names=[ + 'chunk', 'offset', 'required_cache_size', + 'att_cache', 'cnn_cache', 'att_mask' + ], + output_names=['output', 'r_att_cache', 'r_cnn_cache'], + dynamic_axes=dynamic_axes, verbose=False) + onnx_encoder = onnx.load(encoder_outpath) + for (k, v) in args.items(): + meta = onnx_encoder.metadata_props.add() + meta.key, meta.value = str(k), str(v) + onnx.checker.check_model(onnx_encoder) + onnx.helper.printable_graph(onnx_encoder.graph) + # NOTE(xcsong): to add those metadatas we need to reopen + # the file and resave it. + onnx.save(onnx_encoder, encoder_outpath) + print_input_output_info(onnx_encoder, "onnx_encoder") + print('\t\tExport onnx_encoder, done! see {}'.format(encoder_outpath)) + + print("\tStage-1.3: check onnx_encoder and torch_encoder") + torch_output = [] + torch_chunk = copy.deepcopy(chunk) + torch_offset = copy.deepcopy(offset) + torch_required_cache_size = copy.deepcopy(required_cache_size) + torch_att_cache = copy.deepcopy(att_cache) + torch_cnn_cache = copy.deepcopy(cnn_cache) + torch_att_mask = copy.deepcopy(att_mask) + for i in range(10): + print("\t\ttorch chunk-{}: {}, offset: {}, att_cache: {}," + " cnn_cache: {}, att_mask: {}".format( + i, list(torch_chunk.size()), torch_offset, + list(torch_att_cache.size()), + list(torch_cnn_cache.size()), list(torch_att_mask.size()))) + # NOTE(xsong): att_mask of the first few batches need changes if + # we use 16/4 mode. + if args['left_chunks'] > 0: # 16/4 + torch_att_mask[:, :, -(args['chunk_size'] * (i + 1)):] = 1 + out, torch_att_cache, torch_cnn_cache = encoder( + torch_chunk, torch_offset, torch_required_cache_size, + torch_att_cache, torch_cnn_cache, torch_att_mask) + torch_output.append(out) + torch_offset += out.size(1) + torch_output = torch.cat(torch_output, dim=1) + + onnx_output = [] + onnx_chunk = to_numpy(chunk) + onnx_offset = np.array((offset)).astype(np.int64) + onnx_required_cache_size = np.array((required_cache_size)).astype(np.int64) + onnx_att_cache = to_numpy(att_cache) + onnx_cnn_cache = to_numpy(cnn_cache) + onnx_att_mask = to_numpy(att_mask) + ort_session = onnxruntime.InferenceSession(encoder_outpath) + input_names = [node.name for node in onnx_encoder.graph.input] + for i in range(10): + print("\t\tonnx chunk-{}: {}, offset: {}, att_cache: {}," + " cnn_cache: {}, att_mask: {}".format( + i, onnx_chunk.shape, onnx_offset, onnx_att_cache.shape, + onnx_cnn_cache.shape, onnx_att_mask.shape)) + # NOTE(xsong): att_mask of the first few batches need changes if + # we use 16/4 mode. + if args['left_chunks'] > 0: # 16/4 + onnx_att_mask[:, :, -(args['chunk_size'] * (i + 1)):] = 1 + ort_inputs = { + 'chunk': onnx_chunk, 'offset': onnx_offset, + 'required_cache_size': onnx_required_cache_size, + 'att_cache': onnx_att_cache, 'cnn_cache': onnx_cnn_cache, + 'att_mask': onnx_att_mask + } + # NOTE(xcsong): If we use 16/-1, -1/-1 or 16/0 mode, `next_cache_start` + # will be hardcoded to 0 or chunk_size by ONNX, thus + # required_cache_size and att_mask are no more needed and they will + # be removed by ONNX automatically. + for k in list(ort_inputs): + if k not in input_names: + ort_inputs.pop(k) + ort_outs = ort_session.run(None, ort_inputs) + onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2] + onnx_output.append(ort_outs[0]) + onnx_offset += ort_outs[0].shape[1] + onnx_output = np.concatenate(onnx_output, axis=1) + + np.testing.assert_allclose(to_numpy(torch_output), onnx_output, + rtol=1e-03, atol=1e-05) + meta = ort_session.get_modelmeta() + print("\t\tcustom_metadata_map={}".format(meta.custom_metadata_map)) + print("\t\tCheck onnx_encoder, pass!") + + +def export_ctc(asr_model, args): + print("Stage-2: export ctc") + ctc = asr_model.ctc + ctc.forward = ctc.log_softmax + ctc_outpath = os.path.join(args['output_dir'], 'ctc.onnx') + + print("\tStage-2.1: prepare inputs for ctc") + hidden = torch.randn( + (args['batch'], args['chunk_size'] if args['chunk_size'] > 0 else 16, + args['output_size'])) + + print("\tStage-2.2: torch.onnx.export") + dynamic_axes = {'hidden': {1: 'T'}, 'probs': {1: 'T'}} + torch.onnx.export( + ctc, hidden, ctc_outpath, opset_version=13, + export_params=True, do_constant_folding=True, + input_names=['hidden'], output_names=['probs'], + dynamic_axes=dynamic_axes, verbose=False) + onnx_ctc = onnx.load(ctc_outpath) + for (k, v) in args.items(): + meta = onnx_ctc.metadata_props.add() + meta.key, meta.value = str(k), str(v) + onnx.checker.check_model(onnx_ctc) + onnx.helper.printable_graph(onnx_ctc.graph) + onnx.save(onnx_ctc, ctc_outpath) + print_input_output_info(onnx_ctc, "onnx_ctc") + print('\t\tExport onnx_ctc, done! see {}'.format(ctc_outpath)) + + print("\tStage-2.3: check onnx_ctc and torch_ctc") + torch_output = ctc(hidden) + ort_session = onnxruntime.InferenceSession(ctc_outpath) + onnx_output = ort_session.run(None, {'hidden': to_numpy(hidden)}) + + np.testing.assert_allclose(to_numpy(torch_output), onnx_output[0], + rtol=1e-03, atol=1e-05) + print("\t\tCheck onnx_ctc, pass!") + + +def export_decoder(asr_model, args): + print("Stage-3: export decoder") + decoder = asr_model + # NOTE(lzhin): parameters of encoder will be automatically removed + # since they are not used during rescoring. + decoder.forward = decoder.forward_attention_decoder + decoder_outpath = os.path.join(args['output_dir'], 'decoder.onnx') + + print("\tStage-3.1: prepare inputs for decoder") + # hardcode time->200 nbest->10 len->20, they are dynamic axes. + encoder_out = torch.randn((1, 200, args['output_size'])) + hyps = torch.randint(low=0, high=args['vocab_size'], + size=[10, 20]) + hyps[:, 0] = args['vocab_size'] - 1 # + hyps_lens = torch.randint(low=15, high=21, size=[10]) + + print("\tStage-3.2: torch.onnx.export") + dynamic_axes = { + 'hyps': {0: 'NBEST', 1: 'L'}, 'hyps_lens': {0: 'NBEST'}, + 'encoder_out': {1: 'T'}, + 'score': {0: 'NBEST', 1: 'L'}, 'r_score': {0: 'NBEST', 1: 'L'} + } + inputs = (hyps, hyps_lens, encoder_out, args['reverse_weight']) + torch.onnx.export( + decoder, inputs, decoder_outpath, opset_version=13, + export_params=True, do_constant_folding=True, + input_names=['hyps', 'hyps_lens', 'encoder_out', 'reverse_weight'], + output_names=['score', 'r_score'], + dynamic_axes=dynamic_axes, verbose=False) + onnx_decoder = onnx.load(decoder_outpath) + for (k, v) in args.items(): + meta = onnx_decoder.metadata_props.add() + meta.key, meta.value = str(k), str(v) + onnx.checker.check_model(onnx_decoder) + onnx.helper.printable_graph(onnx_decoder.graph) + onnx.save(onnx_decoder, decoder_outpath) + print_input_output_info(onnx_decoder, "onnx_decoder") + print('\t\tExport onnx_decoder, done! see {}'.format( + decoder_outpath)) + + print("\tStage-3.3: check onnx_decoder and torch_decoder") + torch_score, torch_r_score = decoder( + hyps, hyps_lens, encoder_out, args['reverse_weight']) + ort_session = onnxruntime.InferenceSession(decoder_outpath) + input_names = [node.name for node in onnx_decoder.graph.input] + ort_inputs = { + 'hyps': to_numpy(hyps), + 'hyps_lens': to_numpy(hyps_lens), + 'encoder_out': to_numpy(encoder_out), + 'reverse_weight': np.array((args['reverse_weight'])), + } + for k in list(ort_inputs): + if k not in input_names: + ort_inputs.pop(k) + onnx_output = ort_session.run(None, ort_inputs) + + np.testing.assert_allclose(to_numpy(torch_score), onnx_output[0], + rtol=1e-03, atol=1e-05) + if args['is_bidirectional_decoder'] and args['reverse_weight'] > 0.0: + np.testing.assert_allclose(to_numpy(torch_r_score), onnx_output[1], + rtol=1e-03, atol=1e-05) + print("\t\tCheck onnx_decoder, pass!") + + +def main(): + torch.manual_seed(777) + args = get_args() + output_dir = args.output_dir + os.system("mkdir -p " + output_dir) + os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + model = init_asr_model(configs) + load_checkpoint(model, args.checkpoint) + model.eval() + print(model) + + arguments = {} + arguments['output_dir'] = output_dir + arguments['batch'] = 1 + arguments['chunk_size'] = args.chunk_size + arguments['left_chunks'] = args.num_decoding_left_chunks + arguments['reverse_weight'] = args.reverse_weight + arguments['output_size'] = configs['encoder_conf']['output_size'] + arguments['num_blocks'] = configs['encoder_conf']['num_blocks'] + arguments['cnn_module_kernel'] = configs['encoder_conf']['cnn_module_kernel'] + arguments['head'] = configs['encoder_conf']['attention_heads'] + arguments['feature_size'] = configs['input_dim'] + arguments['vocab_size'] = configs['output_dim'] + # NOTE(xcsong): if chunk_size == -1, hardcode to 67 + arguments['decoding_window'] = (args.chunk_size - 1) * \ + model.encoder.embed.subsampling_rate + \ + model.encoder.embed.right_context + 1 if args.chunk_size > 0 else 67 + arguments['encoder'] = configs['encoder'] + arguments['decoder'] = configs['decoder'] + arguments['subsampling_rate'] = model.subsampling_rate() + arguments['right_context'] = model.right_context() + arguments['sos_symbol'] = model.sos_symbol() + arguments['eos_symbol'] = model.eos_symbol() + arguments['is_bidirectional_decoder'] = 1 \ + if model.is_bidirectional_decoder() else 0 + + # NOTE(xcsong): Please note that -1/-1 means non-streaming model! It is + # not a [16/4 16/-1 16/0] all-in-one model and it should not be used in + # streaming mode (i.e., setting chunk_size=16 in `decoder_main`). If you + # want to use 16/-1 or any other streaming mode in `decoder_main`, + # please export onnx in the same config. + if arguments['left_chunks'] > 0: + assert arguments['chunk_size'] > 0 # -1/4 not supported + + export_encoder(model, arguments) + export_ctc(model, arguments) + export_decoder(model, arguments) + + +if __name__ == '__main__': + main() diff --git a/speech/speech_recognition/transformer/pytorch/wenet/bin/export_onnx_gpu.py b/speech/speech_recognition/transformer/pytorch/wenet/bin/export_onnx_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..b8e94d4baaa9112f8adabfd97606e23564c3ce28 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/bin/export_onnx_gpu.py @@ -0,0 +1,351 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import os +import sys + +import torch +import yaml +import logging + +from wenet.transformer.asr_model import init_asr_model +from wenet.utils.checkpoint import load_checkpoint +from wenet.transformer.ctc import CTC +from wenet.transformer.decoder import TransformerDecoder +from wenet.transformer.encoder import BaseEncoder +from wenet.utils.mask import make_pad_mask + +try: + import onnxruntime +except ImportError: + print('Please install onnxruntime-gpu!') + sys.exit(1) + +logger = logging.getLogger(__file__) +logger.setLevel(logging.INFO) + +class Encoder(torch.nn.Module): + def __init__(self, + encoder: BaseEncoder, + ctc: CTC, + beam_size: int = 10): + super().__init__() + self.encoder = encoder + self.ctc = ctc + self.beam_size = beam_size + + def forward(self, speech: torch.Tensor, + speech_lengths: torch.Tensor,): + """Encoder + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + Returns: + encoder_out: B x T x F + encoder_out_lens: B + ctc_log_probs: B x T x V + beam_log_probs: B x T x beam_size + beam_log_probs_idx: B x T x beam_size + """ + encoder_out, encoder_mask = self.encoder(speech, + speech_lengths, + -1, -1) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + ctc_log_probs = self.ctc.log_softmax(encoder_out) + encoder_out_lens = encoder_out_lens.int() + beam_log_probs, beam_log_probs_idx = torch.topk( + ctc_log_probs, self.beam_size, dim=2) + return encoder_out, encoder_out_lens, ctc_log_probs, \ + beam_log_probs, beam_log_probs_idx + + +class Decoder(torch.nn.Module): + def __init__(self, + decoder: TransformerDecoder, + ctc_weight: float = 0.5, + reverse_weight: float = 0.0, + beam_size: int = 10): + super().__init__() + self.decoder = decoder + self.ctc_weight = ctc_weight + self.reverse_weight = reverse_weight + self.beam_size = beam_size + + def forward(self, + encoder_out: torch.Tensor, + encoder_lens: torch.Tensor, + hyps_pad_sos_eos: torch.Tensor, + hyps_lens_sos: torch.Tensor, + r_hyps_pad_sos_eos: torch.Tensor, + ctc_score: torch.Tensor): + """Encoder + Args: + encoder_out: B x T x F + encoder_lens: B + hyps_pad_sos_eos: B x beam x (T2+1), + hyps with sos & eos and padded by ignore id + hyps_lens_sos: B x beam, length for each hyp with sos + r_hyps_pad_sos_eos: B x beam x (T2+1), + reversed hyps with sos & eos and padded by ignore id + ctc_score: B x beam, ctc score for each hyp + Returns: + decoder_out: B x beam x T2 x V + r_decoder_out: B x beam x T2 x V + best_index: B + """ + B, T, F = encoder_out.shape + bz = self.beam_size + B2 = B * bz + encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F) + encoder_mask = ~make_pad_mask(encoder_lens, T).unsqueeze(1) + encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T) + T2 = hyps_pad_sos_eos.shape[2] - 1 + hyps_pad = hyps_pad_sos_eos.view(B2, T2 + 1) + hyps_lens = hyps_lens_sos.view(B2,) + hyps_pad_sos = hyps_pad[:, :-1].contiguous() + hyps_pad_eos = hyps_pad[:, 1:].contiguous() + + r_hyps_pad = r_hyps_pad_sos_eos.view(B2, T2 + 1) + r_hyps_pad_sos = r_hyps_pad[:, :-1].contiguous() + r_hyps_pad_eos = r_hyps_pad[:, 1:].contiguous() + + decoder_out, r_decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps_pad_sos, hyps_lens, r_hyps_pad_sos, + self.reverse_weight) + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + V = decoder_out.shape[-1] + decoder_out = decoder_out.view(B2, T2, V) + mask = ~make_pad_mask(hyps_lens, T2) # B2 x T2 + # mask index, remove ignore id + index = torch.unsqueeze(hyps_pad_eos * mask, 2) + score = decoder_out.gather(2, index).squeeze(2) # B2 X T2 + # mask padded part + score = score * mask + decoder_out = decoder_out.view(B, bz, T2, V) + if self.reverse_weight > 0: + r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) + r_decoder_out = r_decoder_out.view(B2, T2, V) + index = torch.unsqueeze(r_hyps_pad_eos * mask, 2) + r_score = r_decoder_out.gather(2, index).squeeze(2) + r_score = r_score * mask + score = score * (1 - self.reverse_weight) + self.reverse_weight * r_score + r_decoder_out = r_decoder_out.view(B, bz, T2, V) + score = torch.sum(score, axis=1) # B2 + score = torch.reshape(score, (B, bz)) + self.ctc_weight * ctc_score + best_index = torch.argmax(score, dim=1) + return best_index + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='export your script model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--cmvn_file', required=False, default='', type=str, + help='global_cmvn file, default path is in config file') + parser.add_argument('--reverse_weight', default=-1.0, type=float, + required=False, + help='reverse weight for bitransformer,' + + 'default value is in config file') + parser.add_argument('--ctc_weight', default=-1.0, type=float, + required=False, + help='ctc weight, default value is in config file') + parser.add_argument('--beam_size', default=10, type=int, required=False, + help="beam size would be ctc output size") + parser.add_argument('--output_onnx_dir', + default="onnx_model", + help='output onnx encoder and decoder directory') + parser.add_argument('--fp16', + action='store_true', + help='whether to export fp16 model, default false') + args = parser.parse_args() + + torch.manual_seed(0) + torch.set_printoptions(precision=10) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + if args.cmvn_file and os.path.exists(args.cmvn_file): + configs['cmvn_file'] = args.cmvn_file + if args.reverse_weight != -1.0 and 'reverse_weight' in configs['model_conf']: + configs['model_conf']['reverse_weight'] = args.reverse_weight + print("Update reverse weight to", args.reverse_weight) + if args.ctc_weight != -1: + print("Update ctc weight to ", args.ctc_weight) + configs['model_conf']['ctc_weight'] = args.ctc_weight + configs["encoder_conf"]["use_dynamic_chunk"] = False + model = init_asr_model(configs) + load_checkpoint(model, args.checkpoint) + model.eval() + bz = 32 + seq_len = 100 + beam_size = args.beam_size + feature_size = configs["input_dim"] + + speech = torch.randn(bz, seq_len, feature_size, dtype=torch.float32) + speech_lens = torch.randint(low=10, high=seq_len, size=(bz,), dtype=torch.int32) + encoder = Encoder(model.encoder, model.ctc, beam_size) + encoder.eval() + if not os.path.exists(args.output_onnx_dir): + os.mkdir(args.output_onnx_dir) + encoder_onnx_path = os.path.join(args.output_onnx_dir, 'encoder.onnx') + + torch.onnx.export(encoder, + (speech, speech_lens), + encoder_onnx_path, + export_params=True, + opset_version=13, + do_constant_folding=True, + input_names=['speech', 'speech_lengths'], + output_names=['encoder_out', 'encoder_out_lens', + 'ctc_log_probs', + 'beam_log_probs', 'beam_log_probs_idx'], + dynamic_axes={ + 'speech': {0: 'B', 1: 'T'}, + 'speech_lengths': {0: 'B'}, + 'encoder_out': {0: 'B', 1: 'T_OUT'}, + 'encoder_out_lens': {0: 'B'}, + 'ctc_log_probs': {0: 'B', 1: 'T_OUT'}, + 'beam_log_probs': {0: 'B', 1: 'T_OUT'}, + 'beam_log_probs_idx': {0: 'B', 1: 'T_OUT'}, + }, + verbose=False + ) + + def to_numpy(tensor): + if tensor.requires_grad: + return tensor.detach().cpu().numpy() + else: + return tensor.cpu().numpy() + + with torch.no_grad(): + o0, o1, o2, o3, o4 = encoder(speech, speech_lens) + + providers = ["CUDAExecutionProvider"] + ort_session = onnxruntime.InferenceSession(encoder_onnx_path, + providers=providers) + ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(speech), + ort_session.get_inputs()[1].name: to_numpy(speech_lens)} + ort_outs = ort_session.run(None, ort_inputs) + + def test(a, b, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True): + try: + torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol) + except AssertionError as error: + if tolerate_small_mismatch: + print(error) + else: + raise + + # check encoder output + test(to_numpy(o0), ort_outs[0], rtol=1e-03, atol=1e-5) + test(to_numpy(o1), ort_outs[1], rtol=1e-03, atol=1e-05) + test(to_numpy(o2), ort_outs[2], rtol=1e-03, atol=1e-05) + test(to_numpy(o3), ort_outs[3], rtol=1e-03, atol=1e-05) + test(to_numpy(o4), ort_outs[4], rtol=1e-03, atol=1e-05) + logger.info("export to onnx encoder succeed!") + + decoder = Decoder( + model.decoder, + model.ctc_weight, + model.reverse_weight, + beam_size) + decoder.eval() + decoder_onnx_path = os.path.join(args.output_onnx_dir, 'decoder.onnx') + + hyps_pad_sos_eos = torch.randint(low=3, high=1000, size=(bz, beam_size, seq_len)) + hyps_lens_sos = torch.randint(low=3, high=seq_len, size=(bz, beam_size), + dtype=torch.int32) + r_hyps_pad_sos_eos = torch.randint(low=3, high=1000, size=(bz, beam_size, seq_len)) + + output_size = configs["encoder_conf"]["output_size"] + encoder_out = torch.randn(bz, seq_len, output_size, dtype=torch.float32) + encoder_out_lens = torch.randint(low=3, high=seq_len, size=(bz,), dtype=torch.int32) + ctc_score = torch.randn(bz, beam_size, dtype=torch.float32) + torch.onnx.export(decoder, + (encoder_out, encoder_out_lens, + hyps_pad_sos_eos, hyps_lens_sos, + r_hyps_pad_sos_eos, ctc_score), + decoder_onnx_path, + export_params=True, + opset_version=13, + do_constant_folding=True, + input_names=['encoder_out', 'encoder_out_lens', + 'hyps_pad_sos_eos', 'hyps_lens_sos', + 'r_hyps_pad_sos_eos', 'ctc_score'], + output_names=['best_index'], + dynamic_axes={'encoder_out': {0: 'B', 1: 'T'}, + 'encoder_out_lens': {0: 'B'}, + 'hyps_pad_sos_eos': {0: 'B', 2: 'T2'}, + 'hyps_lens_sos': {0: 'B'}, + 'r_hyps_pad_sos_eos': {0: 'B', 2: 'T2'}, + 'ctc_score': {0: 'B'}, + 'best_index': {0: 'B'}, + }, + verbose=False + ) + with torch.no_grad(): + o0 = decoder( + encoder_out, + encoder_out_lens, + hyps_pad_sos_eos, + hyps_lens_sos, + r_hyps_pad_sos_eos, + ctc_score) + + ort_session = onnxruntime.InferenceSession(decoder_onnx_path, + providers=providers) + ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(encoder_out), + ort_session.get_inputs()[1].name: to_numpy(encoder_out_lens), + ort_session.get_inputs()[2].name: to_numpy(hyps_pad_sos_eos), + ort_session.get_inputs()[3].name: to_numpy(hyps_lens_sos), + ort_session.get_inputs()[-1].name: to_numpy(ctc_score) + } + # if model.reverse weight == 0, + # the r_hyps_pad will be removed + # from the onnx decoder since it doen't play any role + if model.reverse_weight > 0: + ort_inputs[ort_session.get_inputs()[4].name] = to_numpy(r_hyps_pad_sos_eos) + ort_outs = ort_session.run(None, ort_inputs) + + # check encoder output + test(to_numpy(o0), ort_outs[0], rtol=1e-03, atol=1e-05) + logger.info("export to onnx decoder succeed!") + + if args.fp16: + try: + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + except ImportError: + print('Please install onnxmltools!') + sys.exit(1) + encoder_onnx_model = onnxmltools.utils.load_model(encoder_onnx_path) + encoder_onnx_model = convert_float_to_float16(encoder_onnx_model) + encoder_onnx_path = os.path.join(args.output_onnx_dir, 'encoder_fp16.onnx') + onnxmltools.utils.save_model(encoder_onnx_model, encoder_onnx_path) + decoder_onnx_model = onnxmltools.utils.load_model(decoder_onnx_path) + decoder_onnx_model = convert_float_to_float16(decoder_onnx_model) + decoder_onnx_path = os.path.join(args.output_onnx_dir, 'decoder_fp16.onnx') + onnxmltools.utils.save_model(decoder_onnx_model, decoder_onnx_path) + # dump configurations + onnx_config = {"beam_size": args.beam_size, + "reverse_weight": args.reverse_weight, + "ctc_weight": args.ctc_weight, + "fp16": args.fp16} + + config_dir = os.path.join(args.output_onnx_dir, "config.yaml") + with open(config_dir, "w") as out: + yaml.dump(onnx_config, out) diff --git a/speech/speech_recognition/transformer/pytorch/wenet/bin/recognize.py b/speech/speech_recognition/transformer/pytorch/wenet/bin/recognize.py new file mode 100644 index 0000000000000000000000000000000000000000..86acc4503fd1d9e39b23a67115e4805738fb7032 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/bin/recognize.py @@ -0,0 +1,237 @@ +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import copy +import logging +import os +import sys + +import torch +import yaml +from torch.utils.data import DataLoader +import sys +CUR_PATH = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.dirname(os.path.dirname(CUR_PATH))) + +from wenet.dataset.dataset import Dataset +from wenet.transformer.asr_model import init_asr_model +from wenet.utils.checkpoint import load_checkpoint +from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols +from wenet.utils.config import override_config + +def get_args(): + parser = argparse.ArgumentParser(description='recognize with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--test_data', required=True, help='test data file') + parser.add_argument('--data_type', + default='raw', + choices=['raw', 'shard'], + help='train and cv data type') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--dict', required=True, help='dict file') + parser.add_argument("--non_lang_syms", + help="non-linguistic symbol file. One symbol per line.") + parser.add_argument('--beam_size', + type=int, + default=10, + help='beam size for search') + parser.add_argument('--penalty', + type=float, + default=0.0, + help='length penalty') + parser.add_argument('--result_file', required=True, help='asr result file') + parser.add_argument('--batch_size', + type=int, + default=16, + help='asr result file') + parser.add_argument('--mode', + choices=[ + 'attention', 'ctc_greedy_search', + 'ctc_prefix_beam_search', 'attention_rescoring' + ], + default='attention', + help='decoding mode') + parser.add_argument('--ctc_weight', + type=float, + default=0.0, + help='ctc weight for attention rescoring decode mode') + parser.add_argument('--decoding_chunk_size', + type=int, + default=-1, + help='''decoding chunk size, + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here''') + parser.add_argument('--num_decoding_left_chunks', + type=int, + default=-1, + help='number of left chunks for decoding') + parser.add_argument('--simulate_streaming', + action='store_true', + help='simulate streaming inference') + parser.add_argument('--reverse_weight', + type=float, + default=0.0, + help='''right to left weight for attention rescoring + decode mode''') + parser.add_argument('--bpe_model', + default=None, + type=str, + help='bpe model for english part') + parser.add_argument('--override_config', + action='append', + default=[], + help="override yaml config") + parser.add_argument('--connect_symbol', + default='', + type=str, + help='used to connect the output characters') + + args = parser.parse_args() + print(args) + return args + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + if args.mode in ['ctc_prefix_beam_search', 'attention_rescoring' + ] and args.batch_size > 1: + logging.fatal( + 'decoding mode {} must be running with batch_size == 1'.format( + args.mode)) + sys.exit(1) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + if len(args.override_config) > 0: + configs = override_config(configs, args.override_config) + + symbol_table = read_symbol_table(args.dict) + test_conf = copy.deepcopy(configs['dataset_conf']) + + test_conf['filter_conf']['max_length'] = 102400 + test_conf['filter_conf']['min_length'] = 0 + test_conf['filter_conf']['token_max_length'] = 102400 + test_conf['filter_conf']['token_min_length'] = 0 + test_conf['filter_conf']['max_output_input_ratio'] = 102400 + test_conf['filter_conf']['min_output_input_ratio'] = 0 + test_conf['speed_perturb'] = False + test_conf['spec_aug'] = False + test_conf['spec_sub'] = False + test_conf['shuffle'] = False + test_conf['sort'] = False + if 'fbank_conf' in test_conf: + test_conf['fbank_conf']['dither'] = 0.0 + elif 'mfcc_conf' in test_conf: + test_conf['mfcc_conf']['dither'] = 0.0 + test_conf['batch_conf']['batch_type'] = "static" + test_conf['batch_conf']['batch_size'] = args.batch_size + non_lang_syms = read_non_lang_symbols(args.non_lang_syms) + + test_dataset = Dataset(args.data_type, + args.test_data, + symbol_table, + test_conf, + args.bpe_model, + non_lang_syms, + partition=False) + + test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) + + # Init asr model from configs + model = init_asr_model(configs) + + # Load dict + char_dict = {v: k for k, v in symbol_table.items()} + eos = len(char_dict) - 1 + + load_checkpoint(model, args.checkpoint) + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + model = model.to(device) + + model.eval() + with torch.no_grad(), open(args.result_file, 'w', encoding="utf-8") as fout: + for batch_idx, batch in enumerate(test_data_loader): + keys, feats, target, feats_lengths, target_lengths = batch + feats = feats.to(device) + target = target.to(device) + feats_lengths = feats_lengths.to(device) + target_lengths = target_lengths.to(device) + if args.mode == 'attention': + hyps, _ = model.recognize( + feats, + feats_lengths, + beam_size=args.beam_size, + decoding_chunk_size=args.decoding_chunk_size, + num_decoding_left_chunks=args.num_decoding_left_chunks, + simulate_streaming=args.simulate_streaming) + hyps = [hyp.tolist() for hyp in hyps] + elif args.mode == 'ctc_greedy_search': + hyps, _ = model.ctc_greedy_search( + feats, + feats_lengths, + decoding_chunk_size=args.decoding_chunk_size, + num_decoding_left_chunks=args.num_decoding_left_chunks, + simulate_streaming=args.simulate_streaming) + # ctc_prefix_beam_search and attention_rescoring only return one + # result in List[int], change it to List[List[int]] for compatible + # with other batch decoding mode + elif args.mode == 'ctc_prefix_beam_search': + assert (feats.size(0) == 1) + hyp, _ = model.ctc_prefix_beam_search( + feats, + feats_lengths, + args.beam_size, + decoding_chunk_size=args.decoding_chunk_size, + num_decoding_left_chunks=args.num_decoding_left_chunks, + simulate_streaming=args.simulate_streaming) + hyps = [hyp] + elif args.mode == 'attention_rescoring': + assert (feats.size(0) == 1) + hyp, _ = model.attention_rescoring( + feats, + feats_lengths, + args.beam_size, + decoding_chunk_size=args.decoding_chunk_size, + num_decoding_left_chunks=args.num_decoding_left_chunks, + ctc_weight=args.ctc_weight, + simulate_streaming=args.simulate_streaming, + reverse_weight=args.reverse_weight) + hyps = [hyp] + for i, key in enumerate(keys): + content = [] + for w in hyps[i]: + if w == eos: + break + content.append(char_dict[w]) + logging.info('{} {}'.format(key, args.connect_symbol.join(content))) + fout.write('{} {}\n'.format(key, args.connect_symbol.join(content))) + + +if __name__ == '__main__': + main() diff --git a/speech/speech_recognition/transformer/pytorch/wenet/bin/recognize_deprecated.py b/speech/speech_recognition/transformer/pytorch/wenet/bin/recognize_deprecated.py new file mode 100644 index 0000000000000000000000000000000000000000..26765bd74743b0a404d98011269925c034aedae6 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/bin/recognize_deprecated.py @@ -0,0 +1,202 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import copy +import logging +import os +import sys + +import torch +import yaml +from torch.utils.data import DataLoader + +from wenet.dataset.dataset_deprecated import AudioDataset, CollateFunc +from wenet.transformer.asr_model import init_asr_model +from wenet.utils.checkpoint import load_checkpoint + +if __name__ == '__main__': + print(""" +!!! This file is deprecated, and we are planning to remove it in +the future, please move to the new IO !!! + """) + parser = argparse.ArgumentParser(description='recognize with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--test_data', required=True, help='test data file') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--dict', required=True, help='dict file') + parser.add_argument('--beam_size', + type=int, + default=10, + help='beam size for search') + parser.add_argument('--penalty', + type=float, + default=0.0, + help='length penalty') + parser.add_argument('--result_file', required=True, help='asr result file') + parser.add_argument('--batch_size', + type=int, + default=16, + help='asr result file') + parser.add_argument('--mode', + choices=[ + 'attention', 'ctc_greedy_search', + 'ctc_prefix_beam_search', 'attention_rescoring' + ], + default='attention', + help='decoding mode') + parser.add_argument('--ctc_weight', + type=float, + default=0.0, + help='ctc weight for attention rescoring decode mode') + parser.add_argument('--decoding_chunk_size', + type=int, + default=-1, + help='''decoding chunk size, + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here''') + parser.add_argument('--num_decoding_left_chunks', + type=int, + default=-1, + help='number of left chunks for decoding') + parser.add_argument('--simulate_streaming', + action='store_true', + help='simulate streaming inference') + parser.add_argument('--reverse_weight', + type=float, + default=0.0, + help='''right to left weight for attention rescoring + decode mode''') + args = parser.parse_args() + print(args) + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + if args.mode in ['ctc_prefix_beam_search', 'attention_rescoring' + ] and args.batch_size > 1: + logging.fatal( + 'decoding mode {} must be running with batch_size == 1'.format( + args.mode)) + sys.exit(1) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + raw_wav = configs['raw_wav'] + # Init dataset and data loader + # Init dataset and data loader + test_collate_conf = copy.deepcopy(configs['collate_conf']) + test_collate_conf['spec_aug'] = False + test_collate_conf['spec_sub'] = False + test_collate_conf['feature_dither'] = False + test_collate_conf['speed_perturb'] = False + if raw_wav: + test_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0 + test_collate_conf['wav_distortion_conf']['wav_dither'] = 0.0 + test_collate_func = CollateFunc(**test_collate_conf, raw_wav=raw_wav) + dataset_conf = configs.get('dataset_conf', {}) + dataset_conf['batch_size'] = args.batch_size + dataset_conf['batch_type'] = 'static' + dataset_conf['sort'] = False + test_dataset = AudioDataset(args.test_data, + **dataset_conf, + raw_wav=raw_wav) + test_data_loader = DataLoader(test_dataset, + collate_fn=test_collate_func, + shuffle=False, + batch_size=1, + num_workers=0) + + # Init asr model from configs + model = init_asr_model(configs) + + # Load dict + char_dict = {} + with open(args.dict, 'r') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + char_dict[int(arr[1])] = arr[0] + eos = len(char_dict) - 1 + + load_checkpoint(model, args.checkpoint) + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + model = model.to(device) + + model.eval() + with torch.no_grad(), open(args.result_file, 'w') as fout: + for batch_idx, batch in enumerate(test_data_loader): + keys, feats, target, feats_lengths, target_lengths = batch + feats = feats.to(device) + target = target.to(device) + feats_lengths = feats_lengths.to(device) + target_lengths = target_lengths.to(device) + if args.mode == 'attention': + hyps, _ = model.recognize( + feats, + feats_lengths, + beam_size=args.beam_size, + decoding_chunk_size=args.decoding_chunk_size, + num_decoding_left_chunks=args.num_decoding_left_chunks, + simulate_streaming=args.simulate_streaming) + hyps = [hyp.tolist() for hyp in hyps] + elif args.mode == 'ctc_greedy_search': + hyps, _ = model.ctc_greedy_search( + feats, + feats_lengths, + decoding_chunk_size=args.decoding_chunk_size, + num_decoding_left_chunks=args.num_decoding_left_chunks, + simulate_streaming=args.simulate_streaming) + # ctc_prefix_beam_search and attention_rescoring only return one + # result in List[int], change it to List[List[int]] for compatible + # with other batch decoding mode + elif args.mode == 'ctc_prefix_beam_search': + assert (feats.size(0) == 1) + hyp, _ = model.ctc_prefix_beam_search( + feats, + feats_lengths, + args.beam_size, + decoding_chunk_size=args.decoding_chunk_size, + num_decoding_left_chunks=args.num_decoding_left_chunks, + simulate_streaming=args.simulate_streaming) + hyps = [hyp] + elif args.mode == 'attention_rescoring': + assert (feats.size(0) == 1) + hyp, _ = model.attention_rescoring( + feats, + feats_lengths, + args.beam_size, + decoding_chunk_size=args.decoding_chunk_size, + num_decoding_left_chunks=args.num_decoding_left_chunks, + ctc_weight=args.ctc_weight, + simulate_streaming=args.simulate_streaming, + reverse_weight=args.reverse_weight) + hyps = [hyp] + for i, key in enumerate(keys): + content = '' + for w in hyps[i]: + if w == eos: + break + content += char_dict[w] + logging.info('{} {}'.format(key, content)) + fout.write('{} {}\n'.format(key, content)) diff --git a/speech/speech_recognition/transformer/pytorch/wenet/bin/recognize_onnx.py b/speech/speech_recognition/transformer/pytorch/wenet/bin/recognize_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..b9756c40c346d01504ff98255749ae9da79f4285 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/bin/recognize_onnx.py @@ -0,0 +1,277 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is for testing exported onnx encoder and decoder. +The exported onnx models only support batch offline ASR inference. +It requires a python wrapped c++ ctc decoder. +Please install it by following: +https://github.com/Slyne/ctc_decoder.git +""" +from __future__ import print_function + +import argparse +import copy +import logging +import os +import sys + +import torch +import yaml +from torch.utils.data import DataLoader + +from wenet.dataset.dataset import Dataset +from wenet.utils.common import IGNORE_ID +from wenet.utils.file_utils import read_symbol_table +from wenet.utils.config import override_config + +import onnxruntime as rt +import multiprocessing +import numpy as np + +try: + from swig_decoders import map_batch, \ + ctc_beam_search_decoder_batch, \ + TrieVector, PathTrie +except ImportError: + print('Please install ctc decoders first by refering to\n' + + 'https://github.com/Slyne/ctc_decoder.git') + sys.exit(1) + + +def get_args(): + parser = argparse.ArgumentParser(description='recognize with your model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--test_data', required=True, help='test data file') + parser.add_argument('--data_type', + default='raw', + choices=['raw', 'shard'], + help='train and cv data type') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this rank, -1 for cpu') + parser.add_argument('--dict', required=True, help='dict file') + parser.add_argument('--encoder_onnx', required=True, help='encoder onnx file') + parser.add_argument('--decoder_onnx', required=True, help='decoder onnx file') + parser.add_argument('--result_file', required=True, help='asr result file') + parser.add_argument('--batch_size', + type=int, + default=32, + help='asr result file') + parser.add_argument('--mode', + choices=[ + 'ctc_greedy_search', 'ctc_prefix_beam_search', + 'attention_rescoring'], + default='attention_rescoring', + help='decoding mode') + parser.add_argument('--bpe_model', + default=None, + type=str, + help='bpe model for english part') + parser.add_argument('--override_config', + action='append', + default=[], + help="override yaml config") + parser.add_argument('--fp16', + action='store_true', + help='whether to export fp16 model, default false') + args = parser.parse_args() + print(args) + return args + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + if len(args.override_config) > 0: + configs = override_config(configs, args.override_config) + + reverse_weight = configs["model_conf"].get("reverse_weight", 0.0) + symbol_table = read_symbol_table(args.dict) + test_conf = copy.deepcopy(configs['dataset_conf']) + test_conf['filter_conf']['max_length'] = 102400 + test_conf['filter_conf']['min_length'] = 0 + test_conf['filter_conf']['token_max_length'] = 102400 + test_conf['filter_conf']['token_min_length'] = 0 + test_conf['filter_conf']['max_output_input_ratio'] = 102400 + test_conf['filter_conf']['min_output_input_ratio'] = 0 + test_conf['speed_perturb'] = False + test_conf['spec_aug'] = False + test_conf['shuffle'] = False + test_conf['sort'] = False + test_conf['fbank_conf']['dither'] = 0.0 + test_conf['batch_conf']['batch_type'] = "static" + test_conf['batch_conf']['batch_size'] = args.batch_size + + test_dataset = Dataset(args.data_type, + args.test_data, + symbol_table, + test_conf, + args.bpe_model, + partition=False) + + test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) + + # Init asr model from configs + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + if use_cuda: + EP_list = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + EP_list = ['CPUExecutionProvider'] + + encoder_ort_session = rt.InferenceSession(args.encoder_onnx, providers=EP_list) + decoder_ort_session = None + if args.mode == "attention_rescoring": + decoder_ort_session = rt.InferenceSession(args.decoder_onnx, providers=EP_list) + + # Load dict + vocabulary = [] + char_dict = {} + with open(args.dict, 'r') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + char_dict[int(arr[1])] = arr[0] + vocabulary.append(arr[0]) + eos = sos = len(char_dict) - 1 + with torch.no_grad(), open(args.result_file, 'w') as fout: + for _, batch in enumerate(test_data_loader): + keys, feats, _, feats_lengths, _ = batch + feats, feats_lengths = feats.numpy(), feats_lengths.numpy() + if args.fp16: + feats = feats.astype(np.float16) + ort_inputs = { + encoder_ort_session.get_inputs()[0].name: feats, + encoder_ort_session.get_inputs()[1].name: feats_lengths} + ort_outs = encoder_ort_session.run(None, ort_inputs) + encoder_out, encoder_out_lens, ctc_log_probs, \ + beam_log_probs, beam_log_probs_idx = ort_outs + beam_size = beam_log_probs.shape[-1] + batch_size = beam_log_probs.shape[0] + num_processes = min(multiprocessing.cpu_count(), batch_size) + if args.mode == 'ctc_greedy_search': + if beam_size != 1: + log_probs_idx = beam_log_probs_idx[:, :, 0] + batch_sents = [] + for idx, seq in enumerate(log_probs_idx): + batch_sents.append(seq[0:encoder_out_lens[idx]].tolist()) + hyps = map_batch(batch_sents, vocabulary, num_processes, + True, 0) + elif args.mode in ('ctc_prefix_beam_search', "attention_rescoring"): + batch_log_probs_seq_list = beam_log_probs.tolist() + batch_log_probs_idx_list = beam_log_probs_idx.tolist() + batch_len_list = encoder_out_lens.tolist() + batch_log_probs_seq = [] + batch_log_probs_ids = [] + batch_start = [] # only effective in streaming deployment + batch_root = TrieVector() + root_dict = {} + for i in range(len(batch_len_list)): + num_sent = batch_len_list[i] + batch_log_probs_seq.append( + batch_log_probs_seq_list[i][0:num_sent]) + batch_log_probs_ids.append( + batch_log_probs_idx_list[i][0:num_sent]) + root_dict[i] = PathTrie() + batch_root.append(root_dict[i]) + batch_start.append(True) + score_hyps = ctc_beam_search_decoder_batch(batch_log_probs_seq, + batch_log_probs_ids, + batch_root, + batch_start, + beam_size, + num_processes, + 0, -2, 0.99999) + if args.mode == 'ctc_prefix_beam_search': + hyps = [] + for cand_hyps in score_hyps: + hyps.append(cand_hyps[0][1]) + hyps = map_batch(hyps, vocabulary, num_processes, False, 0) + if args.mode == 'attention_rescoring': + ctc_score, all_hyps = [], [] + max_len = 0 + for hyps in score_hyps: + cur_len = len(hyps) + if len(hyps) < beam_size: + hyps += (beam_size - cur_len) * [(-float("INF"), (0,))] + cur_ctc_score = [] + for hyp in hyps: + cur_ctc_score.append(hyp[0]) + all_hyps.append(list(hyp[1])) + if len(hyp[1]) > max_len: + max_len = len(hyp[1]) + ctc_score.append(cur_ctc_score) + if args.fp16: + ctc_score = np.array(ctc_score, dtype=np.float16) + else: + ctc_score = np.array(ctc_score, dtype=np.float32) + hyps_pad_sos_eos = np.ones( + (batch_size, beam_size, max_len + 2), dtype=np.int64) * IGNORE_ID + r_hyps_pad_sos_eos = np.ones( + (batch_size, beam_size, max_len + 2), dtype=np.int64) * IGNORE_ID + hyps_lens_sos = np.ones((batch_size, beam_size), dtype=np.int32) + k = 0 + for i in range(batch_size): + for j in range(beam_size): + cand = all_hyps[k] + l = len(cand) + 2 + hyps_pad_sos_eos[i][j][0:l] = [sos] + cand + [eos] + r_hyps_pad_sos_eos[i][j][0:l] = [sos] + cand[::-1] + [eos] + hyps_lens_sos[i][j] = len(cand) + 1 + k += 1 + decoder_ort_inputs = { + decoder_ort_session.get_inputs()[0].name: encoder_out, + decoder_ort_session.get_inputs()[1].name: encoder_out_lens, + decoder_ort_session.get_inputs()[2].name: hyps_pad_sos_eos, + decoder_ort_session.get_inputs()[3].name: hyps_lens_sos, + decoder_ort_session.get_inputs()[-1].name: ctc_score} + if reverse_weight > 0: + r_hyps_pad_sos_eos_name = decoder_ort_session.get_inputs()[4].name + decoder_ort_inputs[r_hyps_pad_sos_eos_name] = r_hyps_pad_sos_eos + best_index = decoder_ort_session.run(None, decoder_ort_inputs)[0] + best_sents = [] + k = 0 + for idx in best_index: + cur_best_sent = all_hyps[k: k + beam_size][idx] + best_sents.append(cur_best_sent) + k += beam_size + hyps = map_batch(best_sents, vocabulary, num_processes) + + for i, key in enumerate(keys): + content = hyps[i] + logging.info('{} {}'.format(key, content)) + fout.write('{} {}\n'.format(key, content)) + +if __name__ == '__main__': + main() diff --git a/speech/speech_recognition/transformer/pytorch/wenet/bin/train.py b/speech/speech_recognition/transformer/pytorch/wenet/bin/train.py new file mode 100644 index 0000000000000000000000000000000000000000..931e01207233773c00a62390dae06d6753b4cb54 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/bin/train.py @@ -0,0 +1,309 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import copy +import logging +import os + +import torch +import torch.distributed as dist +import torch.optim as optim +import yaml +from tensorboardX import SummaryWriter +from torch.utils.data import DataLoader + +import sys +CUR_PATH = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.dirname(os.path.dirname(CUR_PATH))) + +from wenet.dataset.dataset import Dataset +from wenet.transformer.asr_model import init_asr_model +from wenet.utils.checkpoint import (load_checkpoint, save_checkpoint, + load_trained_modules) +from wenet.utils.executor import Executor +from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols +from wenet.utils.scheduler import WarmupLR +from wenet.utils.config import override_config +import time + +def get_args(): + parser = argparse.ArgumentParser(description='training your network') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--data_type', + default='raw', + choices=['raw', 'shard'], + help='train and cv data type') + parser.add_argument('--train_data', required=True, help='train data file') + parser.add_argument('--cv_data', required=True, help='cv data file') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this local rank, -1 for cpu') + parser.add_argument('--model_dir', required=True, help='save model dir') + parser.add_argument('--checkpoint', help='checkpoint model') + parser.add_argument('--tensorboard_dir', + default='tensorboard', + help='tensorboard log dir') + parser.add_argument('--ddp.rank', + dest='rank', + default=0, + type=int, + help='global rank for distributed training') + parser.add_argument('--ddp.world_size', + dest='world_size', + default=-1, + type=int, + help='''number of total processes/gpus for + distributed training''') + parser.add_argument('--ddp.dist_backend', + dest='dist_backend', + default='nccl', + choices=['nccl', 'gloo'], + help='distributed backend') + parser.add_argument('--ddp.init_method', + dest='init_method', + default=None, + help='ddp init method') + parser.add_argument('--num_workers', + default=0, + type=int, + help='num of subprocess workers for reading') + parser.add_argument('--pin_memory', + action='store_true', + default=False, + help='Use pinned memory buffers used for reading') + parser.add_argument('--use_amp', + action='store_true', + default=False, + help='Use automatic mixed precision training') + parser.add_argument('--fp16_grad_sync', + action='store_true', + default=False, + help='Use fp16 gradient sync for ddp') + parser.add_argument('--cmvn', default=None, help='global cmvn file') + parser.add_argument('--symbol_table', + required=True, + help='model unit symbol table for training') + parser.add_argument("--non_lang_syms", + help="non-linguistic symbol file. One symbol per line.") + parser.add_argument('--prefetch', + default=100, + type=int, + help='prefetch number') + parser.add_argument('--bpe_model', + default=None, + type=str, + help='bpe model for english part') + parser.add_argument('--override_config', + action='append', + default=[], + help="override yaml config") + parser.add_argument("--enc_init", + default=None, + type=str, + help="Pre-trained model to initialize encoder") + parser.add_argument("--enc_init_mods", + default="encoder.", + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help="List of encoder modules \ + to initialize ,separated by a comma") + + + args = parser.parse_args() + return args + + +def main(): + args = get_args() + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + + # Set random seed + torch.manual_seed(777) + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + if len(args.override_config) > 0: + configs = override_config(configs, args.override_config) + + distributed = args.world_size > 1 + if distributed: + host_addr_full = 'tcp://' + os.environ["MASTER_ADDR"] + ':' + os.environ["MASTER_PORT"] + + logging.info('training on multiple gpus, this gpu {}'.format(args.gpu)) + dist.init_process_group(args.dist_backend, + init_method=host_addr_full, + world_size=args.world_size, + rank=args.rank) + + symbol_table = read_symbol_table(args.symbol_table) + + train_conf = configs['dataset_conf'] + cv_conf = copy.deepcopy(train_conf) + cv_conf['speed_perturb'] = False + cv_conf['spec_aug'] = False + cv_conf['spec_sub'] = False + cv_conf['shuffle'] = False + non_lang_syms = read_non_lang_symbols(args.non_lang_syms) + + train_dataset = Dataset(args.data_type, args.train_data, symbol_table, + train_conf, args.bpe_model, non_lang_syms, True) + cv_dataset = Dataset(args.data_type, + args.cv_data, + symbol_table, + cv_conf, + args.bpe_model, + non_lang_syms, + partition=False) + + train_data_loader = DataLoader(train_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) + cv_data_loader = DataLoader(cv_dataset, + batch_size=None, + pin_memory=args.pin_memory, + num_workers=args.num_workers, + prefetch_factor=args.prefetch) + + if 'fbank_conf' in configs['dataset_conf']: + input_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins'] + else: + input_dim = configs['dataset_conf']['mfcc_conf']['num_mel_bins'] + vocab_size = len(symbol_table) + + # Save configs to model_dir/train.yaml for inference and export + configs['input_dim'] = input_dim + configs['output_dim'] = vocab_size + configs['cmvn_file'] = args.cmvn + configs['is_json_cmvn'] = True + if args.rank == 0: + saved_config_path = os.path.join(args.model_dir, 'train.yaml') + with open(saved_config_path, 'w') as fout: + data = yaml.dump(configs) + fout.write(data) + + # Init asr model from configs + model = init_asr_model(configs) + print(model) + num_params = sum(p.numel() for p in model.parameters()) + print('the number of model params: {}'.format(num_params)) + + # !!!IMPORTANT!!! + # Try to export the model by script, if fails, we should refine + # the code to satisfy the script export requirements + # if args.rank == 0: + # script_model = torch.jit.script(model) + # script_model.save(os.path.join(args.model_dir, 'init.zip')) + executor = Executor() + # If specify checkpoint, load some info from checkpoint + if args.checkpoint is not None: + infos = load_checkpoint(model, args.checkpoint) + elif args.enc_init is not None: + logging.info('load pretrained encoders: {}'.format(args.enc_init)) + infos = load_trained_modules(model, args) + else: + infos = {} + start_epoch = infos.get('epoch', -1) + 1 + cv_loss = infos.get('cv_loss', 0.0) + step = infos.get('step', -1) + + num_epochs = configs.get('max_epoch', 100) + model_dir = args.model_dir + writer = None + if args.rank == 0: + os.makedirs(model_dir, exist_ok=True) + exp_id = os.path.basename(model_dir) + writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id)) + + if distributed: + assert (torch.cuda.is_available()) + # cuda model is required for nn.parallel.DistributedDataParallel + model.cuda() + model = torch.nn.parallel.DistributedDataParallel( + model, find_unused_parameters=True) + device = torch.device("cuda") + if args.fp16_grad_sync: + from torch.distributed.algorithms.ddp_comm_hooks import ( + default as comm_hooks, + ) + model.register_comm_hook( + state=None, hook=comm_hooks.fp16_compress_hook + ) + else: + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + model = model.to(device) + + optimizer = optim.Adam(model.parameters(), **configs['optim_conf']) + scheduler = WarmupLR(optimizer, **configs['scheduler_conf']) + final_epoch = None + configs['rank'] = args.rank + configs['is_distributed'] = distributed + configs['use_amp'] = args.use_amp + if start_epoch == 0 and args.rank == 0: + save_model_path = os.path.join(model_dir, 'init.pt') + save_checkpoint(model, save_model_path) + + # Start training loop + executor.step = step + scheduler.set_step(step) + # used for pytorch amp mixed precision training + scaler = None + if args.use_amp: + scaler = torch.cuda.amp.GradScaler() + print("use amp \n ") + + for epoch in range(start_epoch, num_epochs): + train_dataset.set_epoch(epoch) + configs['epoch'] = epoch + lr = optimizer.param_groups[0]['lr'] + logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr)) + start_time = time.time() + executor.train(model, optimizer, scheduler, train_data_loader, device, + writer, configs, scaler) + train_time = time.time() - start_time + print("train time: ", train_time) + print("qps:", 120098 / train_time) + + total_loss, num_seen_utts = executor.cv(model, cv_data_loader, device, + configs) + cv_loss = total_loss / num_seen_utts + + logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss)) + if args.rank == 0: + save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch)) + save_checkpoint( + model, save_model_path, { + 'epoch': epoch, + 'lr': lr, + 'cv_loss': cv_loss, + 'step': executor.step + }) + writer.add_scalar('epoch/cv_loss', cv_loss, epoch) + writer.add_scalar('epoch/lr', lr, epoch) + final_epoch = epoch + + if final_epoch is not None and args.rank == 0: + final_model_path = os.path.join(model_dir, 'final.pt') + os.symlink('{}.pt'.format(final_epoch), final_model_path) + writer.close() + + +if __name__ == '__main__': + main() diff --git a/speech/speech_recognition/transformer/pytorch/wenet/bin/train_deprecated.py b/speech/speech_recognition/transformer/pytorch/wenet/bin/train_deprecated.py new file mode 100644 index 0000000000000000000000000000000000000000..236d819e3ab063307e90c390f423cc95e1521519 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/bin/train_deprecated.py @@ -0,0 +1,278 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import copy +import logging +import os + +import torch +import torch.distributed as dist +import torch.optim as optim +import yaml +from tensorboardX import SummaryWriter +from torch.utils.data import DataLoader + +from wenet.dataset.dataset_deprecated import AudioDataset, CollateFunc +from wenet.transformer.asr_model import init_asr_model +from wenet.utils.checkpoint import (load_checkpoint, save_checkpoint, + load_trained_modules) +from wenet.utils.executor import Executor +from wenet.utils.scheduler import WarmupLR + +if __name__ == '__main__': + print(""" +!!! This file is deprecated, and we are planning to remove it in +the future, please move to the new IO !!! + """) + parser = argparse.ArgumentParser(description='training your network') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--train_data', required=True, help='train data file') + parser.add_argument('--cv_data', required=True, help='cv data file') + parser.add_argument('--gpu', + type=int, + default=-1, + help='gpu id for this local rank, -1 for cpu') + parser.add_argument('--model_dir', required=True, help='save model dir') + parser.add_argument('--checkpoint', help='checkpoint model') + parser.add_argument('--tensorboard_dir', + default='tensorboard', + help='tensorboard log dir') + parser.add_argument('--ddp.rank', + dest='rank', + default=0, + type=int, + help='global rank for distributed training') + parser.add_argument('--ddp.world_size', + dest='world_size', + default=-1, + type=int, + help='''number of total processes/gpus for + distributed training''') + parser.add_argument('--ddp.dist_backend', + dest='dist_backend', + default='nccl', + choices=['nccl', 'gloo'], + help='distributed backend') + parser.add_argument('--ddp.init_method', + dest='init_method', + default=None, + help='ddp init method') + parser.add_argument('--num_workers', + default=0, + type=int, + help='num of subprocess workers for reading') + parser.add_argument('--pin_memory', + action='store_true', + default=False, + help='Use pinned memory buffers used for reading') + parser.add_argument('--use_amp', + action='store_true', + default=False, + help='Use automatic mixed precision training') + parser.add_argument('--cmvn', default=None, help='global cmvn file') + parser.add_argument("--enc_init", + default=None, + type=str, + help="Pre-trained model to initialize encoder") + parser.add_argument("--enc_init_mods", + default="encoder.", + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help="List of encoder modules \ + to initialize ,separated by a comma") + + args = parser.parse_args() + + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) + # Set random seed + torch.manual_seed(777) + print(args) + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + distributed = args.world_size > 1 + + raw_wav = configs['raw_wav'] + + train_collate_func = CollateFunc(**configs['collate_conf'], + raw_wav=raw_wav) + + cv_collate_conf = copy.deepcopy(configs['collate_conf']) + # no augmenation on cv set + cv_collate_conf['spec_aug'] = False + cv_collate_conf['spec_sub'] = False + if raw_wav: + cv_collate_conf['feature_dither'] = 0.0 + cv_collate_conf['speed_perturb'] = False + cv_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0 + cv_collate_func = CollateFunc(**cv_collate_conf, raw_wav=raw_wav) + + dataset_conf = configs.get('dataset_conf', {}) + train_dataset = AudioDataset(args.train_data, + **dataset_conf, + raw_wav=raw_wav) + cv_dataset = AudioDataset(args.cv_data, **dataset_conf, raw_wav=raw_wav) + + if distributed: + logging.info('training on multiple gpus, this gpu {}'.format(args.gpu)) + dist.init_process_group(args.dist_backend, + init_method=args.init_method, + world_size=args.world_size, + rank=args.rank) + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, shuffle=True) + cv_sampler = torch.utils.data.distributed.DistributedSampler( + cv_dataset, shuffle=False) + else: + train_sampler = None + cv_sampler = None + + train_data_loader = DataLoader(train_dataset, + collate_fn=train_collate_func, + sampler=train_sampler, + shuffle=(train_sampler is None), + pin_memory=args.pin_memory, + batch_size=1, + num_workers=args.num_workers) + cv_data_loader = DataLoader(cv_dataset, + collate_fn=cv_collate_func, + sampler=cv_sampler, + shuffle=False, + batch_size=1, + pin_memory=args.pin_memory, + num_workers=args.num_workers) + + if raw_wav: + input_dim = configs['collate_conf']['feature_extraction_conf'][ + 'mel_bins'] + else: + input_dim = train_dataset.input_dim + vocab_size = train_dataset.output_dim + + # Save configs to model_dir/train.yaml for inference and export + configs['input_dim'] = input_dim + configs['output_dim'] = vocab_size + configs['cmvn_file'] = args.cmvn + configs['is_json_cmvn'] = raw_wav + if args.rank == 0: + saved_config_path = os.path.join(args.model_dir, 'train.yaml') + with open(saved_config_path, 'w') as fout: + data = yaml.dump(configs) + fout.write(data) + + # Init asr model from configs + model = init_asr_model(configs) + print(model) + num_params = sum(p.numel() for p in model.parameters()) + print('the number of model params: {}'.format(num_params)) + + # !!!IMPORTANT!!! + # Try to export the model by script, if fails, we should refine + # the code to satisfy the script export requirements + if args.rank == 0: + script_model = torch.jit.script(model) + script_model.save(os.path.join(args.model_dir, 'init.zip')) + executor = Executor() + # If specify checkpoint, load some info from checkpoint + if args.checkpoint is not None: + infos = load_checkpoint(model, args.checkpoint) + elif args.enc_init is not None: + logging.debug('load pretrained encoders: {}'.format(args.enc_init)) + infos = load_trained_modules(model, args) + else: + infos = {} + start_epoch = infos.get('epoch', -1) + 1 + cv_loss = infos.get('cv_loss', 0.0) + step = infos.get('step', -1) + + num_epochs = configs.get('max_epoch', 100) + model_dir = args.model_dir + writer = None + if args.rank == 0: + os.makedirs(model_dir, exist_ok=True) + exp_id = os.path.basename(model_dir) + writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id)) + + if distributed: + assert (torch.cuda.is_available()) + # cuda model is required for nn.parallel.DistributedDataParallel + model.cuda() + model = torch.nn.parallel.DistributedDataParallel( + model, find_unused_parameters=True) + device = torch.device("cuda") + else: + use_cuda = args.gpu >= 0 and torch.cuda.is_available() + device = torch.device('cuda' if use_cuda else 'cpu') + model = model.to(device) + + optimizer = optim.Adam(model.parameters(), **configs['optim_conf']) + scheduler = WarmupLR(optimizer, **configs['scheduler_conf']) + final_epoch = None + configs['rank'] = args.rank + configs['is_distributed'] = distributed + configs['use_amp'] = args.use_amp + if start_epoch == 0 and args.rank == 0: + save_model_path = os.path.join(model_dir, 'init.pt') + save_checkpoint(model, save_model_path) + + # Start training loop + executor.step = step + scheduler.set_step(step) + # used for pytorch amp mixed precision training + scaler = None + if args.use_amp: + scaler = torch.cuda.amp.GradScaler() + for epoch in range(start_epoch, num_epochs): + if distributed: + train_sampler.set_epoch(epoch) + lr = optimizer.param_groups[0]['lr'] + logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr)) + executor.train(model, optimizer, scheduler, train_data_loader, device, + writer, configs, scaler) + total_loss, num_seen_utts = executor.cv(model, cv_data_loader, device, + configs) + if args.world_size > 1: + # all_reduce expected a sequence parameter, so we use [num_seen_utts]. + num_seen_utts = torch.Tensor([num_seen_utts]).to(device) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = torch.Tensor([total_loss]).to(device) + dist.all_reduce(total_loss) + cv_loss = total_loss[0] / num_seen_utts[0] + cv_loss = cv_loss.item() + else: + cv_loss = total_loss / num_seen_utts + + logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss)) + if args.rank == 0: + save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch)) + save_checkpoint( + model, save_model_path, { + 'epoch': epoch, + 'lr': lr, + 'cv_loss': cv_loss, + 'step': executor.step + }) + writer.add_scalar('epoch/cv_loss', cv_loss, epoch) + writer.add_scalar('epoch/lr', lr, epoch) + final_epoch = epoch + + if final_epoch is not None and args.rank == 0: + final_model_path = os.path.join(model_dir, 'final.pt') + os.symlink('{}.pt'.format(final_epoch), final_model_path) + writer.close() diff --git a/speech/speech_recognition/transformer/pytorch/wenet/dataset/dataset.py b/speech/speech_recognition/transformer/pytorch/wenet/dataset/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fc50c508aed3432494561c1ecf6b0b6b74bb30e8 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/dataset/dataset.py @@ -0,0 +1,189 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import torch +import torch.distributed as dist +from torch.utils.data import IterableDataset + +import wenet.dataset.processor as processor +from wenet.utils.file_utils import read_lists + + +class Processor(IterableDataset): + def __init__(self, source, f, *args, **kw): + assert callable(f) + self.source = source + self.f = f + self.args = args + self.kw = kw + + def set_epoch(self, epoch): + self.source.set_epoch(epoch) + + def __iter__(self): + """ Return an iterator over the source dataset processed by the + given processor. + """ + assert self.source is not None + assert callable(self.f) + return self.f(iter(self.source), *self.args, **self.kw) + + def apply(self, f): + assert callable(f) + return Processor(self, f, *self.args, **self.kw) + + +class DistributedSampler: + def __init__(self, shuffle=True, partition=True): + self.epoch = -1 + self.update() + self.shuffle = shuffle + self.partition = partition + + def update(self): + assert dist.is_available() + if dist.is_initialized(): + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + else: + self.rank = 0 + self.world_size = 1 + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + self.worker_id = 0 + self.num_workers = 1 + else: + self.worker_id = worker_info.id + self.num_workers = worker_info.num_workers + return dict(rank=self.rank, + world_size=self.world_size, + worker_id=self.worker_id, + num_workers=self.num_workers) + + def set_epoch(self, epoch): + self.epoch = epoch + + def sample(self, data): + """ Sample data according to rank/world_size/num_workers + + Args: + data(List): input data list + + Returns: + List: data list after sample + """ + data = list(range(len(data))) + # TODO(Binbin Zhang): fix this + # We can not handle uneven data for CV on DDP, so we don't + # sample data by rank, that means every GPU gets the same + # and all the CV data + if self.partition: + if self.shuffle: + random.Random(self.epoch).shuffle(data) + data = data[self.rank::self.world_size] + data = data[self.worker_id::self.num_workers] + return data + + +class DataList(IterableDataset): + def __init__(self, lists, shuffle=True, partition=True): + self.lists = lists + self.sampler = DistributedSampler(shuffle, partition) + + def set_epoch(self, epoch): + self.sampler.set_epoch(epoch) + + def __iter__(self): + sampler_info = self.sampler.update() + indexes = self.sampler.sample(self.lists) + for index in indexes: + # yield dict(src=src) + data = dict(src=self.lists[index]) + data.update(sampler_info) + yield data + + +def Dataset(data_type, + data_list_file, + symbol_table, + conf, + bpe_model=None, + non_lang_syms=None, + partition=True): + """ Construct dataset from arguments + + We have two shuffle stage in the Dataset. The first is global + shuffle at shards tar/raw file level. The second is global shuffle + at training samples level. + + Args: + data_type(str): raw/shard + bpe_model(str): model for english bpe part + partition(bool): whether to do data partition in terms of rank + """ + assert data_type in ['raw', 'shard'] + lists = read_lists(data_list_file) + shuffle = conf.get('shuffle', True) + dataset = DataList(lists, shuffle=shuffle, partition=partition) + if data_type == 'shard': + dataset = Processor(dataset, processor.url_opener) + dataset = Processor(dataset, processor.tar_file_and_group) + else: + dataset = Processor(dataset, processor.parse_raw) + + dataset = Processor(dataset, processor.tokenize, symbol_table, bpe_model, + non_lang_syms, conf.get('split_with_space', False)) + filter_conf = conf.get('filter_conf', {}) + dataset = Processor(dataset, processor.filter, **filter_conf) + + resample_conf = conf.get('resample_conf', {}) + dataset = Processor(dataset, processor.resample, **resample_conf) + + speed_perturb = conf.get('speed_perturb', False) + if speed_perturb: + dataset = Processor(dataset, processor.speed_perturb) + + feats_type = conf.get('feats_type', 'fbank') + assert feats_type in ['fbank', 'mfcc'] + if feats_type == 'fbank': + fbank_conf = conf.get('fbank_conf', {}) + dataset = Processor(dataset, processor.compute_fbank, **fbank_conf) + elif feats_type == 'mfcc': + mfcc_conf = conf.get('mfcc_conf', {}) + dataset = Processor(dataset, processor.compute_mfcc, **mfcc_conf) + + spec_aug = conf.get('spec_aug', True) + spec_sub = conf.get('spec_sub', False) + if spec_aug: + spec_aug_conf = conf.get('spec_aug_conf', {}) + dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf) + if spec_sub: + spec_sub_conf = conf.get('spec_sub_conf', {}) + dataset = Processor(dataset, processor.spec_sub, **spec_sub_conf) + + if shuffle: + shuffle_conf = conf.get('shuffle_conf', {}) + dataset = Processor(dataset, processor.shuffle, **shuffle_conf) + + sort = conf.get('sort', True) + if sort: + sort_conf = conf.get('sort_conf', {}) + dataset = Processor(dataset, processor.sort, **sort_conf) + + batch_conf = conf.get('batch_conf', {}) + dataset = Processor(dataset, processor.batch, **batch_conf) + dataset = Processor(dataset, processor.padding) + return dataset diff --git a/speech/speech_recognition/transformer/pytorch/wenet/dataset/dataset_deprecated.py b/speech/speech_recognition/transformer/pytorch/wenet/dataset/dataset_deprecated.py new file mode 100644 index 0000000000000000000000000000000000000000..10c5065ce6554ab2cd0b705f052d5127b546fe99 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/dataset/dataset_deprecated.py @@ -0,0 +1,533 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Chao Yang) +# Copyright (c) 2021 Jinsong Pan +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import codecs +import copy +import logging +import random + +import numpy as np +import torch +import torchaudio +import torchaudio.compliance.kaldi as kaldi +import torchaudio.sox_effects as sox_effects +import yaml +from PIL import Image +from PIL.Image import BICUBIC +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset, DataLoader + +import wenet.dataset.kaldi_io as kaldi_io +from wenet.dataset.wav_distortion import distort_wav_conf +from wenet.utils.common import IGNORE_ID + +torchaudio.set_audio_backend("sox_io") + + +def _spec_augmentation(x, + warp_for_time=False, + num_t_mask=2, + num_f_mask=2, + max_t=50, + max_f=10, + max_w=80): + """ Deep copy x and do spec augmentation then return it + + Args: + x: input feature, T * F 2D + num_t_mask: number of time mask to apply + num_f_mask: number of freq mask to apply + max_t: max width of time mask + max_f: max width of freq mask + max_w: max width of time warp + + Returns: + augmented feature + """ + y = np.copy(x) + max_frames = y.shape[0] + max_freq = y.shape[1] + + # time warp + if warp_for_time and max_frames > max_w * 2: + center = random.randrange(max_w, max_frames - max_w) + warped = random.randrange(center - max_w, center + max_w) + 1 + + left = Image.fromarray(x[:center]).resize((max_freq, warped), BICUBIC) + right = Image.fromarray(x[center:]).resize( + (max_freq, max_frames - warped), BICUBIC) + y = np.concatenate((left, right), 0) + # time mask + for i in range(num_t_mask): + start = random.randint(0, max_frames - 1) + length = random.randint(1, max_t) + end = min(max_frames, start + length) + y[start:end, :] = 0 + # freq mask + for i in range(num_f_mask): + start = random.randint(0, max_freq - 1) + length = random.randint(1, max_f) + end = min(max_freq, start + length) + y[:, start:end] = 0 + return y + + +def _spec_substitute(x, max_t=20, num_t_sub=3): + """ Deep copy x and do spec substitute then return it + + Args: + x: input feature, T * F 2D + max_t: max width of time substitute + num_t_sub: number of time substitute to apply + + Returns: + augmented feature + """ + y = np.copy(x) + max_frames = y.shape[0] + for i in range(num_t_sub): + start = random.randint(0, max_frames - 1) + length = random.randint(1, max_t) + end = min(max_frames, start + length) + # only substitute the earlier time chosen randomly for current time + pos = random.randint(0, start) + y[start:end, :] = y[start - pos:end - pos, :] + return y + + +def _waveform_distortion(waveform, distortion_methods_conf): + """ Apply distortion on waveform + + This distortion will not change the length of the waveform. + + Args: + waveform: numpy float tensor, (length,) + distortion_methods_conf: a list of config for ditortion method. + a method will be randomly selected by 'method_rate' and + apply on the waveform. + + Returns: + distorted waveform. + """ + r = random.uniform(0, 1) + acc = 0.0 + for distortion_method in distortion_methods_conf: + method_rate = distortion_method['method_rate'] + acc += method_rate + if r < acc: + distortion_type = distortion_method['name'] + distortion_conf = distortion_method['params'] + point_rate = distortion_method['point_rate'] + return distort_wav_conf(waveform, distortion_type, distortion_conf, + point_rate) + return waveform + + +# add speed perturb when loading wav +# return augmented, sr +def _load_wav_with_speed(wav_file, speed): + """ Load the wave from file and apply speed perpturbation + + Args: + wav_file: input feature, T * F 2D + + Returns: + augmented feature + """ + if speed == 1.0: + wav, sr = torchaudio.load(wav_file) + else: + sample_rate = torchaudio.backend.sox_io_backend.info( + wav_file).sample_rate + # get torchaudio version + ta_no = torchaudio.__version__.split(".") + ta_version = 100 * int(ta_no[0]) + 10 * int(ta_no[1]) + + if ta_version < 80: + # Note: deprecated in torchaudio>=0.8.0 + E = sox_effects.SoxEffectsChain() + E.append_effect_to_chain('speed', speed) + E.append_effect_to_chain("rate", sample_rate) + E.set_input_file(wav_file) + wav, sr = E.sox_build_flow_effects() + else: + # Note: enable in torchaudio>=0.8.0 + wav, sr = sox_effects.apply_effects_file( + wav_file, + [['speed', str(speed)], ['rate', str(sample_rate)]]) + + return wav, sr + + +def _extract_feature(batch, speed_perturb, wav_distortion_conf, + feature_extraction_conf): + """ Extract acoustic fbank feature from origin waveform. + + Speed perturbation and wave amplitude distortion is optional. + + Args: + batch: a list of tuple (wav id , wave path). + speed_perturb: bool, whether or not to use speed pertubation. + wav_distortion_conf: a dict , the config of wave amplitude distortion. + feature_extraction_conf:a dict , the config of fbank extraction. + + Returns: + (keys, feats, labels) + """ + keys = [] + feats = [] + lengths = [] + wav_dither = wav_distortion_conf['wav_dither'] + wav_distortion_rate = wav_distortion_conf['wav_distortion_rate'] + distortion_methods_conf = wav_distortion_conf['distortion_methods'] + if speed_perturb: + speeds = [1.0, 1.1, 0.9] + weights = [1, 1, 1] + speed = random.choices(speeds, weights, k=1)[0] + # speed = random.choice(speeds) + for i, x in enumerate(batch): + try: + wav = x[1] + value = wav.strip().split(",") + # 1 for general wav.scp, 3 for segmented wav.scp + assert len(value) == 1 or len(value) == 3 + wav_path = value[0] + sample_rate = torchaudio.backend.sox_io_backend.info( + wav_path).sample_rate + if 'resample' in feature_extraction_conf: + resample_rate = feature_extraction_conf['resample'] + else: + resample_rate = sample_rate + if speed_perturb: + if len(value) == 3: + logging.error( + "speed perturb does not support segmented wav.scp now") + assert len(value) == 1 + waveform, sample_rate = _load_wav_with_speed(wav_path, speed) + else: + # value length 3 means using segmented wav.scp + # incluede .wav, start time, end time + if len(value) == 3: + start_frame = int(float(value[1]) * sample_rate) + end_frame = int(float(value[2]) * sample_rate) + waveform, sample_rate = torchaudio.backend.sox_io_backend.load( + filepath=wav_path, + num_frames=end_frame - start_frame, + frame_offset=start_frame) + else: + waveform, sample_rate = torchaudio.load(wav_path) + waveform = waveform * (1 << 15) + if resample_rate != sample_rate: + waveform = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=resample_rate)(waveform) + + if wav_distortion_rate > 0.0: + r = random.uniform(0, 1) + if r < wav_distortion_rate: + waveform = waveform.detach().numpy() + waveform = _waveform_distortion(waveform, + distortion_methods_conf) + waveform = torch.from_numpy(waveform) + mat = kaldi.fbank( + waveform, + num_mel_bins=feature_extraction_conf['mel_bins'], + frame_length=feature_extraction_conf['frame_length'], + frame_shift=feature_extraction_conf['frame_shift'], + dither=wav_dither, + energy_floor=0.0, + sample_frequency=resample_rate) + mat = mat.detach().numpy() + feats.append(mat) + keys.append(x[0]) + lengths.append(mat.shape[0]) + except (Exception) as e: + print(e) + logging.warn('read utterance {} error'.format(x[0])) + pass + # Sort it because sorting is required in pack/pad operation + order = np.argsort(lengths)[::-1] + sorted_keys = [keys[i] for i in order] + sorted_feats = [feats[i] for i in order] + labels = [x[2].split() for x in batch] + labels = [np.fromiter(map(int, x), dtype=np.int32) for x in labels] + sorted_labels = [labels[i] for i in order] + return sorted_keys, sorted_feats, sorted_labels + + +def _load_feature(batch): + """ Load acoustic feature from files. + + The features have been prepared in previous step, usualy by Kaldi. + + Args: + batch: a list of tuple (wav id , feature ark path). + + Returns: + (keys, feats, labels) + """ + keys = [] + feats = [] + lengths = [] + for i, x in enumerate(batch): + try: + mat = kaldi_io.read_mat(x[1]) + feats.append(mat) + keys.append(x[0]) + lengths.append(mat.shape[0]) + except (Exception): + # logging.warn('read utterance {} error'.format(x[0])) + pass + # Sort it because sorting is required in pack/pad operation + order = np.argsort(lengths)[::-1] + sorted_keys = [keys[i] for i in order] + sorted_feats = [feats[i] for i in order] + labels = [x[2].split() for x in batch] + labels = [np.fromiter(map(int, x), dtype=np.int32) for x in labels] + sorted_labels = [labels[i] for i in order] + return sorted_keys, sorted_feats, sorted_labels + + +class CollateFunc(object): + """ Collate function for AudioDataset + """ + def __init__( + self, + feature_dither=0.0, + speed_perturb=False, + spec_aug=False, + spec_aug_conf=None, + spec_sub=False, + spec_sub_conf=None, + raw_wav=True, + feature_extraction_conf=None, + wav_distortion_conf=None, + ): + """ + Args: + raw_wav: + True if input is raw wav and feature extraction is needed. + False if input is extracted feature + """ + self.wav_distortion_conf = wav_distortion_conf + self.feature_extraction_conf = feature_extraction_conf + self.spec_aug = spec_aug + self.feature_dither = feature_dither + self.speed_perturb = speed_perturb + self.raw_wav = raw_wav + self.spec_aug_conf = spec_aug_conf + self.spec_sub = spec_sub + self.spec_sub_conf = spec_sub_conf + + def __call__(self, batch): + assert (len(batch) == 1) + if self.raw_wav: + keys, xs, ys = _extract_feature(batch[0], self.speed_perturb, + self.wav_distortion_conf, + self.feature_extraction_conf) + + else: + keys, xs, ys = _load_feature(batch[0]) + + train_flag = True + if ys is None: + train_flag = False + + # optional feature dither d ~ (-a, a) on fbank feature + # a ~ (0, 0.5) + if self.feature_dither != 0.0: + a = random.uniform(0, self.feature_dither) + xs = [x + (np.random.random_sample(x.shape) - 0.5) * a for x in xs] + + # optinoal spec substitute + if self.spec_sub: + xs = [_spec_substitute(x, **self.spec_sub_conf) for x in xs] + + # optinoal spec augmentation + if self.spec_aug: + xs = [_spec_augmentation(x, **self.spec_aug_conf) for x in xs] + + # padding + xs_lengths = torch.from_numpy( + np.array([x.shape[0] for x in xs], dtype=np.int32)) + + # pad_sequence will FAIL in case xs is empty + if len(xs) > 0: + xs_pad = pad_sequence([torch.from_numpy(x).float() for x in xs], + True, 0) + else: + xs_pad = torch.Tensor(xs) + if train_flag: + ys_lengths = torch.from_numpy( + np.array([y.shape[0] for y in ys], dtype=np.int32)) + if len(ys) > 0: + ys_pad = pad_sequence([torch.from_numpy(y).int() for y in ys], + True, IGNORE_ID) + else: + ys_pad = torch.Tensor(ys) + else: + ys_pad = None + ys_lengths = None + return keys, xs_pad, ys_pad, xs_lengths, ys_lengths + + +class AudioDataset(Dataset): + def __init__(self, + data_file, + max_length=10240, + min_length=0, + token_max_length=200, + token_min_length=1, + batch_type='static', + batch_size=1, + max_frames_in_batch=0, + sort=True, + raw_wav=True): + """Dataset for loading audio data. + + Attributes:: + data_file: input data file + Plain text data file, each line contains following 7 fields, + which is split by '\t': + utt:utt1 + feat:tmp/data/file1.wav or feat:tmp/data/fbank.ark:30 + feat_shape: 4.95(in seconds) or feat_shape:495,80(495 is in frames) + text:i love you + token: i l o v e y o u + tokenid: int id of this token + token_shape: M,N # M is the number of token, N is vocab size + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + token_max_length: drop utterance which is greater than token_max_length, + especially when use char unit for english modeling + token_min_length: drop utterance which is less than token_max_length + batch_type: static or dynamic, see max_frames_in_batch(dynamic) + batch_size: number of utterances in a batch, + it's for static batch size. + max_frames_in_batch: max feature frames in a batch, + when batch_type is dynamic, it's for dynamic batch size. + Then batch_size is ignored, we will keep filling the + batch until the total frames in batch up to max_frames_in_batch. + sort: whether to sort all data, so the utterance with the same + length could be filled in a same batch. + raw_wav: use raw wave or extracted featute. + if raw wave is used, dynamic waveform-level augmentation could be used + and the feature is extracted by torchaudio. + if extracted featute(e.g. by kaldi) is used, only feature-level + augmentation such as specaug could be used. + """ + assert batch_type in ['static', 'dynamic'] + data = [] + + # Open in utf8 mode since meet encoding problem + with codecs.open(data_file, 'r', encoding='utf-8') as f: + for line in f: + arr = line.strip().split('\t') + if len(arr) != 7: + continue + key = arr[0].split(':')[1] + tokenid = arr[5].split(':')[1] + output_dim = int(arr[6].split(':')[1].split(',')[1]) + if raw_wav: + wav_path = ':'.join(arr[1].split(':')[1:]) + duration = int(float(arr[2].split(':')[1]) * 1000 / 10) + data.append((key, wav_path, duration, tokenid)) + else: + feat_ark = ':'.join(arr[1].split(':')[1:]) + feat_info = arr[2].split(':')[1].split(',') + feat_dim = int(feat_info[1].strip()) + num_frames = int(feat_info[0].strip()) + data.append((key, feat_ark, num_frames, tokenid)) + self.input_dim = feat_dim + self.output_dim = output_dim + if sort: + data = sorted(data, key=lambda x: x[2]) + valid_data = [] + for i in range(len(data)): + length = data[i][2] + token_length = len(data[i][3].split()) + # remove too lang or too short utt for both input and output + # to prevent from out of memory + if length > max_length or length < min_length: + # logging.warn('ignore utterance {} feature {}'.format( + # data[i][0], length)) + pass + elif token_length > token_max_length or token_length < token_min_length: + pass + else: + valid_data.append(data[i]) + data = valid_data + self.minibatch = [] + num_data = len(data) + # Dynamic batch size + if batch_type == 'dynamic': + assert (max_frames_in_batch > 0) + self.minibatch.append([]) + num_frames_in_batch = 0 + for i in range(num_data): + length = data[i][2] + num_frames_in_batch += length + if num_frames_in_batch > max_frames_in_batch: + self.minibatch.append([]) + num_frames_in_batch = length + self.minibatch[-1].append((data[i][0], data[i][1], data[i][3])) + # Static batch size + else: + cur = 0 + while cur < num_data: + end = min(cur + batch_size, num_data) + item = [] + for i in range(cur, end): + item.append((data[i][0], data[i][1], data[i][3])) + self.minibatch.append(item) + cur = end + + def __len__(self): + return len(self.minibatch) + + def __getitem__(self, idx): + return self.minibatch[idx] + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('type', help='config file') + parser.add_argument('config_file', help='config file') + parser.add_argument('data_file', help='input data file') + args = parser.parse_args() + + with open(args.config_file, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + + # Init dataset and data loader + collate_conf = copy.copy(configs['collate_conf']) + if args.type == 'raw_wav': + raw_wav = True + else: + raw_wav = False + collate_func = CollateFunc(**collate_conf, raw_wav=raw_wav) + dataset_conf = configs.get('dataset_conf', {}) + dataset = AudioDataset(args.data_file, **dataset_conf, raw_wav=raw_wav) + + data_loader = DataLoader(dataset, + batch_size=1, + shuffle=True, + sampler=None, + num_workers=0, + collate_fn=collate_func) + + for i, batch in enumerate(data_loader): + print(i) + # print(batch[1].shape) diff --git a/speech/speech_recognition/transformer/pytorch/wenet/dataset/kaldi_io.py b/speech/speech_recognition/transformer/pytorch/wenet/dataset/kaldi_io.py new file mode 100644 index 0000000000000000000000000000000000000000..c9bef293c93d882147bb5b738e1fc49a7a19a484 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/dataset/kaldi_io.py @@ -0,0 +1,666 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright 2014-2016 Brno University of Technology (author: Karel Vesely) +# Licensed under the Apache License, Version 2.0 (the "License") + +import numpy as np +import sys, os, re, gzip, struct + +################################################# +# Adding kaldi tools to shell path, + +# Select kaldi, +if not 'KALDI_ROOT' in os.environ: + # Default! To change run python with 'export KALDI_ROOT=/some_dir python' + os.environ['KALDI_ROOT']='/mnt/matylda5/iveselyk/Tools/kaldi-trunk' + +# Add kaldi tools to path, +os.environ['PATH'] = os.popen('echo $KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lm/:$KALDI_ROOT/src/sgmmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$KALDI_ROOT/src/nnetbin:$KALDI_ROOT/src/nnet2bin:$KALDI_ROOT/src/nnet3bin:$KALDI_ROOT/src/online2bin/:$KALDI_ROOT/src/ivectorbin/:$KALDI_ROOT/src/lmbin/').readline().strip() + ':' + os.environ['PATH'] + + +################################################# +# Define all custom exceptions, +class UnsupportedDataType(Exception): pass +class UnknownVectorHeader(Exception): pass +class UnknownMatrixHeader(Exception): pass + +class BadSampleSize(Exception): pass +class BadInputFormat(Exception): pass + +class SubprocessFailed(Exception): pass + +################################################# +# Data-type independent helper functions, + +def open_or_fd(file, mode='rb'): + """ fd = open_or_fd(file) + Open file, gzipped file, pipe, or forward the file-descriptor. + Eventually seeks in the 'file' argument contains ':offset' suffix. + """ + offset = None + try: + # strip 'ark:' prefix from r{x,w}filename (optional), + if re.search('^(ark|scp)(,scp|,b|,t|,n?f|,n?p|,b?o|,n?s|,n?cs)*:', file): + (prefix,file) = file.split(':',1) + # separate offset from filename (optional), + if re.search(':[0-9]+$', file): + (file,offset) = file.rsplit(':',1) + # input pipe? + if file[-1] == '|': + fd = popen(file[:-1], 'rb') # custom, + # output pipe? + elif file[0] == '|': + fd = popen(file[1:], 'wb') # custom, + # is it gzipped? + elif file.split('.')[-1] == 'gz': + fd = gzip.open(file, mode) + # a normal file... + else: + fd = open(file, mode) + except TypeError: + # 'file' is opened file descriptor, + fd = file + # Eventually seek to offset, + if offset != None: fd.seek(int(offset)) + return fd + +# based on '/usr/local/lib/python3.4/os.py' +def popen(cmd, mode="rb"): + if not isinstance(cmd, str): + raise TypeError("invalid cmd type (%s, expected string)" % type(cmd)) + + import subprocess, io, threading + + # cleanup function for subprocesses, + def cleanup(proc, cmd): + ret = proc.wait() + if ret > 0: + raise SubprocessFailed('cmd %s returned %d !' % (cmd,ret)) + return + + # text-mode, + if mode == "r": + proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) + threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, + return io.TextIOWrapper(proc.stdout) + elif mode == "w": + proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE) + threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, + return io.TextIOWrapper(proc.stdin) + # binary, + elif mode == "rb": + proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) + threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, + return proc.stdout + elif mode == "wb": + proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE) + threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, + return proc.stdin + # sanity, + else: + raise ValueError("invalid mode %s" % mode) + + +def read_key(fd): + """ [key] = read_key(fd) + Read the utterance-key from the opened ark/stream descriptor 'fd'. + """ + key = '' + while 1: + char = fd.read(1).decode("latin1") + if char == '' : break + if char == ' ' : break + key += char + key = key.strip() + if key == '': return None # end of file, + assert(re.match('^\S+$',key) != None) # check format (no whitespace!) + return key + + +################################################# +# Integer vectors (alignments, ...), + +def read_ali_ark(file_or_fd): + """ Alias to 'read_vec_int_ark()' """ + return read_vec_int_ark(file_or_fd) + +def read_vec_int_ark(file_or_fd): + """ generator(key,vec) = read_vec_int_ark(file_or_fd) + Create generator of (key,vector) tuples, which reads from the ark file/stream. + file_or_fd : ark, gzipped ark, pipe or opened file descriptor. + + Read ark to a 'dictionary': + d = { u:d for u,d in kaldi_io.read_vec_int_ark(file) } + """ + fd = open_or_fd(file_or_fd) + try: + key = read_key(fd) + while key: + ali = read_vec_int(fd) + yield key, ali + key = read_key(fd) + finally: + if fd is not file_or_fd: fd.close() + +def read_vec_int_scp(file_or_fd): + """ generator(key,vec) = read_vec_int_scp(file_or_fd) + Returns generator of (key,vector) tuples, read according to kaldi scp. + file_or_fd : scp, gzipped scp, pipe or opened file descriptor. + + Iterate the scp: + for key,vec in kaldi_io.read_vec_int_scp(file): + ... + + Read scp to a 'dictionary': + d = { key:vec for key,mat in kaldi_io.read_vec_int_scp(file) } + """ + fd = open_or_fd(file_or_fd) + try: + for line in fd: + (key,rxfile) = line.decode().split(' ') + vec = read_vec_int(rxfile) + yield key, vec + finally: + if fd is not file_or_fd : fd.close() + +def read_vec_int(file_or_fd): + """ [int-vec] = read_vec_int(file_or_fd) + Read kaldi integer vector, ascii or binary input, + """ + fd = open_or_fd(file_or_fd) + binary = fd.read(2).decode() + if binary == '\0B': # binary flag + assert(fd.read(1).decode() == '\4'); # int-size + vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # vector dim + # Elements from int32 vector are sored in tuples: (sizeof(int32), value), + vec = np.frombuffer(fd.read(vec_size*5), dtype=[('size','int8'),('value','int32')], count=vec_size) + assert(vec[0]['size'] == 4) # int32 size, + ans = vec[:]['value'] # values are in 2nd column, + else: # ascii, + arr = (binary + fd.readline().decode()).strip().split() + try: + arr.remove('['); arr.remove(']') # optionally + except ValueError: + pass + ans = np.array(arr, dtype=int) + if fd is not file_or_fd : fd.close() # cleanup + return ans + +# Writing, +def write_vec_int(file_or_fd, v, key=''): + """ write_vec_int(f, v, key='') + Write a binary kaldi integer vector to filename or stream. + Arguments: + file_or_fd : filename or opened file descriptor for writing, + v : the vector to be stored, + key (optional) : used for writing ark-file, the utterance-id gets written before the vector. + + Example of writing single vector: + kaldi_io.write_vec_int(filename, vec) + + Example of writing arkfile: + with open(ark_file,'w') as f: + for key,vec in dict.iteritems(): + kaldi_io.write_vec_flt(f, vec, key=key) + """ + fd = open_or_fd(file_or_fd, mode='wb') + if sys.version_info[0] == 3: assert(fd.mode == 'wb') + try: + if key != '' : fd.write((key+' ').encode("latin1")) # ark-files have keys (utterance-id), + fd.write('\0B'.encode()) # we write binary! + # dim, + fd.write('\4'.encode()) # int32 type, + fd.write(struct.pack(np.dtype('int32').char, v.shape[0])) + # data, + for i in range(len(v)): + fd.write('\4'.encode()) # int32 type, + fd.write(struct.pack(np.dtype('int32').char, v[i])) # binary, + finally: + if fd is not file_or_fd : fd.close() + + +################################################# +# Float vectors (confidences, ivectors, ...), + +# Reading, +def read_vec_flt_scp(file_or_fd): + """ generator(key,mat) = read_vec_flt_scp(file_or_fd) + Returns generator of (key,vector) tuples, read according to kaldi scp. + file_or_fd : scp, gzipped scp, pipe or opened file descriptor. + + Iterate the scp: + for key,vec in kaldi_io.read_vec_flt_scp(file): + ... + + Read scp to a 'dictionary': + d = { key:mat for key,mat in kaldi_io.read_mat_scp(file) } + """ + fd = open_or_fd(file_or_fd) + try: + for line in fd: + (key,rxfile) = line.decode().split(' ') + vec = read_vec_flt(rxfile) + yield key, vec + finally: + if fd is not file_or_fd : fd.close() + +def read_vec_flt_ark(file_or_fd): + """ generator(key,vec) = read_vec_flt_ark(file_or_fd) + Create generator of (key,vector) tuples, reading from an ark file/stream. + file_or_fd : ark, gzipped ark, pipe or opened file descriptor. + + Read ark to a 'dictionary': + d = { u:d for u,d in kaldi_io.read_vec_flt_ark(file) } + """ + fd = open_or_fd(file_or_fd) + try: + key = read_key(fd) + while key: + ali = read_vec_flt(fd) + yield key, ali + key = read_key(fd) + finally: + if fd is not file_or_fd: fd.close() + +def read_vec_flt(file_or_fd): + """ [flt-vec] = read_vec_flt(file_or_fd) + Read kaldi float vector, ascii or binary input, + """ + fd = open_or_fd(file_or_fd) + binary = fd.read(2).decode() + if binary == '\0B': # binary flag + # Data type, + header = fd.read(3).decode() + if header == 'FV ': sample_size = 4 # floats + elif header == 'DV ': sample_size = 8 # doubles + else: raise UnknownVectorHeader("The header contained '%s'" % header) + assert(sample_size > 0) + # Dimension, + assert(fd.read(1).decode() == '\4'); # int-size + vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # vector dim + # Read whole vector, + buf = fd.read(vec_size * sample_size) + if sample_size == 4 : ans = np.frombuffer(buf, dtype='float32') + elif sample_size == 8 : ans = np.frombuffer(buf, dtype='float64') + else : raise BadSampleSize + return ans + else: # ascii, + arr = (binary + fd.readline().decode()).strip().split() + try: + arr.remove('['); arr.remove(']') # optionally + except ValueError: + pass + ans = np.array(arr, dtype=float) + if fd is not file_or_fd : fd.close() # cleanup + return ans + +# Writing, +def write_vec_flt(file_or_fd, v, key=''): + """ write_vec_flt(f, v, key='') + Write a binary kaldi vector to filename or stream. Supports 32bit and 64bit floats. + Arguments: + file_or_fd : filename or opened file descriptor for writing, + v : the vector to be stored, + key (optional) : used for writing ark-file, the utterance-id gets written before the vector. + + Example of writing single vector: + kaldi_io.write_vec_flt(filename, vec) + + Example of writing arkfile: + with open(ark_file,'w') as f: + for key,vec in dict.iteritems(): + kaldi_io.write_vec_flt(f, vec, key=key) + """ + fd = open_or_fd(file_or_fd, mode='wb') + if sys.version_info[0] == 3: assert(fd.mode == 'wb') + try: + if key != '' : fd.write((key+' ').encode("latin1")) # ark-files have keys (utterance-id), + fd.write('\0B'.encode()) # we write binary! + # Data-type, + if v.dtype == 'float32': fd.write('FV '.encode()) + elif v.dtype == 'float64': fd.write('DV '.encode()) + else: raise UnsupportedDataType("'%s', please use 'float32' or 'float64'" % v.dtype) + # Dim, + fd.write('\04'.encode()) + fd.write(struct.pack(np.dtype('uint32').char, v.shape[0])) # dim + # Data, + fd.write(v.tobytes()) + finally: + if fd is not file_or_fd : fd.close() + + +################################################# +# Float matrices (features, transformations, ...), + +# Reading, +def read_mat_scp(file_or_fd): + """ generator(key,mat) = read_mat_scp(file_or_fd) + Returns generator of (key,matrix) tuples, read according to kaldi scp. + file_or_fd : scp, gzipped scp, pipe or opened file descriptor. + + Iterate the scp: + for key,mat in kaldi_io.read_mat_scp(file): + ... + + Read scp to a 'dictionary': + d = { key:mat for key,mat in kaldi_io.read_mat_scp(file) } + """ + fd = open_or_fd(file_or_fd) + try: + for line in fd: + (key,rxfile) = line.decode().split(' ') + mat = read_mat(rxfile) + yield key, mat + finally: + if fd is not file_or_fd : fd.close() + +def read_mat_ark(file_or_fd): + """ generator(key,mat) = read_mat_ark(file_or_fd) + Returns generator of (key,matrix) tuples, read from ark file/stream. + file_or_fd : scp, gzipped scp, pipe or opened file descriptor. + + Iterate the ark: + for key,mat in kaldi_io.read_mat_ark(file): + ... + + Read ark to a 'dictionary': + d = { key:mat for key,mat in kaldi_io.read_mat_ark(file) } + """ + fd = open_or_fd(file_or_fd) + try: + key = read_key(fd) + while key: + mat = read_mat(fd) + yield key, mat + key = read_key(fd) + finally: + if fd is not file_or_fd : fd.close() + +def read_mat(file_or_fd): + """ [mat] = read_mat(file_or_fd) + Reads single kaldi matrix, supports ascii and binary. + file_or_fd : file, gzipped file, pipe or opened file descriptor. + """ + fd = open_or_fd(file_or_fd) + try: + binary = fd.read(2).decode() + if binary == '\0B' : + mat = _read_mat_binary(fd) + else: + assert(binary == ' [') + mat = _read_mat_ascii(fd) + finally: + if fd is not file_or_fd: fd.close() + return mat + +def _read_mat_binary(fd): + # Data type + header = fd.read(3).decode() + # 'CM', 'CM2', 'CM3' are possible values, + if header.startswith('CM'): return _read_compressed_mat(fd, header) + elif header == 'FM ': sample_size = 4 # floats + elif header == 'DM ': sample_size = 8 # doubles + else: raise UnknownMatrixHeader("The header contained '%s'" % header) + assert(sample_size > 0) + # Dimensions + s1, rows, s2, cols = np.frombuffer(fd.read(10), dtype='int8,int32,int8,int32', count=1)[0] + # Read whole matrix + buf = fd.read(rows * cols * sample_size) + if sample_size == 4 : vec = np.frombuffer(buf, dtype='float32') + elif sample_size == 8 : vec = np.frombuffer(buf, dtype='float64') + else : raise BadSampleSize + mat = np.reshape(vec,(rows,cols)) + return mat + +def _read_mat_ascii(fd): + rows = [] + while 1: + line = fd.readline().decode() + if (len(line) == 0) : raise BadInputFormat # eof, should not happen! + if len(line.strip()) == 0 : continue # skip empty line + arr = line.strip().split() + if arr[-1] != ']': + rows.append(np.array(arr,dtype='float32')) # not last line + else: + rows.append(np.array(arr[:-1],dtype='float32')) # last line + mat = np.vstack(rows) + return mat + + +def _read_compressed_mat(fd, format): + """ Read a compressed matrix, + see: https://github.com/kaldi-asr/kaldi/blob/master/src/matrix/compressed-matrix.h + methods: CompressedMatrix::Read(...), CompressedMatrix::CopyToMat(...), + """ + assert(format == 'CM ') # The formats CM2, CM3 are not supported... + + # Format of header 'struct', + global_header = np.dtype([('minvalue','float32'),('range','float32'),('num_rows','int32'),('num_cols','int32')]) # member '.format' is not written, + per_col_header = np.dtype([('percentile_0','uint16'),('percentile_25','uint16'),('percentile_75','uint16'),('percentile_100','uint16')]) + + # Mapping for percentiles in col-headers, + def uint16_to_float(value, min, range): + return np.float32(min + range * 1.52590218966964e-05 * value) + + # Mapping for matrix elements, + def uint8_to_float_v2(vec, p0, p25, p75, p100): + # Split the vector by masks, + mask_0_64 = (vec <= 64); + mask_193_255 = (vec > 192); + mask_65_192 = (~(mask_0_64 | mask_193_255)); + # Sanity check (useful but slow...), + # assert(len(vec) == np.sum(np.hstack([mask_0_64,mask_65_192,mask_193_255]))) + # assert(len(vec) == np.sum(np.any([mask_0_64,mask_65_192,mask_193_255], axis=0))) + # Build the float vector, + ans = np.empty(len(vec), dtype='float32') + ans[mask_0_64] = p0 + (p25 - p0) / 64. * vec[mask_0_64] + ans[mask_65_192] = p25 + (p75 - p25) / 128. * (vec[mask_65_192] - 64) + ans[mask_193_255] = p75 + (p100 - p75) / 63. * (vec[mask_193_255] - 192) + return ans + + # Read global header, + globmin, globrange, rows, cols = np.frombuffer(fd.read(16), dtype=global_header, count=1)[0] + + # The data is structed as [Colheader, ... , Colheader, Data, Data , .... ] + # { cols }{ size } + col_headers = np.frombuffer(fd.read(cols*8), dtype=per_col_header, count=cols) + data = np.reshape(np.frombuffer(fd.read(cols*rows), dtype='uint8', count=cols*rows), newshape=(cols,rows)) # stored as col-major, + + mat = np.empty((cols,rows), dtype='float32') + for i, col_header in enumerate(col_headers): + col_header_flt = [ uint16_to_float(percentile, globmin, globrange) for percentile in col_header ] + mat[i] = uint8_to_float_v2(data[i], *col_header_flt) + + return mat.T # transpose! col-major -> row-major, + +def write_ark_scp(key, mat, ark_fout, scp_out): + mat_offset = write_mat(ark_fout, mat, key) + scp_line = '{}\t{}:{}'.format(key, ark_fout.name, mat_offset) + scp_out.write(scp_line) + scp_out.write('\n') + +# Writing, +def write_mat(file_or_fd, m, key=''): + """ write_mat(f, m, key='') + Write a binary kaldi matrix to filename or stream. Supports 32bit and 64bit floats. + Arguments: + file_or_fd : filename of opened file descriptor for writing, + m : the matrix to be stored, + key (optional) : used for writing ark-file, the utterance-id gets written before the matrix. + + Example of writing single matrix: + kaldi_io.write_mat(filename, mat) + + Example of writing arkfile: + with open(ark_file,'w') as f: + for key,mat in dict.iteritems(): + kaldi_io.write_mat(f, mat, key=key) + """ + mat_offset = 0 + fd = open_or_fd(file_or_fd, mode='wb') + if sys.version_info[0] == 3: assert(fd.mode == 'wb') + try: + if key != '' : fd.write((key+' ').encode("latin1")) # ark-files have keys (utterance-id), + mat_offset = fd.tell() + fd.write('\0B'.encode()) # we write binary! + # Data-type, + if m.dtype == 'float32': fd.write('FM '.encode()) + elif m.dtype == 'float64': fd.write('DM '.encode()) + else: raise UnsupportedDataType("'%s', please use 'float32' or 'float64'" % m.dtype) + # Dims, + fd.write('\04'.encode()) + fd.write(struct.pack(np.dtype('uint32').char, m.shape[0])) # rows + fd.write('\04'.encode()) + fd.write(struct.pack(np.dtype('uint32').char, m.shape[1])) # cols + # Data, + fd.write(m.tobytes()) + finally: + if fd is not file_or_fd : fd.close() + return mat_offset + +################################################# +# 'Posterior' kaldi type (posteriors, confusion network, nnet1 training targets, ...) +# Corresponds to: vector > > +# - outer vector: time axis +# - inner vector: records at the time +# - tuple: int = index, float = value +# + +def read_cnet_ark(file_or_fd): + """ Alias of function 'read_post_ark()', 'cnet' = confusion network """ + return read_post_ark(file_or_fd) + +def read_post_ark(file_or_fd): + """ generator(key,vec>) = read_post_ark(file) + Returns generator of (key,posterior) tuples, read from ark file. + file_or_fd : ark, gzipped ark, pipe or opened file descriptor. + + Iterate the ark: + for key,post in kaldi_io.read_post_ark(file): + ... + + Read ark to a 'dictionary': + d = { key:post for key,post in kaldi_io.read_post_ark(file) } + """ + fd = open_or_fd(file_or_fd) + try: + key = read_key(fd) + while key: + post = read_post(fd) + yield key, post + key = read_key(fd) + finally: + if fd is not file_or_fd: fd.close() + +def read_post(file_or_fd): + """ [post] = read_post(file_or_fd) + Reads single kaldi 'Posterior' in binary format. + + The 'Posterior' is C++ type 'vector > >', + the outer-vector is usually time axis, inner-vector are the records + at given time, and the tuple is composed of an 'index' (integer) + and a 'float-value'. The 'float-value' can represent a probability + or any other numeric value. + + Returns vector of vectors of tuples. + """ + fd = open_or_fd(file_or_fd) + ans=[] + binary = fd.read(2).decode(); assert(binary == '\0B'); # binary flag + assert(fd.read(1).decode() == '\4'); # int-size + outer_vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # number of frames (or bins) + + # Loop over 'outer-vector', + for i in range(outer_vec_size): + assert(fd.read(1).decode() == '\4'); # int-size + inner_vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # number of records for frame (or bin) + data = np.frombuffer(fd.read(inner_vec_size*10), dtype=[('size_idx','int8'),('idx','int32'),('size_post','int8'),('post','float32')], count=inner_vec_size) + assert(data[0]['size_idx'] == 4) + assert(data[0]['size_post'] == 4) + ans.append(data[['idx','post']].tolist()) + + if fd is not file_or_fd: fd.close() + return ans + + +################################################# +# Kaldi Confusion Network bin begin/end times, +# (kaldi stores CNs time info separately from the Posterior). +# + +def read_cntime_ark(file_or_fd): + """ generator(key,vec>) = read_cntime_ark(file_or_fd) + Returns generator of (key,cntime) tuples, read from ark file. + file_or_fd : file, gzipped file, pipe or opened file descriptor. + + Iterate the ark: + for key,time in kaldi_io.read_cntime_ark(file): + ... + + Read ark to a 'dictionary': + d = { key:time for key,time in kaldi_io.read_post_ark(file) } + """ + fd = open_or_fd(file_or_fd) + try: + key = read_key(fd) + while key: + cntime = read_cntime(fd) + yield key, cntime + key = read_key(fd) + finally: + if fd is not file_or_fd : fd.close() + +def read_cntime(file_or_fd): + """ [cntime] = read_cntime(file_or_fd) + Reads single kaldi 'Confusion Network time info', in binary format: + C++ type: vector >. + (begin/end times of bins at the confusion network). + + Binary layout is ' ...' + + file_or_fd : file, gzipped file, pipe or opened file descriptor. + + Returns vector of tuples. + """ + fd = open_or_fd(file_or_fd) + binary = fd.read(2).decode(); assert(binary == '\0B'); # assuming it's binary + + assert(fd.read(1).decode() == '\4'); # int-size + vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # number of frames (or bins) + + data = np.frombuffer(fd.read(vec_size*10), dtype=[('size_beg','int8'),('t_beg','float32'),('size_end','int8'),('t_end','float32')], count=vec_size) + assert(data[0]['size_beg'] == 4) + assert(data[0]['size_end'] == 4) + ans = data[['t_beg','t_end']].tolist() # Return vector of tuples (t_beg,t_end), + + if fd is not file_or_fd : fd.close() + return ans + + +################################################# +# Segments related, +# + +# Segments as 'Bool vectors' can be handy, +# - for 'superposing' the segmentations, +# - for frame-selection in Speaker-ID experiments, +def read_segments_as_bool_vec(segments_file): + """ [ bool_vec ] = read_segments_as_bool_vec(segments_file) + using kaldi 'segments' file for 1 wav, format : ' ' + - t-beg, t-end is in seconds, + - assumed 100 frames/second, + """ + segs = np.loadtxt(segments_file, dtype='object,object,f,f', ndmin=1) + # Sanity checks, + assert(len(segs) > 0) # empty segmentation is an error, + assert(len(np.unique([rec[1] for rec in segs ])) == 1) # segments with only 1 wav-file, + # Convert time to frame-indexes, + start = np.rint([100 * rec[2] for rec in segs]).astype(int) + end = np.rint([100 * rec[3] for rec in segs]).astype(int) + # Taken from 'read_lab_to_bool_vec', htk.py, + frms = np.repeat(np.r_[np.tile([False,True], len(end)), False], + np.r_[np.c_[start - np.r_[0, end[:-1]], end-start].flat, 0]) + assert np.sum(end-start) == np.sum(frms) + return frms + diff --git a/speech/speech_recognition/transformer/pytorch/wenet/dataset/processor.py b/speech/speech_recognition/transformer/pytorch/wenet/dataset/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..00a3488fb9a39fec51a770279a50432e3d9c2b98 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/dataset/processor.py @@ -0,0 +1,619 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import json +import random +import re +import tarfile +from subprocess import PIPE, Popen +from urllib.parse import urlparse + +import torch +import torchaudio +import torchaudio.compliance.kaldi as kaldi +from torch.nn.utils.rnn import pad_sequence + +AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) + + +def url_opener(data): + """ Give url or local file, return file descriptor + Inplace operation. + + Args: + data(Iterable[str]): url or local file list + + Returns: + Iterable[{src, stream}] + """ + for sample in data: + assert 'src' in sample + # TODO(Binbin Zhang): support HTTP + url = sample['src'] + try: + pr = urlparse(url) + # local file + if pr.scheme == '' or pr.scheme == 'file': + stream = open(url, 'rb') + # network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP + else: + cmd = f'curl -s -L {url}' + process = Popen(cmd, shell=True, stdout=PIPE) + sample.update(process=process) + stream = process.stdout + sample.update(stream=stream) + yield sample + except Exception as ex: + logging.warning('Failed to open {}'.format(url)) + + +def tar_file_and_group(data): + """ Expand a stream of open tar files into a stream of tar file contents. + And groups the file with same prefix + + Args: + data: Iterable[{src, stream}] + + Returns: + Iterable[{key, wav, txt, sample_rate}] + """ + for sample in data: + assert 'stream' in sample + stream = tarfile.open(fileobj=sample['stream'], mode="r|*") + prev_prefix = None + example = {} + valid = True + for tarinfo in stream: + name = tarinfo.name + pos = name.rfind('.') + assert pos > 0 + prefix, postfix = name[:pos], name[pos + 1:] + if prev_prefix is not None and prefix != prev_prefix: + example['key'] = prev_prefix + if valid: + yield example + example = {} + valid = True + with stream.extractfile(tarinfo) as file_obj: + try: + if postfix == 'txt': + example['txt'] = file_obj.read().decode('utf8').strip() + elif postfix in AUDIO_FORMAT_SETS: + waveform, sample_rate = torchaudio.load(file_obj) + example['wav'] = waveform + example['sample_rate'] = sample_rate + else: + example[postfix] = file_obj.read() + except Exception as ex: + valid = False + logging.warning('error to parse {}'.format(name)) + prev_prefix = prefix + if prev_prefix is not None: + example['key'] = prev_prefix + yield example + stream.close() + if 'process' in sample: + sample['process'].communicate() + sample['stream'].close() + + +def parse_raw(data): + """ Parse key/wav/txt from json line + + Args: + data: Iterable[str], str is a json line has key/wav/txt + + Returns: + Iterable[{key, wav, txt, sample_rate}] + """ + for sample in data: + assert 'src' in sample + json_line = sample['src'] + obj = json.loads(json_line) + assert 'key' in obj + assert 'wav' in obj + assert 'txt' in obj + key = obj['key'] + wav_file = obj['wav'] + txt = obj['txt'] + try: + if 'start' in obj: + assert 'end' in obj + sample_rate = torchaudio.backend.sox_io_backend.info( + wav_file).sample_rate + start_frame = int(obj['start'] * sample_rate) + end_frame = int(obj['end'] * sample_rate) + waveform, _ = torchaudio.backend.sox_io_backend.load( + filepath=wav_file, + num_frames=end_frame - start_frame, + frame_offset=start_frame) + else: + waveform, sample_rate = torchaudio.load(wav_file) + example = dict(key=key, + txt=txt, + wav=waveform, + sample_rate=sample_rate) + yield example + except Exception as ex: + logging.warning('Failed to read {}'.format(wav_file)) + + +def filter(data, + max_length=10240, + min_length=10, + token_max_length=200, + token_min_length=1, + min_output_input_ratio=0.0005, + max_output_input_ratio=1): + """ Filter sample according to feature and label length + Inplace operation. + + Args:: + data: Iterable[{key, wav, label, sample_rate}] + max_length: drop utterance which is greater than max_length(10ms) + min_length: drop utterance which is less than min_length(10ms) + token_max_length: drop utterance which is greater than + token_max_length, especially when use char unit for + english modeling + token_min_length: drop utterance which is + less than token_max_length + min_output_input_ratio: minimal ration of + token_length / feats_length(10ms) + max_output_input_ratio: maximum ration of + token_length / feats_length(10ms) + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'label' in sample + # sample['wav'] is torch.Tensor, we have 100 frames every second + num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100 + if num_frames < min_length: + continue + if num_frames > max_length: + continue + if len(sample['label']) < token_min_length: + continue + if len(sample['label']) > token_max_length: + continue + if num_frames != 0: + if len(sample['label']) / num_frames < min_output_input_ratio: + continue + if len(sample['label']) / num_frames > max_output_input_ratio: + continue + yield sample + + +def resample(data, resample_rate=16000): + """ Resample data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + resample_rate: target resample rate + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + if sample_rate != resample_rate: + sample['sample_rate'] = resample_rate + sample['wav'] = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=resample_rate)(waveform) + yield sample + + +def speed_perturb(data, speeds=None): + """ Apply speed perturb to the data. + Inplace operation. + + Args: + data: Iterable[{key, wav, label, sample_rate}] + speeds(List[float]): optional speed + + Returns: + Iterable[{key, wav, label, sample_rate}] + """ + if speeds is None: + speeds = [0.9, 1.0, 1.1] + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + speed = random.choice(speeds) + if speed != 1.0: + wav, _ = torchaudio.sox_effects.apply_effects_tensor( + waveform, sample_rate, + [['speed', str(speed)], ['rate', str(sample_rate)]]) + sample['wav'] = wav + + yield sample + + +def compute_fbank(data, + num_mel_bins=23, + frame_length=25, + frame_shift=10, + dither=0.0): + """ Extract fbank + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'key' in sample + assert 'label' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + waveform = waveform * (1 << 15) + # Only keep key, feat, label + mat = kaldi.fbank(waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + sample_frequency=sample_rate) + yield dict(key=sample['key'], label=sample['label'], feat=mat) + + +def compute_mfcc(data, + num_mel_bins=23, + frame_length=25, + frame_shift=10, + dither=0.0, + num_ceps=40, + high_freq=0.0, + low_freq=20.0): + """ Extract mfcc + + Args: + data: Iterable[{key, wav, label, sample_rate}] + + Returns: + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'sample_rate' in sample + assert 'wav' in sample + assert 'key' in sample + assert 'label' in sample + sample_rate = sample['sample_rate'] + waveform = sample['wav'] + waveform = waveform * (1 << 15) + # Only keep key, feat, label + mat = kaldi.mfcc(waveform, + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + num_ceps=num_ceps, + high_freq=high_freq, + low_freq=low_freq, + sample_frequency=sample_rate) + yield dict(key=sample['key'], label=sample['label'], feat=mat) + + +def __tokenize_by_bpe_model(sp, txt): + tokens = [] + # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + pattern = re.compile(r'([\u4e00-\u9fff])') + # Example: + # txt = "你好 ITS'S OKAY 的" + # chars = ["你", "好", " ITS'S OKAY ", "的"] + chars = pattern.split(txt.upper()) + mix_chars = [w for w in chars if len(w.strip()) > 0] + for ch_or_w in mix_chars: + # ch_or_w is a single CJK charater(i.e., "你"), do nothing. + if pattern.fullmatch(ch_or_w) is not None: + tokens.append(ch_or_w) + # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "), + # encode ch_or_w using bpe_model. + else: + for p in sp.encode_as_pieces(ch_or_w): + tokens.append(p) + + return tokens + + +def tokenize(data, + symbol_table, + bpe_model=None, + non_lang_syms=None, + split_with_space=False): + """ Decode text to chars or BPE + Inplace operation + + Args: + data: Iterable[{key, wav, txt, sample_rate}] + + Returns: + Iterable[{key, wav, txt, tokens, label, sample_rate}] + """ + if non_lang_syms is not None: + non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") + else: + non_lang_syms = {} + non_lang_syms_pattern = None + + if bpe_model is not None: + import sentencepiece as spm + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + else: + sp = None + + for sample in data: + assert 'txt' in sample + txt = sample['txt'].strip() + if non_lang_syms_pattern is not None: + parts = non_lang_syms_pattern.split(txt.upper()) + parts = [w for w in parts if len(w.strip()) > 0] + else: + parts = [txt] + + label = [] + tokens = [] + for part in parts: + if part in non_lang_syms: + tokens.append(part) + else: + if bpe_model is not None: + tokens.extend(__tokenize_by_bpe_model(sp, part)) + else: + if split_with_space: + part = part.split(" ") + for ch in part: + if ch == ' ': + ch = "▁" + tokens.append(ch) + + for ch in tokens: + if ch in symbol_table: + label.append(symbol_table[ch]) + elif '' in symbol_table: + label.append(symbol_table['']) + + sample['tokens'] = tokens + sample['label'] = label + yield sample + + +def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80): + """ Do spec augmentation + Inplace operation + + Args: + data: Iterable[{key, feat, label}] + num_t_mask: number of time mask to apply + num_f_mask: number of freq mask to apply + max_t: max width of time mask + max_f: max width of freq mask + max_w: max width of time warp + + Returns + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'feat' in sample + x = sample['feat'] + assert isinstance(x, torch.Tensor) + y = x.clone().detach() + max_frames = y.size(0) + max_freq = y.size(1) + # time mask + for i in range(num_t_mask): + start = random.randint(0, max_frames - 1) + length = random.randint(1, max_t) + end = min(max_frames, start + length) + y[start:end, :] = 0 + # freq mask + for i in range(num_f_mask): + start = random.randint(0, max_freq - 1) + length = random.randint(1, max_f) + end = min(max_freq, start + length) + y[:, start:end] = 0 + sample['feat'] = y + yield sample + + +def spec_sub(data, max_t=20, num_t_sub=3): + """ Do spec substitute + Inplace operation + + Args: + data: Iterable[{key, feat, label}] + max_t: max width of time substitute + num_t_sub: number of time substitute to apply + + Returns + Iterable[{key, feat, label}] + """ + for sample in data: + assert 'feat' in sample + x = sample['feat'] + assert isinstance(x, torch.Tensor) + y = x.clone().detach() + max_frames = y.size(0) + for i in range(num_t_sub): + start = random.randint(0, max_frames - 1) + length = random.randint(1, max_t) + end = min(max_frames, start + length) + # only substitute the earlier time chosen randomly for current time + pos = random.randint(0, start) + y[start:end, :] = x[start - pos:end - pos, :] + sample['feat'] = y + yield sample + + +def shuffle(data, shuffle_size=10000): + """ Local shuffle the data + + Args: + data: Iterable[{key, feat, label}] + shuffle_size: buffer size for shuffle + + Returns: + Iterable[{key, feat, label}] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= shuffle_size: + random.shuffle(buf) + for x in buf: + yield x + buf = [] + # The sample left over + random.shuffle(buf) + for x in buf: + yield x + + +def sort(data, sort_size=500): + """ Sort the data by feature length. + Sort is used after shuffle and before batch, so we can group + utts with similar lengths into a batch, and `sort_size` should + be less than `shuffle_size` + + Args: + data: Iterable[{key, feat, label}] + sort_size: buffer size for sort + + Returns: + Iterable[{key, feat, label}] + """ + + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= sort_size: + buf.sort(key=lambda x: x['feat'].size(0)) + for x in buf: + yield x + buf = [] + # The sample left over + buf.sort(key=lambda x: x['feat'].size(0)) + for x in buf: + yield x + + +def static_batch(data, batch_size=16): + """ Static batch the data by `batch_size` + + Args: + data: Iterable[{key, feat, label}] + batch_size: batch size + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + for sample in data: + buf.append(sample) + if len(buf) >= batch_size: + yield buf + buf = [] + if len(buf) > 0: + yield buf + + +def dynamic_batch(data, max_frames_in_batch=12000): + """ Dynamic batch the data until the total frames in batch + reach `max_frames_in_batch` + + Args: + data: Iterable[{key, feat, label}] + max_frames_in_batch: max_frames in one batch + + Returns: + Iterable[List[{key, feat, label}]] + """ + buf = [] + longest_frames = 0 + for sample in data: + assert 'feat' in sample + assert isinstance(sample['feat'], torch.Tensor) + new_sample_frames = sample['feat'].size(0) + longest_frames = max(longest_frames, new_sample_frames) + frames_after_padding = longest_frames * (len(buf) + 1) + if frames_after_padding > max_frames_in_batch: + yield buf + buf = [sample] + longest_frames = new_sample_frames + else: + buf.append(sample) + if len(buf) > 0: + yield buf + + +def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000): + """ Wrapper for static/dynamic batch + """ + if batch_type == 'static': + return static_batch(data, batch_size) + elif batch_type == 'dynamic': + return dynamic_batch(data, max_frames_in_batch) + else: + logging.fatal('Unsupported batch type {}'.format(batch_type)) + + +def padding(data): + """ Padding the data into training data + + Args: + data: Iterable[List[{key, feat, label}]] + + Returns: + Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)] + """ + for sample in data: + assert isinstance(sample, list) + feats_length = torch.tensor([x['feat'].size(0) for x in sample], + dtype=torch.int32) + order = torch.argsort(feats_length, descending=True) + feats_lengths = torch.tensor( + [sample[i]['feat'].size(0) for i in order], dtype=torch.int32) + sorted_feats = [sample[i]['feat'] for i in order] + sorted_keys = [sample[i]['key'] for i in order] + sorted_labels = [ + torch.tensor(sample[i]['label'], dtype=torch.int64) for i in order + ] + label_lengths = torch.tensor([x.size(0) for x in sorted_labels], + dtype=torch.int32) + + padded_feats = pad_sequence(sorted_feats, + batch_first=True, + padding_value=0) + padding_labels = pad_sequence(sorted_labels, + batch_first=True, + padding_value=-1) + + yield (sorted_keys, padded_feats, padding_labels, feats_lengths, + label_lengths) diff --git a/speech/speech_recognition/transformer/pytorch/wenet/dataset/wav_distortion.py b/speech/speech_recognition/transformer/pytorch/wenet/dataset/wav_distortion.py new file mode 100644 index 0000000000000000000000000000000000000000..35b983635764efeaf6f5295941aa9bb2555a8d1b --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/dataset/wav_distortion.py @@ -0,0 +1,310 @@ +import sys +import random +import math + +import torchaudio +import torch +torchaudio.set_audio_backend("sox_io") + + +def db2amp(db): + return pow(10, db / 20) + +def amp2db(amp): + return 20 * math.log10(amp) + +def make_poly_distortion(conf): + """Generate a db-domain ploynomial distortion function + + f(x) = a * x^m * (1-x)^n + x + + Args: + conf: a dict {'a': #int, 'm': #int, 'n': #int} + + Returns: + The ploynomial function, which could be applied on + a float amplitude value + """ + a = conf['a'] + m = conf['m'] + n = conf['n'] + + def poly_distortion(x): + abs_x = abs(x) + if abs_x < 0.000001: + x = x + else: + db_norm = amp2db(abs_x) / 100 + 1 + if db_norm < 0: + db_norm = 0 + db_norm = a * pow(db_norm, m) * pow((1 - db_norm), n) + db_norm + if db_norm > 1: + db_norm = 1 + db = (db_norm - 1) * 100 + amp = db2amp(db) + if amp >= 0.9997: + amp = 0.9997 + if x > 0: + x = amp + else: + x = -amp + return x + return poly_distortion + +def make_quad_distortion(): + return make_poly_distortion({'a' : 1, 'm' : 1, 'n' : 1}) + +# the amplitude are set to max for all non-zero point +def make_max_distortion(conf): + """Generate a max distortion function + + Args: + conf: a dict {'max_db': float } + 'max_db': the maxium value. + + Returns: + The max function, which could be applied on + a float amplitude value + """ + max_db = conf['max_db'] + if max_db: + max_amp = db2amp(max_db) # < 0.997 + else: + max_amp = 0.997 + + def max_distortion(x): + if x > 0: + x = max_amp + elif x < 0: + x = -max_amp + else: + x = 0.0 + return x + return max_distortion + + + +def make_amp_mask(db_mask=None): + """Get a amplitude domain mask from db domain mask + + Args: + db_mask: Optional. A list of tuple. if None, using default value. + + Returns: + A list of tuple. The amplitude domain mask + """ + if db_mask is None: + db_mask = [(-110, -95), (-90, -80), (-65, -60), (-50, -30), (-15, 0)] + amp_mask = [(db2amp(db[0]), db2amp(db[1])) for db in db_mask] + return amp_mask + +default_mask = make_amp_mask() + + +def generate_amp_mask(mask_num): + """Generate amplitude domain mask randomly in [-100db, 0db] + + Args: + mask_num: the slot number of the mask + + Returns: + A list of tuple. each tuple defines a slot. + e.g. [(-100, -80), (-65, -60), (-50, -30), (-15, 0)] + for #mask_num = 4 + """ + a = [0] * 2 * mask_num + a[0] = 0 + m = [] + for i in range(1, 2 * mask_num): + a[i] = a[i - 1] + random.uniform(0.5, 1) + max_val = a[2 * mask_num - 1] + for i in range(0, mask_num): + l = ((a[2 * i] - max_val) / max_val) * 100 + r = ((a[2 * i + 1] - max_val) / max_val) * 100 + m.append((l, r)) + return make_amp_mask(m) + + +def make_fence_distortion(conf): + """Generate a fence distortion function + + In this fence-like shape function, the values in mask slots are + set to maxium, while the values not in mask slots are set to 0. + Use seperated masks for Positive and negetive amplitude. + + Args: + conf: a dict {'mask_number': int,'max_db': float } + 'mask_number': the slot number in mask. + 'max_db': the maxium value. + + Returns: + The fence function, which could be applied on + a float amplitude value + """ + mask_number = conf['mask_number'] + max_db = conf['max_db'] + max_amp = db2amp(max_db) # 0.997 + if mask_number <= 0 : + positive_mask = default_mask + negative_mask = make_amp_mask([(-50, 0)]) + else: + positive_mask = generate_amp_mask(mask_number) + negative_mask = generate_amp_mask(mask_number) + + def fence_distortion(x): + is_in_mask = False + if x > 0: + for mask in positive_mask: + if x >= mask[0] and x <= mask[1]: + is_in_mask = True + return max_amp + if not is_in_mask: + return 0.0 + elif x < 0: + abs_x = abs(x) + for mask in negative_mask: + if abs_x >= mask[0] and abs_x <= mask[1]: + is_in_mask = True + return max_amp + if not is_in_mask: + return 0.0 + return x + + return fence_distortion + +# +def make_jag_distortion(conf): + """Generate a jag distortion function + + In this jag-like shape function, the values in mask slots are + not changed, while the values not in mask slots are set to 0. + Use seperated masks for Positive and negetive amplitude. + + Args: + conf: a dict {'mask_number': #int} + 'mask_number': the slot number in mask. + + Returns: + The jag function,which could be applied on + a float amplitude value + """ + mask_number = conf['mask_number'] + if mask_number <= 0 : + positive_mask = default_mask + negative_mask = make_amp_mask([(-50, 0)]) + else: + positive_mask = generate_amp_mask(mask_number) + negative_mask = generate_amp_mask(mask_number) + + def jag_distortion(x): + is_in_mask = False + if x > 0: + for mask in positive_mask: + if x >= mask[0] and x <= mask[1]: + is_in_mask = True + return x + if not is_in_mask: + return 0.0 + elif x < 0: + abs_x = abs(x) + for mask in negative_mask: + if abs_x >= mask[0] and abs_x <= mask[1]: + is_in_mask = True + return x + if not is_in_mask: + return 0.0 + return x + + return jag_distortion + +# gaining 20db means amp = amp * 10 +# gaining -20db means amp = amp / 10 +def make_gain_db(conf): + """Generate a db domain gain function + + Args: + conf: a dict {'db': #float} + 'db': the gaining value + + Returns: + The db gain function, which could be applied on + a float amplitude value + """ + db = conf['db'] + + def gain_db(x): + return min(0.997, x * pow(10, db / 20)) + + return gain_db + + +def distort(x, func, rate=0.8): + """Distort a waveform in sample point level + + Args: + x: the origin wavefrom + func: the distort function + rate: sample point-level distort probability + + Returns: + the distorted waveform + """ + for i in range(0, x.shape[1]): + a = random.uniform(0, 1) + if a < rate: + x[0][i] = func(float(x[0][i])) + return x + +def distort_chain(x, funcs, rate=0.8): + for i in range(0, x.shape[1]): + a = random.uniform(0, 1) + if a < rate: + for func in funcs: + x[0][i] = func(float(x[0][i])) + return x + +# x is numpy +def distort_wav_conf(x, distort_type, distort_conf, rate=0.1): + if distort_type == 'gain_db': + gain_db = make_gain_db(distort_conf) + x = distort(x, gain_db) + elif distort_type == 'max_distortion': + max_distortion = make_max_distortion(distort_conf) + x = distort(x, max_distortion, rate=rate) + elif distort_type == 'fence_distortion': + fence_distortion = make_fence_distortion(distort_conf) + x = distort(x, fence_distortion, rate=rate) + elif distort_type == 'jag_distortion': + jag_distortion = make_jag_distortion(distort_conf) + x = distort(x, jag_distortion, rate=rate) + elif distort_type == 'poly_distortion': + poly_distortion = make_poly_distortion(distort_conf) + x = distort(x, poly_distortion, rate=rate) + elif distort_type == 'quad_distortion': + quad_distortion = make_quad_distortion() + x = distort(x, quad_distortion, rate=rate) + elif distort_type == 'none_distortion': + pass + else: + print('unsupport type') + return x + +def distort_wav_conf_and_save(distort_type, distort_conf, rate, wav_in, wav_out): + x, sr = torchaudio.load(wav_in) + x = x.detach().numpy() + out = distort_wav_conf(x, distort_type, distort_conf, rate) + torchaudio.save(wav_out, torch.from_numpy(out), sr) + +if __name__ == "__main__": + distort_type = sys.argv[1] + wav_in = sys.argv[2] + wav_out = sys.argv[3] + conf = None + rate = 0.1 + if distort_type == 'new_jag_distortion': + conf = {'mask_number' : 4} + elif distort_type == 'new_fence_distortion': + conf = {'mask_number' : 1, 'max_db' : -30} + elif distort_type == 'poly_distortion': + conf = {'a' : 4, 'm' : 2, "n" : 2} + distort_wav_conf_and_save(distort_type, conf, rate, wav_in, wav_out) diff --git a/speech/speech_recognition/transformer/pytorch/wenet/transformer/asr_model.py b/speech/speech_recognition/transformer/pytorch/wenet/transformer/asr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..bb3be7f88ed8c4dd034563183c1f5cca79bcaec5 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/transformer/asr_model.py @@ -0,0 +1,774 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import List, Optional, Tuple + +import torch + +from torch.nn.utils.rnn import pad_sequence + +from wenet.transformer.cmvn import GlobalCMVN +from wenet.transformer.ctc import CTC +from wenet.transformer.decoder import (TransformerDecoder, + BiTransformerDecoder) +from wenet.transformer.encoder import ConformerEncoder +from wenet.transformer.encoder import TransformerEncoder +from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss +from wenet.utils.cmvn import load_cmvn +from wenet.utils.common import (IGNORE_ID, add_sos_eos, log_add, + remove_duplicates_and_blank, th_accuracy, + reverse_pad_list) +from wenet.utils.mask import (make_pad_mask, mask_finished_preds, + mask_finished_scores, subsequent_mask) + + +class ASRModel(torch.nn.Module): + """CTC-attention hybrid Encoder-Decoder model""" + def __init__( + self, + vocab_size: int, + encoder: TransformerEncoder, + decoder: TransformerDecoder, + ctc: CTC, + ctc_weight: float = 0.5, + ignore_id: int = IGNORE_ID, + reverse_weight: float = 0.0, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + ): + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + + super().__init__() + # note that eos is the same as sos (equivalent ID) + self.sos = vocab_size - 1 + self.eos = vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.ctc_weight = ctc_weight + self.reverse_weight = reverse_weight + + self.encoder = encoder + self.decoder = decoder + self.ctc = ctc + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor]]: + """Frontend + Encoder + Decoder + Calc loss + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == + text_lengths.shape[0]), (speech.shape, speech_lengths.shape, + text.shape, text_lengths.shape) + # 1. Encoder + encoder_out, encoder_mask = self.encoder(speech, speech_lengths) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + + # 2a. Attention-decoder branch + if self.ctc_weight != 1.0: + loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, + text, text_lengths) + else: + loss_att = None + + # 2b. CTC branch + if self.ctc_weight != 0.0: + loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, + text_lengths) + else: + loss_ctc = None + + if loss_ctc is None: + loss = loss_att + elif loss_att is None: + loss = loss_ctc + else: + loss = self.ctc_weight * loss_ctc + (1 - + self.ctc_weight) * loss_att + return loss, loss_att, loss_ctc + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_mask: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, float]: + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # reverse the seq, used for right to left decoder + r_ys_pad = reverse_pad_list(ys_pad, ys_pad_lens, float(self.ignore_id)) + r_ys_in_pad, r_ys_out_pad = add_sos_eos(r_ys_pad, self.sos, self.eos, + self.ignore_id) + # 1. Forward decoder + decoder_out, r_decoder_out, _ = self.decoder(encoder_out, encoder_mask, + ys_in_pad, ys_in_lens, + r_ys_in_pad, + self.reverse_weight) + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + r_loss_att = torch.tensor(0.0) + if self.reverse_weight > 0.0: + r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad) + loss_att = loss_att * ( + 1 - self.reverse_weight) + r_loss_att * self.reverse_weight + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + return loss_att, acc_att + + def _forward_encoder( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Let's assume B = batch_size + # 1. Encoder + if simulate_streaming and decoding_chunk_size > 0: + encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk( + speech, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) + else: + encoder_out, encoder_mask = self.encoder( + speech, + speech_lengths, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) + return encoder_out, encoder_mask + + def recognize( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + beam_size: int = 10, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + ) -> torch.Tensor: + """ Apply beam search on attention decoder + + Args: + speech (torch.Tensor): (batch, max_len, feat_dim) + speech_length (torch.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + + Returns: + torch.Tensor: decoding result, (batch, max_result_len) + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + device = speech.device + batch_size = speech.shape[0] + + # Let's assume B = batch_size and N = beam_size + # 1. Encoder + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + encoder_dim = encoder_out.size(2) + running_size = batch_size * beam_size + encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( + running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) + encoder_mask = encoder_mask.unsqueeze(1).repeat( + 1, beam_size, 1, 1).view(running_size, 1, + maxlen) # (B*N, 1, max_len) + + hyps = torch.ones([running_size, 1], dtype=torch.long, + device=device).fill_(self.sos) # (B*N, 1) + scores = torch.tensor([0.0] + [-float('inf')] * (beam_size - 1), + dtype=torch.float) + scores = scores.to(device).repeat([batch_size]).unsqueeze(1).to( + device) # (B*N, 1) + end_flag = torch.zeros_like(scores, dtype=torch.bool, device=device) + cache: Optional[List[torch.Tensor]] = None + # 2. Decoder forward step by step + for i in range(1, maxlen + 1): + # Stop if all batch and all beam produce eos + if end_flag.sum() == running_size: + break + # 2.1 Forward decoder step + hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( + running_size, 1, 1).to(device) # (B*N, i, i) + # logp: (B*N, vocab) + logp, cache = self.decoder.forward_one_step( + encoder_out, encoder_mask, hyps, hyps_mask, cache) + # 2.2 First beam prune: select topk best prob at current time + top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) + top_k_logp = mask_finished_scores(top_k_logp, end_flag) + top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos) + # 2.3 Second beam prune: select topk score with history + scores = scores + top_k_logp # (B*N, N), broadcast add + scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) + scores, offset_k_index = scores.topk(k=beam_size) # (B, N) + scores = scores.view(-1, 1) # (B*N, 1) + # 2.4. Compute base index in top_k_index, + # regard top_k_index as (B*N*N),regard offset_k_index as (B*N), + # then find offset_k_index in top_k_index + base_k_index = torch.arange(batch_size, device=device).view( + -1, 1).repeat([1, beam_size]) # (B, N) + base_k_index = base_k_index * beam_size * beam_size + best_k_index = base_k_index.view(-1) + offset_k_index.view( + -1) # (B*N) + + # 2.5 Update best hyps + best_k_pred = torch.index_select(top_k_index.view(-1), + dim=-1, + index=best_k_index) # (B*N) + best_hyps_index = best_k_index // beam_size + last_best_k_hyps = torch.index_select( + hyps, dim=0, index=best_hyps_index) # (B*N, i) + hyps = torch.cat((last_best_k_hyps, best_k_pred.view(-1, 1)), + dim=1) # (B*N, i+1) + + # 2.6 Update end flag + end_flag = torch.eq(hyps[:, -1], self.eos).view(-1, 1) + + # 3. Select best of best + scores = scores.view(batch_size, beam_size) + # TODO: length normalization + best_scores, best_index = scores.max(dim=-1) + best_hyps_index = best_index + torch.arange( + batch_size, dtype=torch.long, device=device) * beam_size + best_hyps = torch.index_select(hyps, dim=0, index=best_hyps_index) + best_hyps = best_hyps[:, 1:] + return best_hyps, best_scores + + def ctc_greedy_search( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + ) -> List[List[int]]: + """ Apply CTC greedy search + + Args: + speech (torch.Tensor): (batch, max_len, feat_dim) + speech_length (torch.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + Returns: + List[List[int]]: best path result + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + batch_size = speech.shape[0] + # Let's assume B = batch_size + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + ctc_probs = self.ctc.log_softmax( + encoder_out) # (B, maxlen, vocab_size) + topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) + topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) + mask = make_pad_mask(encoder_out_lens, maxlen) # (B, maxlen) + topk_index = topk_index.masked_fill_(mask, self.eos) # (B, maxlen) + hyps = [hyp.tolist() for hyp in topk_index] + scores = topk_prob.max(1) + hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] + return hyps, scores + + def _ctc_prefix_beam_search( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + beam_size: int, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + ) -> Tuple[List[List[int]], torch.Tensor]: + """ CTC prefix beam search inner implementation + + Args: + speech (torch.Tensor): (batch, max_len, feat_dim) + speech_length (torch.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + + Returns: + List[List[int]]: nbest results + torch.Tensor: encoder output, (1, max_len, encoder_dim), + it will be used for rescoring in attention rescoring mode + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + batch_size = speech.shape[0] + # For CTC prefix beam search, we only support batch_size=1 + assert batch_size == 1 + # Let's assume B = batch_size and N = beam_size + # 1. Encoder forward and get CTC score + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + ctc_probs = self.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + ctc_probs = ctc_probs.squeeze(0) + # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) + cur_hyps = [(tuple(), (0.0, -float('inf')))] + # 2. CTC beam search step by step + for t in range(0, maxlen): + logp = ctc_probs[t] # (vocab_size,) + # key: prefix, value (pb, pnb), default value(-inf, -inf) + next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) + # 2.1 First beam prune: select topk best + top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) + for s in top_k_index: + s = s.item() + ps = logp[s].item() + for prefix, (pb, pnb) in cur_hyps: + last = prefix[-1] if len(prefix) > 0 else None + if s == 0: # blank + n_pb, n_pnb = next_hyps[prefix] + n_pb = log_add([n_pb, pb + ps, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + elif s == last: + # Update *ss -> *s; + n_pb, n_pnb = next_hyps[prefix] + n_pnb = log_add([n_pnb, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + # Update *s-s -> *ss, - is for blank + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + else: + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + + # 2.2 Second beam prune + next_hyps = sorted(next_hyps.items(), + key=lambda x: log_add(list(x[1])), + reverse=True) + cur_hyps = next_hyps[:beam_size] + hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps] + return hyps, encoder_out + + def ctc_prefix_beam_search( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + beam_size: int, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + ) -> List[int]: + """ Apply CTC prefix beam search + + Args: + speech (torch.Tensor): (batch, max_len, feat_dim) + speech_length (torch.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + + Returns: + List[int]: CTC prefix beam search nbest results + """ + hyps, _ = self._ctc_prefix_beam_search(speech, speech_lengths, + beam_size, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) + return hyps[0] + + def attention_rescoring( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + beam_size: int, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + ctc_weight: float = 0.0, + simulate_streaming: bool = False, + reverse_weight: float = 0.0, + ) -> List[int]: + """ Apply attention rescoring decoding, CTC prefix beam search + is applied first to get nbest, then we resoring the nbest on + attention decoder with corresponding encoder out + + Args: + speech (torch.Tensor): (batch, max_len, feat_dim) + speech_length (torch.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + reverse_weight (float): right to left decoder weight + ctc_weight (float): ctc score weight + + Returns: + List[int]: Attention rescoring result + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + if reverse_weight > 0.0: + # decoder should be a bitransformer decoder if reverse_weight > 0.0 + assert hasattr(self.decoder, 'right_decoder') + device = speech.device + batch_size = speech.shape[0] + # For attention rescoring we only support batch_size=1 + assert batch_size == 1 + # encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size + hyps, encoder_out = self._ctc_prefix_beam_search( + speech, speech_lengths, beam_size, decoding_chunk_size, + num_decoding_left_chunks, simulate_streaming) + + assert len(hyps) == beam_size + hyps_pad = pad_sequence([ + torch.tensor(hyp[0], device=device, dtype=torch.long) + for hyp in hyps + ], True, self.ignore_id) # (beam_size, max_hyps_len) + ori_hyps_pad = hyps_pad + hyps_lens = torch.tensor([len(hyp[0]) for hyp in hyps], + device=device, + dtype=torch.long) # (beam_size,) + hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining + encoder_out = encoder_out.repeat(beam_size, 1, 1) + encoder_mask = torch.ones(beam_size, + 1, + encoder_out.size(1), + dtype=torch.bool, + device=device) + # used for right to left decoder + r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id) + r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, + self.ignore_id) + decoder_out, r_decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, + reverse_weight) # (beam_size, max_hyps_len, vocab_size) + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + decoder_out = decoder_out.cpu().numpy() + # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a + # conventional transformer decoder. + r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) + r_decoder_out = r_decoder_out.cpu().numpy() + # Only use decoder score for rescoring + best_score = -float('inf') + best_index = 0 + for i, hyp in enumerate(hyps): + score = 0.0 + for j, w in enumerate(hyp[0]): + score += decoder_out[i][j][w] + score += decoder_out[i][len(hyp[0])][self.eos] + # add right to left decoder score + if reverse_weight > 0: + r_score = 0.0 + for j, w in enumerate(hyp[0]): + r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w] + r_score += r_decoder_out[i][len(hyp[0])][self.eos] + score = score * (1 - reverse_weight) + r_score * reverse_weight + # add ctc score + score += hyp[1] * ctc_weight + if score > best_score: + best_score = score + best_index = i + return hyps[best_index][0], best_score + + @torch.jit.export + def subsampling_rate(self) -> int: + """ Export interface for c++ call, return subsampling_rate of the + model + """ + return self.encoder.embed.subsampling_rate + + @torch.jit.export + def right_context(self) -> int: + """ Export interface for c++ call, return right_context of the model + """ + return self.encoder.embed.right_context + + @torch.jit.export + def sos_symbol(self) -> int: + """ Export interface for c++ call, return sos symbol id of the model + """ + return self.sos + + @torch.jit.export + def eos_symbol(self) -> int: + """ Export interface for c++ call, return eos symbol id of the model + """ + return self.eos + + @torch.jit.export + def forward_encoder_chunk( + self, + xs: torch.Tensor, + offset: int, + required_cache_size: int, + att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ Export interface for c++ call, give input chunk xs, and return + output from time 0 to current chunk. + + Args: + xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + offset (int): current offset in encoder output time stamp + required_cache_size (int): cache size required for next chunk + compuation + >=0: actual cache size + <0: means all history cache is required + att_cache (torch.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, + (elayers, b=1, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + + Returns: + torch.Tensor: output of current input xs, + with shape (b=1, chunk_size, hidden-dim). + torch.Tensor: new attention cache required for next chunk, with + dynamic shape (elayers, head, ?, d_k * 2) + depending on required_cache_size. + torch.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. + + """ + return self.encoder.forward_chunk(xs, offset, required_cache_size, + att_cache, cnn_cache) + + @torch.jit.export + def ctc_activation(self, xs: torch.Tensor) -> torch.Tensor: + """ Export interface for c++ call, apply linear transform and log + softmax before ctc + Args: + xs (torch.Tensor): encoder output + + Returns: + torch.Tensor: activation before ctc + + """ + return self.ctc.log_softmax(xs) + + @torch.jit.export + def is_bidirectional_decoder(self) -> bool: + """ + Returns: + torch.Tensor: decoder output + """ + if hasattr(self.decoder, 'right_decoder'): + return True + else: + return False + + @torch.jit.export + def forward_attention_decoder( + self, + hyps: torch.Tensor, + hyps_lens: torch.Tensor, + encoder_out: torch.Tensor, + reverse_weight: float = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ Export interface for c++ call, forward decoder with multiple + hypothesis from ctc prefix beam search and one encoder output + Args: + hyps (torch.Tensor): hyps from ctc prefix beam search, already + pad sos at the begining + hyps_lens (torch.Tensor): length of each hyp in hyps + encoder_out (torch.Tensor): corresponding encoder output + r_hyps (torch.Tensor): hyps from ctc prefix beam search, already + pad eos at the begining which is used fo right to left decoder + reverse_weight: used for verfing whether used right to left decoder, + > 0 will use. + + Returns: + torch.Tensor: decoder output + """ + assert encoder_out.size(0) == 1 + num_hyps = hyps.size(0) + assert hyps_lens.size(0) == num_hyps + encoder_out = encoder_out.repeat(num_hyps, 1, 1) + encoder_mask = torch.ones(num_hyps, + 1, + encoder_out.size(1), + dtype=torch.bool, + device=encoder_out.device) + + # input for right to left decoder + # this hyps_lens has count token, we need minus it. + r_hyps_lens = hyps_lens - 1 + # this hyps has included token, so it should be + # convert the original hyps. + r_hyps = hyps[:, 1:] + # >>> r_hyps + # >>> tensor([[ 1, 2, 3], + # >>> [ 9, 8, 4], + # >>> [ 2, -1, -1]]) + # >>> r_hyps_lens + # >>> tensor([3, 3, 1]) + + # NOTE(Mddct): `pad_sequence` is not supported by ONNX, it is used + # in `reverse_pad_list` thus we have to refine the below code. + # Issue: https://github.com/wenet-e2e/wenet/issues/1113 + # Equal to: + # >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id)) + # >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id) + max_len = torch.max(r_hyps_lens) + index_range = torch.arange(0, max_len, 1).to(encoder_out.device) + seq_len_expand = r_hyps_lens.unsqueeze(1) + seq_mask = seq_len_expand > index_range # (beam, max_len) + # >>> seq_mask + # >>> tensor([[ True, True, True], + # >>> [ True, True, True], + # >>> [ True, False, False]]) + index = (seq_len_expand - 1) - index_range # (beam, max_len) + # >>> index + # >>> tensor([[ 2, 1, 0], + # >>> [ 2, 1, 0], + # >>> [ 0, -1, -2]]) + index = index * seq_mask + # >>> index + # >>> tensor([[2, 1, 0], + # >>> [2, 1, 0], + # >>> [0, 0, 0]]) + r_hyps = torch.gather(r_hyps, 1, index) + # >>> r_hyps + # >>> tensor([[3, 2, 1], + # >>> [4, 8, 9], + # >>> [2, 2, 2]]) + r_hyps = torch.where(seq_mask, r_hyps, self.eos) + # >>> r_hyps + # >>> tensor([[3, 2, 1], + # >>> [4, 8, 9], + # >>> [2, eos, eos]]) + r_hyps = torch.cat([hyps[:, 0:1], r_hyps], dim=1) + # >>> r_hyps + # >>> tensor([[sos, 3, 2, 1], + # >>> [sos, 4, 8, 9], + # >>> [sos, 2, eos, eos]]) + + decoder_out, r_decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, + reverse_weight) # (num_hyps, max_hyps_len, vocab_size) + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + + # right to left decoder may be not used during decoding process, + # which depends on reverse_weight param. + # r_dccoder_out will be 0.0, if reverse_weight is 0.0 + r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) + return decoder_out, r_decoder_out + + +def init_asr_model(configs): + if configs['cmvn_file'] is not None: + mean, istd = load_cmvn(configs['cmvn_file'], configs['is_json_cmvn']) + global_cmvn = GlobalCMVN( + torch.from_numpy(mean).float(), + torch.from_numpy(istd).float()) + else: + global_cmvn = None + + input_dim = configs['input_dim'] + vocab_size = configs['output_dim'] + + encoder_type = configs.get('encoder', 'conformer') + decoder_type = configs.get('decoder', 'bitransformer') + + if encoder_type == 'conformer': + encoder = ConformerEncoder(input_dim, + global_cmvn=global_cmvn, + **configs['encoder_conf']) + else: + encoder = TransformerEncoder(input_dim, + global_cmvn=global_cmvn, + **configs['encoder_conf']) + if decoder_type == 'transformer': + decoder = TransformerDecoder(vocab_size, encoder.output_size(), + **configs['decoder_conf']) + else: + assert 0.0 < configs['model_conf']['reverse_weight'] < 1.0 + assert configs['decoder_conf']['r_num_blocks'] > 0 + decoder = BiTransformerDecoder(vocab_size, encoder.output_size(), + **configs['decoder_conf']) + ctc = CTC(vocab_size, encoder.output_size()) + model = ASRModel( + vocab_size=vocab_size, + encoder=encoder, + decoder=decoder, + ctc=ctc, + **configs['model_conf'], + ) + return model diff --git a/speech/speech_recognition/transformer/pytorch/wenet/transformer/attention.py b/speech/speech_recognition/transformer/pytorch/wenet/transformer/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..908a710826a47e9a4c5b9ae79fe714b45f258432 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/transformer/attention.py @@ -0,0 +1,301 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +"""Multi-Head Attention layer definition.""" + +import math +from typing import Tuple + +import torch +from torch import nn + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + def __init__(self, n_head: int, n_feat: int, dropout_rate: float): + """Construct an MultiHeadedAttention object.""" + super().__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor, size + (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor, size + (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor, size + (#batch, n_head, time2, d_k). + + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention( + self, value: torch.Tensor, scores: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) + ) -> torch.Tensor: + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value, size + (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score, size + (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask, size (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + # NOTE(xcsong): When will `if mask.size(2) > 0` be True? + # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the + # 1st chunk to ease the onnx export.] + # 2. pytorch training + if mask.size(2) > 0 : # time2 > 0 + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + # For last chunk, time2 might be larger than scores.size(-1) + mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2) + scores = scores.masked_fill(mask, -float('inf')) + attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0) # (batch, head, time1, time2) + # NOTE(xcsong): When will `if mask.size(2) > 0` be False? + # 1. onnx(16/-1, -1/-1, 16/0) + # 2. jit (16/-1, -1/-1, 16/0, 16/4) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, + self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, query: torch.Tensor, key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + 1.When applying cross attention between decoder and encoder, + the batch padding mask for input is in (#batch, 1, T) shape. + 2.When applying self attention of encoder, + the mask is in (#batch, T, T) shape. + 3.When applying self attention of decoder, + the mask is in (#batch, L, L) shape. + 4.If the different position in decoder see different block + of the encoder, such as Mocha, the passed in mask could be + in (#batch, L, T) shape. But there is no such case in current + Wenet. + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + + """ + q, k, v = self.forward_qkv(query, key, value) + + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.size(0) > 0: + key_cache, value_cache = torch.split( + cache, cache.size(-1) // 2, dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = torch.cat((k, v), dim=-1) + + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask), new_cache + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x, zero_triu: bool = False): + """Compute relative positinal encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, size). + zero_triu (bool): If true, return the lower triangular part of + the matrix. + Returns: + torch.Tensor: Output tensor. + """ + + zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1), + device=x.device, + dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(x.size()[0], + x.size()[1], + x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x) + + if zero_triu: + ones = torch.ones((x.size(2), x.size(3))) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, time2, size). + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.size(0) > 0: + key_cache, value_cache = torch.split( + cache, cache.size(-1) // 2, dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = torch.cat((k, v), dim=-1) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + # Remove rel_shift since it is useless in speech recognition, + # and it requires special attention for streaming. + # matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask), new_cache diff --git a/speech/speech_recognition/transformer/pytorch/wenet/transformer/cmvn.py b/speech/speech_recognition/transformer/pytorch/wenet/transformer/cmvn.py new file mode 100644 index 0000000000000000000000000000000000000000..9c28e908013f853be63ec562dddba65092909eb0 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/transformer/cmvn.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class GlobalCMVN(torch.nn.Module): + def __init__(self, + mean: torch.Tensor, + istd: torch.Tensor, + norm_var: bool = True): + """ + Args: + mean (torch.Tensor): mean stats + istd (torch.Tensor): inverse std, std which is 1.0 / std + """ + super().__init__() + assert mean.shape == istd.shape + self.norm_var = norm_var + # The buffer can be accessed from this module using self.mean + self.register_buffer("mean", mean) + self.register_buffer("istd", istd) + + def forward(self, x: torch.Tensor): + """ + Args: + x (torch.Tensor): (batch, max_len, feat_dim) + + Returns: + (torch.Tensor): normalized feature + """ + x = x - self.mean + if self.norm_var: + x = x * self.istd + return x diff --git a/speech/speech_recognition/transformer/pytorch/wenet/transformer/convolution.py b/speech/speech_recognition/transformer/pytorch/wenet/transformer/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..d25b2d36f701ea18b0b477bf3a6628c107a6e39c --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/transformer/convolution.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# Copyright 2021 Mobvoi Inc. All Rights Reserved. +# Author: di.wu@mobvoi.com (DI WU) +"""ConvolutionModule definition.""" + +from typing import Tuple + +import torch +from torch import nn +from typeguard import check_argument_types + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model.""" + def __init__(self, + channels: int, + kernel_size: int = 15, + activation: nn.Module = nn.ReLU(), + norm: str = "batch_norm", + causal: bool = False, + bias: bool = True): + """Construct an ConvolutionModule object. + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernel size of conv layers. + causal (int): Whether use causal convolution or not + """ + assert check_argument_types() + super().__init__() + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + # self.lorder is used to distinguish if it's a causal convolution, + # if self.lorder > 0: it's a causal convolution, the input will be + # padded with self.lorder frames on the left in forward. + # else: it's a symmetrical convolution + if causal: + padding = 0 + self.lorder = kernel_size - 1 + else: + # kernel_size should be an odd number for none causal convolution + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.lorder = 0 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + + assert norm in ['batch_norm', 'layer_norm'] + if norm == "batch_norm": + self.use_layer_norm = False + self.norm = nn.BatchNorm1d(channels) + else: + self.use_layer_norm = True + self.norm = nn.LayerNorm(channels) + + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = activation + + def forward( + self, + x: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + cache: torch.Tensor = torch.zeros((0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute convolution module. + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), + (0, 0, 0) means fake mask. + cache (torch.Tensor): left context cache, it is only + used in causal convolution (#batch, channels, cache_t), + (0, 0, 0) meas fake cache. + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) # (#batch, channels, time) + + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad, 0.0) + + if self.lorder > 0: + if cache.size(2) == 0: # cache_t == 0 + x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) + else: + assert cache.size(0) == x.size(0) # equal batch + assert cache.size(1) == x.size(1) # equal channel + x = torch.cat((cache, x), dim=2) + assert (x.size(2) > self.lorder) + new_cache = x[:, :, -self.lorder:] + else: + # It's better we just return None if no cache is requried, + # However, for JIT export, here we just fake one tensor instead of + # None. + new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + # x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + x = glu_torch_imply(x, dim=1) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.activation(self.norm(x)) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.pointwise_conv2(x) + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad, 0.0) + + return x.transpose(1, 2), new_cache + +def glu_torch_imply(inputs,dim:int=-1): + x,y = torch.chunk(inputs, 2, dim=dim) + outputs = x * torch.nn.functional.sigmoid(y) + return outputs \ No newline at end of file diff --git a/speech/speech_recognition/transformer/pytorch/wenet/transformer/ctc.py b/speech/speech_recognition/transformer/pytorch/wenet/transformer/ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..e2b129ca28bbc1df99902d690a6067ada2cbf34c --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/transformer/ctc.py @@ -0,0 +1,69 @@ +import torch +import torch.nn.functional as F +from typeguard import check_argument_types + + +class CTC(torch.nn.Module): + """CTC module""" + def __init__( + self, + odim: int, + encoder_output_size: int, + dropout_rate: float = 0.0, + reduce: bool = True, + ): + """ Construct CTC module + Args: + odim: dimension of outputs + encoder_output_size: number of encoder projection units + dropout_rate: dropout rate (0.0 ~ 1.0) + reduce: reduce the CTC loss into a scalar + """ + assert check_argument_types() + super().__init__() + eprojs = encoder_output_size + self.dropout_rate = dropout_rate + self.ctc_lo = torch.nn.Linear(eprojs, odim) + + reduction_type = "sum" if reduce else "none" + self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type) + + def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor, + ys_pad: torch.Tensor, ys_lens: torch.Tensor) -> torch.Tensor: + """Calculate CTC loss. + + Args: + hs_pad: batch of padded hidden state sequences (B, Tmax, D) + hlens: batch of lengths of hidden state sequences (B) + ys_pad: batch of padded character id sequence tensor (B, Lmax) + ys_lens: batch of lengths of character sequence (B) + """ + # hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab) + ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) + # ys_hat: (B, L, D) -> (L, B, D) + ys_hat = ys_hat.transpose(0, 1) + ys_hat = ys_hat.log_softmax(2) + loss = self.ctc_loss(ys_hat, ys_pad, hlens, ys_lens) + # Batch-size average + loss = loss / ys_hat.size(1) + return loss + + def log_softmax(self, hs_pad: torch.Tensor) -> torch.Tensor: + """log_softmax of frame activations + + Args: + Tensor hs_pad: 3d tensor (B, Tmax, eprojs) + Returns: + torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim) + """ + return F.log_softmax(self.ctc_lo(hs_pad), dim=2) + + def argmax(self, hs_pad: torch.Tensor) -> torch.Tensor: + """argmax of frame activations + + Args: + torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) + Returns: + torch.Tensor: argmax applied 2d tensor (B, Tmax) + """ + return torch.argmax(self.ctc_lo(hs_pad), dim=2) diff --git a/speech/speech_recognition/transformer/pytorch/wenet/transformer/decoder.py b/speech/speech_recognition/transformer/pytorch/wenet/transformer/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..00dda3036581e2370b14fb993bdfea25c5b0d676 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/transformer/decoder.py @@ -0,0 +1,287 @@ +# Copyright 2021 Mobvoi Inc. All Rights Reserved. +# Author: di.wu@mobvoi.com (DI WU) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +"""Decoder definition.""" +from typing import Tuple, List, Optional + +import torch +from typeguard import check_argument_types + +from wenet.transformer.attention import MultiHeadedAttention +from wenet.transformer.decoder_layer import DecoderLayer +from wenet.transformer.embedding import PositionalEncoding +from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward +from wenet.utils.mask import (subsequent_mask, make_pad_mask) + + +class TransformerDecoder(torch.nn.Module): + """Base class of Transfomer decoder module. + Args: + vocab_size: output dim + encoder_output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the hidden units number of position-wise feedforward + num_blocks: the number of decoder blocks + dropout_rate: dropout rate + self_attention_dropout_rate: dropout rate for attention + input_layer: input layer type + use_output_layer: whether to use output layer + pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + normalize_before: + True: use layer_norm before each sub-block of a layer. + False: use layer_norm after each sub-block of a layer. + concat_after: whether to concat attention layer's input and output + True: x -> x + linear(concat(x, att(x))) + False: x -> x + att(x) + """ + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = "embed", + use_output_layer: bool = True, + normalize_before: bool = True, + concat_after: bool = False, + ): + assert check_argument_types() + super().__init__() + attention_dim = encoder_output_size + + if input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(vocab_size, attention_dim), + PositionalEncoding(attention_dim, positional_dropout_rate), + ) + else: + raise ValueError(f"only 'embed' is supported: {input_layer}") + + self.normalize_before = normalize_before + self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5) + self.use_output_layer = use_output_layer + self.output_layer = torch.nn.Linear(attention_dim, vocab_size) + self.num_blocks = num_blocks + self.decoders = torch.nn.ModuleList([ + DecoderLayer( + attention_dim, + MultiHeadedAttention(attention_heads, attention_dim, + self_attention_dropout_rate), + MultiHeadedAttention(attention_heads, attention_dim, + src_attention_dropout_rate), + PositionwiseFeedForward(attention_dim, linear_units, + dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ) for _ in range(self.num_blocks) + ]) + + def forward( + self, + memory: torch.Tensor, + memory_mask: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + r_ys_in_pad: torch.Tensor = torch.empty(0), + reverse_weight: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward decoder. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoder memory mask, (batch, 1, maxlen_in) + ys_in_pad: padded input token ids, int64 (batch, maxlen_out) + ys_in_lens: input lengths of this batch (batch) + r_ys_in_pad: not used in transformer decoder, in order to unify api + with bidirectional decoder + reverse_weight: not used in transformer decoder, in order to unify + api with bidirectional decode + Returns: + (tuple): tuple containing: + x: decoded token score before softmax (batch, maxlen_out, + vocab_size) if use_output_layer is True, + torch.tensor(0.0), in order to unify api with bidirectional decoder + olens: (batch, ) + """ + tgt = ys_in_pad + maxlen = tgt.size(1) + # tgt_mask: (B, 1, L) + tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1) + tgt_mask = tgt_mask.to(tgt.device) + # m: (1, L, L) + m = subsequent_mask(tgt_mask.size(-1), + device=tgt_mask.device).unsqueeze(0) + # tgt_mask: (B, L, L) + tgt_mask = tgt_mask & m + x, _ = self.embed(tgt) + for layer in self.decoders: + x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory, + memory_mask) + if self.normalize_before: + x = self.after_norm(x) + if self.use_output_layer: + x = self.output_layer(x) + olens = tgt_mask.sum(1) + return x, torch.tensor(0.0), olens + + def forward_one_step( + self, + memory: torch.Tensor, + memory_mask: torch.Tensor, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + cache: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Forward one step. + This is only used for decoding. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoded memory mask, (batch, 1, maxlen_in) + tgt: input token ids, int64 (batch, maxlen_out) + tgt_mask: input token mask, (batch, maxlen_out) + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (include 1.2) + cache: cached output list of (batch, max_time_out-1, size) + Returns: + y, cache: NN output value and cache per `self.decoders`. + y.shape` is (batch, maxlen_out, token) + """ + x, _ = self.embed(tgt) + new_cache = [] + for i, decoder in enumerate(self.decoders): + if cache is None: + c = None + else: + c = cache[i] + x, tgt_mask, memory, memory_mask = decoder(x, + tgt_mask, + memory, + memory_mask, + cache=c) + new_cache.append(x) + if self.normalize_before: + y = self.after_norm(x[:, -1]) + else: + y = x[:, -1] + if self.use_output_layer: + y = torch.log_softmax(self.output_layer(y), dim=-1) + return y, new_cache + + +class BiTransformerDecoder(torch.nn.Module): + """Base class of Transfomer decoder module. + Args: + vocab_size: output dim + encoder_output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the hidden units number of position-wise feedforward + num_blocks: the number of decoder blocks + r_num_blocks: the number of right to left decoder blocks + dropout_rate: dropout rate + self_attention_dropout_rate: dropout rate for attention + input_layer: input layer type + use_output_layer: whether to use output layer + pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + normalize_before: + True: use layer_norm before each sub-block of a layer. + False: use layer_norm after each sub-block of a layer. + concat_after: whether to concat attention layer's input and output + True: x -> x + linear(concat(x, att(x))) + False: x -> x + att(x) + """ + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + r_num_blocks: int = 0, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = "embed", + use_output_layer: bool = True, + normalize_before: bool = True, + concat_after: bool = False, + ): + + assert check_argument_types() + super().__init__() + self.left_decoder = TransformerDecoder( + vocab_size, encoder_output_size, attention_heads, linear_units, + num_blocks, dropout_rate, positional_dropout_rate, + self_attention_dropout_rate, src_attention_dropout_rate, + input_layer, use_output_layer, normalize_before, concat_after) + + self.right_decoder = TransformerDecoder( + vocab_size, encoder_output_size, attention_heads, linear_units, + r_num_blocks, dropout_rate, positional_dropout_rate, + self_attention_dropout_rate, src_attention_dropout_rate, + input_layer, use_output_layer, normalize_before, concat_after) + + def forward( + self, + memory: torch.Tensor, + memory_mask: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + r_ys_in_pad: torch.Tensor, + reverse_weight: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward decoder. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoder memory mask, (batch, 1, maxlen_in) + ys_in_pad: padded input token ids, int64 (batch, maxlen_out) + ys_in_lens: input lengths of this batch (batch) + r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out), + used for right to left decoder + reverse_weight: used for right to left decoder + Returns: + (tuple): tuple containing: + x: decoded token score before softmax (batch, maxlen_out, + vocab_size) if use_output_layer is True, + r_x: x: decoded token score (right to left decoder) + before softmax (batch, maxlen_out, vocab_size) + if use_output_layer is True, + olens: (batch, ) + """ + l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad, + ys_in_lens) + r_x = torch.tensor(0.0) + if reverse_weight > 0.0: + r_x, _, olens = self.right_decoder(memory, memory_mask, r_ys_in_pad, + ys_in_lens) + return l_x, r_x, olens + + def forward_one_step( + self, + memory: torch.Tensor, + memory_mask: torch.Tensor, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + cache: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Forward one step. + This is only used for decoding. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoded memory mask, (batch, 1, maxlen_in) + tgt: input token ids, int64 (batch, maxlen_out) + tgt_mask: input token mask, (batch, maxlen_out) + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (include 1.2) + cache: cached output list of (batch, max_time_out-1, size) + Returns: + y, cache: NN output value and cache per `self.decoders`. + y.shape` is (batch, maxlen_out, token) + """ + return self.left_decoder.forward_one_step(memory, memory_mask, tgt, + tgt_mask, cache) diff --git a/speech/speech_recognition/transformer/pytorch/wenet/transformer/decoder_layer.py b/speech/speech_recognition/transformer/pytorch/wenet/transformer/decoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..cb9ddc7872c700c1f4c7b1dfdbac93cdaffd4251 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/transformer/decoder_layer.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +"""Decoder self-attention layer definition.""" +from typing import Optional, Tuple + +import torch +from torch import nn + + +class DecoderLayer(nn.Module): + """Single decoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + src_attn (torch.nn.Module): Inter-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: to use layer_norm after each sub-block. + concat_after (bool): Whether to concat attention layer's inpu + and output. + True: x -> x + linear(concat(x, att(x))) + False: x -> x + att(x) + """ + def __init__( + self, + size: int, + self_attn: nn.Module, + src_attn: nn.Module, + feed_forward: nn.Module, + dropout_rate: float, + normalize_before: bool = True, + concat_after: bool = False, + ): + """Construct an DecoderLayer object.""" + super().__init__() + self.size = size + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.norm1 = nn.LayerNorm(size, eps=1e-5) + self.norm2 = nn.LayerNorm(size, eps=1e-5) + self.norm3 = nn.LayerNorm(size, eps=1e-5) + self.dropout = nn.Dropout(dropout_rate) + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear1 = nn.Linear(size + size, size) + self.concat_linear2 = nn.Linear(size + size, size) + else: + self.concat_linear1 = nn.Identity() + self.concat_linear2 = nn.Identity() + + def forward( + self, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + memory: torch.Tensor, + memory_mask: torch.Tensor, + cache: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute decoded features. + + Args: + tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). + tgt_mask (torch.Tensor): Mask for input tensor + (#batch, maxlen_out). + memory (torch.Tensor): Encoded memory + (#batch, maxlen_in, size). + memory_mask (torch.Tensor): Encoded memory mask + (#batch, maxlen_in). + cache (torch.Tensor): cached tensors. + (#batch, maxlen_out - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, maxlen_out, size). + torch.Tensor: Mask for output tensor (#batch, maxlen_out). + torch.Tensor: Encoded memory (#batch, maxlen_in, size). + torch.Tensor: Encoded memory mask (#batch, maxlen_in). + + """ + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + + if cache is None: + tgt_q = tgt + tgt_q_mask = tgt_mask + else: + # compute only the last frame query keeping dim: max_time_out -> 1 + assert cache.shape == ( + tgt.shape[0], + tgt.shape[1] - 1, + self.size, + ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" + tgt_q = tgt[:, -1:, :] + residual = residual[:, -1:, :] + tgt_q_mask = tgt_mask[:, -1:, :] + + if self.concat_after: + tgt_concat = torch.cat( + (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]), dim=-1) + x = residual + self.concat_linear1(tgt_concat) + else: + x = residual + self.dropout( + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + if self.concat_after: + x_concat = torch.cat( + (x, self.src_attn(x, memory, memory, memory_mask)[0]), dim=-1) + x = residual + self.concat_linear2(x_concat) + else: + x = residual + self.dropout( + self.src_attn(x, memory, memory, memory_mask)[0]) + if not self.normalize_before: + x = self.norm2(x) + + residual = x + if self.normalize_before: + x = self.norm3(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm3(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + return x, tgt_mask, memory, memory_mask diff --git a/speech/speech_recognition/transformer/pytorch/wenet/transformer/embedding.py b/speech/speech_recognition/transformer/pytorch/wenet/transformer/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..a47afd9db1ede1b5b74461b309118398521f7f53 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/transformer/embedding.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Mobvoi Inc. All Rights Reserved. +# Author: di.wu@mobvoi.com (DI WU) +"""Positonal Encoding Module.""" + +import math +from typing import Tuple + +import torch + + +class PositionalEncoding(torch.nn.Module): + """Positional encoding. + + :param int d_model: embedding dim + :param float dropout_rate: dropout rate + :param int max_len: maximum input length + + PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) + PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) + """ + def __init__(self, + d_model: int, + dropout_rate: float, + max_len: int = 5000, + reverse: bool = False): + """Construct an PositionalEncoding object.""" + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.max_len = max_len + + self.pe = torch.zeros(self.max_len, self.d_model) + position = torch.arange(0, self.max_len, + dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) * + -(math.log(10000.0) / self.d_model)) + self.pe[:, 0::2] = torch.sin(position * div_term) + self.pe[:, 1::2] = torch.cos(position * div_term) + self.pe = self.pe.unsqueeze(0) + + def forward(self, + x: torch.Tensor, + offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + offset (int): position offset + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + torch.Tensor: for compatibility to RelPositionalEncoding + """ + assert offset + x.size(1) < self.max_len + self.pe = self.pe.to(x.device) + pos_emb = self.pe[:, offset:offset + x.size(1)] + x = x * self.xscale + pos_emb + return self.dropout(x), self.dropout(pos_emb) + + def position_encoding(self, offset: int, size: int) -> torch.Tensor: + """ For getting encoding in a streaming fashion + + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. + + Args: + offset (int): start offset + size (int): requried size of position encoding + + Returns: + torch.Tensor: Corresponding encoding + """ + assert offset + size < self.max_len + return self.dropout(self.pe[:, offset:offset + size]) + + +class RelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module. + See : Appendix B in https://arxiv.org/abs/1901.02860 + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): + """Initialize class.""" + super().__init__(d_model, dropout_rate, max_len, reverse=True) + + def forward(self, + x: torch.Tensor, + offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Positional embedding tensor (1, time, `*`). + """ + assert offset + x.size(1) < self.max_len + self.pe = self.pe.to(x.device) + x = x * self.xscale + pos_emb = self.pe[:, offset:offset + x.size(1)] + return self.dropout(x), self.dropout(pos_emb) + + +class NoPositionalEncoding(torch.nn.Module): + """ No position encoding + """ + def __init__(self, d_model: int, dropout_rate: float): + super().__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + + def forward(self, + x: torch.Tensor, + offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: + """ Just return zero vector for interface compatibility + """ + pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device) + return self.dropout(x), pos_emb + + def position_encoding(self, offset: int, size: int) -> torch.Tensor: + return torch.zeros(1, size, self.d_model) diff --git a/speech/speech_recognition/transformer/pytorch/wenet/transformer/encoder.py b/speech/speech_recognition/transformer/pytorch/wenet/transformer/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..96fe134a1a0b8fc60efeaeda08af2f0e11cc2f84 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/transformer/encoder.py @@ -0,0 +1,451 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Mobvoi Inc. All Rights Reserved. +# Author: di.wu@mobvoi.com (DI WU) +"""Encoder definition.""" +from typing import Tuple + +import torch +from typeguard import check_argument_types + +from wenet.transformer.attention import MultiHeadedAttention +from wenet.transformer.attention import RelPositionMultiHeadedAttention +from wenet.transformer.convolution import ConvolutionModule +from wenet.transformer.embedding import PositionalEncoding +from wenet.transformer.embedding import RelPositionalEncoding +from wenet.transformer.embedding import NoPositionalEncoding +from wenet.transformer.encoder_layer import TransformerEncoderLayer +from wenet.transformer.encoder_layer import ConformerEncoderLayer +from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward +from wenet.transformer.subsampling import Conv2dSubsampling4 +from wenet.transformer.subsampling import Conv2dSubsampling6 +from wenet.transformer.subsampling import Conv2dSubsampling8 +from wenet.transformer.subsampling import LinearNoSubsampling +from wenet.utils.common import get_activation +from wenet.utils.mask import make_pad_mask +from wenet.utils.mask import add_optional_chunk_mask + + +class BaseEncoder(torch.nn.Module): + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "abs_pos", + normalize_before: bool = True, + concat_after: bool = False, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + ): + """ + Args: + input_size (int): input dim + output_size (int): dimension of attention + attention_heads (int): the number of heads of multi head attention + linear_units (int): the hidden units number of position-wise feed + forward + num_blocks (int): the number of decoder blocks + dropout_rate (float): dropout rate + attention_dropout_rate (float): dropout rate in attention + positional_dropout_rate (float): dropout rate after adding + positional encoding + input_layer (str): input layer type. + optional [linear, conv2d, conv2d6, conv2d8] + pos_enc_layer_type (str): Encoder positional encoding layer type. + opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos] + normalize_before (bool): + True: use layer_norm before each sub-block of a layer. + False: use layer_norm after each sub-block of a layer. + concat_after (bool): whether to concat attention layer's input + and output. + True: x -> x + linear(concat(x, att(x))) + False: x -> x + att(x) + static_chunk_size (int): chunk size for static chunk training and + decoding + use_dynamic_chunk (bool): whether use dynamic chunk size for + training or not, You can only use fixed chunk(chunk_size > 0) + or dyanmic chunk size(use_dynamic_chunk = True) + global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module + use_dynamic_left_chunk (bool): whether use dynamic left chunk in + dynamic chunk training + """ + assert check_argument_types() + super().__init__() + self._output_size = output_size + + if pos_enc_layer_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == "rel_pos": + pos_enc_class = RelPositionalEncoding + elif pos_enc_layer_type == "no_pos": + pos_enc_class = NoPositionalEncoding + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) + + if input_layer == "linear": + subsampling_class = LinearNoSubsampling + elif input_layer == "conv2d": + subsampling_class = Conv2dSubsampling4 + elif input_layer == "conv2d6": + subsampling_class = Conv2dSubsampling6 + elif input_layer == "conv2d8": + subsampling_class = Conv2dSubsampling8 + else: + raise ValueError("unknown input_layer: " + input_layer) + + self.global_cmvn = global_cmvn + self.embed = subsampling_class( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + + self.normalize_before = normalize_before + self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5) + self.static_chunk_size = static_chunk_size + self.use_dynamic_chunk = use_dynamic_chunk + self.use_dynamic_left_chunk = use_dynamic_left_chunk + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Embed positions in tensor. + + Args: + xs: padded input tensor (B, T, D) + xs_lens: input length (B) + decoding_chunk_size: decoding chunk size for dynamic chunk + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + Returns: + encoder output tensor xs, and subsampled masks + xs: padded output tensor (B, T' ~= T/subsample_rate, D) + masks: torch.Tensor batch padding mask after subsample + (B, 1, T' ~= T/subsample_rate) + """ + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + xs, pos_emb, masks = self.embed(xs, masks) + mask_pad = masks # (B, 1, T/subsample_rate) + chunk_masks = add_optional_chunk_mask(xs, masks, + self.use_dynamic_chunk, + self.use_dynamic_left_chunk, + decoding_chunk_size, + self.static_chunk_size, + num_decoding_left_chunks) + for layer in self.encoders: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + if self.normalize_before: + xs = self.after_norm(xs) + # Here we assume the mask is not changed in encoder layers, so just + # return the masks before encoder layers, and the masks will be used + # for cross attention with decoder later + return xs, masks + + def forward_chunk( + self, + xs: torch.Tensor, + offset: int, + required_cache_size: int, + att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ Forward just one chunk + + Args: + xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + offset (int): current offset in encoder output time stamp + required_cache_size (int): cache size required for next chunk + compuation + >=0: actual cache size + <0: means all history cache is required + att_cache (torch.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, + (elayers, b=1, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + + Returns: + torch.Tensor: output of current input xs, + with shape (b=1, chunk_size, hidden-dim). + torch.Tensor: new attention cache required for next chunk, with + dynamic shape (elayers, head, ?, d_k * 2) + depending on required_cache_size. + torch.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. + + """ + assert xs.size(0) == 1 + # tmp_masks is just for interface compatibility + tmp_masks = torch.ones(1, + xs.size(1), + device=xs.device, + dtype=torch.bool) + tmp_masks = tmp_masks.unsqueeze(1) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) + xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) + # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim) + elayers, cache_t1 = att_cache.size(0), att_cache.size(2) + chunk_size = xs.size(1) + attention_key_size = cache_t1 + chunk_size + pos_emb = self.embed.position_encoding( + offset=offset - cache_t1, size=attention_key_size) + if required_cache_size < 0: + next_cache_start = 0 + elif required_cache_size == 0: + next_cache_start = attention_key_size + else: + next_cache_start = max(attention_key_size - required_cache_size, 0) + r_att_cache = [] + r_cnn_cache = [] + for i, layer in enumerate(self.encoders): + # NOTE(xcsong): Before layer.forward + # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), + # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) + xs, _, new_att_cache, new_cnn_cache = layer( + xs, att_mask, pos_emb, + att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, + cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache + ) + # NOTE(xcsong): After layer.forward + # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), + # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) + r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) + r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) + if self.normalize_before: + xs = self.after_norm(xs) + + # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2), + # ? may be larger than cache_t1, it depends on required_cache_size + r_att_cache = torch.cat(r_att_cache, dim=0) + # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2) + r_cnn_cache = torch.cat(r_cnn_cache, dim=0) + + return (xs, r_att_cache, r_cnn_cache) + + def forward_chunk_by_chunk( + self, + xs: torch.Tensor, + decoding_chunk_size: int, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ Forward input chunk by chunk with chunk_size like a streaming + fashion + + Here we should pay special attention to computation cache in the + streaming style forward chunk by chunk. Three things should be taken + into account for computation in the current network: + 1. transformer/conformer encoder layers output cache + 2. convolution in conformer + 3. convolution in subsampling + + However, we don't implement subsampling cache for: + 1. We can control subsampling module to output the right result by + overlapping input instead of cache left context, even though it + wastes some computation, but subsampling only takes a very + small fraction of computation in the whole model. + 2. Typically, there are several covolution layers with subsampling + in subsampling module, it is tricky and complicated to do cache + with different convolution layers with different subsampling + rate. + 3. Currently, nn.Sequential is used to stack all the convolution + layers in subsampling, we need to rewrite it to make it work + with cache, which is not prefered. + Args: + xs (torch.Tensor): (1, max_len, dim) + chunk_size (int): decoding chunk size + """ + assert decoding_chunk_size > 0 + # The model is trained by static or dynamic chunk + assert self.static_chunk_size > 0 or self.use_dynamic_chunk + subsampling = self.embed.subsampling_rate + context = self.embed.right_context + 1 # Add current frame + stride = subsampling * decoding_chunk_size + decoding_window = (decoding_chunk_size - 1) * subsampling + context + num_frames = xs.size(1) + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) + outputs = [] + offset = 0 + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + + # Feed forward overlap input step by step + for cur in range(0, num_frames - context + 1, stride): + end = min(cur + decoding_window, num_frames) + chunk_xs = xs[:, cur:end, :] + (y, att_cache, cnn_cache) = self.forward_chunk( + chunk_xs, offset, required_cache_size, att_cache, cnn_cache) + outputs.append(y) + offset += y.size(1) + ys = torch.cat(outputs, 1) + masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool) + return ys, masks + + +class TransformerEncoder(BaseEncoder): + """Transformer encoder module.""" + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "abs_pos", + normalize_before: bool = True, + concat_after: bool = False, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + ): + """ Construct TransformerEncoder + + See Encoder for the meaning of each parameter. + """ + assert check_argument_types() + super().__init__(input_size, output_size, attention_heads, + linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + concat_after, static_chunk_size, use_dynamic_chunk, + global_cmvn, use_dynamic_left_chunk) + self.encoders = torch.nn.ModuleList([ + TransformerEncoderLayer( + output_size, + MultiHeadedAttention(attention_heads, output_size, + attention_dropout_rate), + PositionwiseFeedForward(output_size, linear_units, + dropout_rate), dropout_rate, + normalize_before, concat_after) for _ in range(num_blocks) + ]) + + +class ConformerEncoder(BaseEncoder): + """Conformer encoder module.""" + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "rel_pos", + normalize_before: bool = True, + concat_after: bool = False, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: torch.nn.Module = None, + use_dynamic_left_chunk: bool = False, + positionwise_conv_kernel_size: int = 1, + macaron_style: bool = True, + selfattention_layer_type: str = "rel_selfattn", + activation_type: str = "swish", + use_cnn_module: bool = True, + cnn_module_kernel: int = 15, + causal: bool = False, + cnn_module_norm: str = "batch_norm", + ): + """Construct ConformerEncoder + + Args: + input_size to use_dynamic_chunk, see in BaseEncoder + positionwise_conv_kernel_size (int): Kernel size of positionwise + conv1d layer. + macaron_style (bool): Whether to use macaron style for + positionwise layer. + selfattention_layer_type (str): Encoder attention layer type, + the parameter has no effect now, it's just for configure + compatibility. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernel size of convolution module. + causal (bool): whether to use causal convolution or not. + """ + assert check_argument_types() + super().__init__(input_size, output_size, attention_heads, + linear_units, num_blocks, dropout_rate, + positional_dropout_rate, attention_dropout_rate, + input_layer, pos_enc_layer_type, normalize_before, + concat_after, static_chunk_size, use_dynamic_chunk, + global_cmvn, use_dynamic_left_chunk) + activation = get_activation(activation_type) + + # self-attention module definition + if pos_enc_layer_type == "no_pos": + encoder_selfattn_layer = MultiHeadedAttention + else: + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + # feed-forward module definition + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + ) + # convolution module definition + convolution_layer = ConvolutionModule + convolution_layer_args = (output_size, cnn_module_kernel, activation, + cnn_module_norm, causal) + + self.encoders = torch.nn.ModuleList([ + ConformerEncoderLayer( + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer( + *positionwise_layer_args) if macaron_style else None, + convolution_layer( + *convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + ) for _ in range(num_blocks) + ]) diff --git a/speech/speech_recognition/transformer/pytorch/wenet/transformer/encoder_layer.py b/speech/speech_recognition/transformer/pytorch/wenet/transformer/encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..a08c1727d0bccdb9359b9104c5cb172c00acaf45 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/transformer/encoder_layer.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Mobvoi Inc. All Rights Reserved. +# Author: di.wu@mobvoi.com (DI WU) +"""Encoder self-attention layer definition.""" + +from typing import Optional, Tuple + +import torch +from torch import nn + + +class TransformerEncoderLayer(nn.Module): + """Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: to use layer_norm after each sub-block. + concat_after (bool): Whether to concat attention layer's input and + output. + True: x -> x + linear(concat(x, att(x))) + False: x -> x + att(x) + + """ + def __init__( + self, + size: int, + self_attn: torch.nn.Module, + feed_forward: torch.nn.Module, + dropout_rate: float, + normalize_before: bool = True, + concat_after: bool = False, + ): + """Construct an EncoderLayer object.""" + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = nn.LayerNorm(size, eps=1e-5) + self.norm2 = nn.LayerNorm(size, eps=1e-5) + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if concat_after: + self.concat_linear = nn.Linear(size + size, size) + else: + self.concat_linear = nn.Identity() + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time,time), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): just for interface compatibility + to ConformerEncoderLayer + mask_pad (torch.Tensor): does not used in transformer layer, + just for unified api with conformer. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2), not used here, it's for interface + compatibility to ConformerEncoderLayer. + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2). + + """ + residual = x + if self.normalize_before: + x = self.norm1(x) + + x_att, new_att_cache = self.self_attn( + x, x, x, mask, cache=att_cache) + if self.concat_after: + x_concat = torch.cat((x, x_att), dim=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + return x, mask, new_att_cache, fake_cnn_cache + + +class ConformerEncoderLayer(nn.Module): + """Encoder layer module. + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + feed_forward_macaron (torch.nn.Module): Additional feed-forward module + instance. + `PositionwiseFeedForward` instance can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: use layer_norm after each sub-block. + concat_after (bool): Whether to concat attention layer's input and + output. + True: x -> x + linear(concat(x, att(x))) + False: x -> x + att(x) + """ + def __init__( + self, + size: int, + self_attn: torch.nn.Module, + feed_forward: Optional[nn.Module] = None, + feed_forward_macaron: Optional[nn.Module] = None, + conv_module: Optional[nn.Module] = None, + dropout_rate: float = 0.1, + normalize_before: bool = True, + concat_after: bool = False, + ): + """Construct an EncoderLayer object.""" + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module + self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module + if feed_forward_macaron is not None: + self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = nn.LayerNorm(size, + eps=1e-5) # for the CNN module + self.norm_final = nn.LayerNorm( + size, eps=1e-5) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + else: + self.concat_linear = nn.Identity() + + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time,time), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): positional encoding, must not be None + for ConformerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2) + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + + x_att, new_att_cache = self.self_attn( + x, x, x, mask, pos_emb, att_cache) + if self.concat_after: + x_concat = torch.cat((x, x_att), dim=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + + # convolution module + # Fake new cnn cache here, and then change it in conv_module + new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) + x = residual + self.dropout(x) + + if not self.normalize_before: + x = self.norm_conv(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + return x, mask, new_att_cache, new_cnn_cache diff --git a/speech/speech_recognition/transformer/pytorch/wenet/transformer/label_smoothing_loss.py b/speech/speech_recognition/transformer/pytorch/wenet/transformer/label_smoothing_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..36a74e89c5bc8d9b73f145ba549ae58742b4997b --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/transformer/label_smoothing_loss.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +"""Label smoothing module.""" + +import torch +from torch import nn + + +class LabelSmoothingLoss(nn.Module): + """Label-smoothing loss. + + In a standard CE loss, the label's data distribution is: + [0,1,2] -> + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ] + + In the smoothing version CE Loss,some probabilities + are taken from the true label prob (1.0) and are divided + among other labels. + + e.g. + smoothing=0.1 + [0,1,2] -> + [ + [0.9, 0.05, 0.05], + [0.05, 0.9, 0.05], + [0.05, 0.05, 0.9], + ] + + Args: + size (int): the number of class + padding_idx (int): padding class id which will be ignored for loss + smoothing (float): smoothing rate (0.0 means the conventional CE) + normalize_length (bool): + normalize loss by sequence length if True + normalize loss by batch size if False + """ + def __init__(self, + size: int, + padding_idx: int, + smoothing: float, + normalize_length: bool = False): + """Construct an LabelSmoothingLoss object.""" + super(LabelSmoothingLoss, self).__init__() + self.criterion = nn.KLDivLoss(reduction="none") + self.padding_idx = padding_idx + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.size = size + self.normalize_length = normalize_length + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Compute loss between x and target. + + The model outputs and data labels tensors are flatten to + (batch*seqlen, class) shape and a mask is applied to the + padding part which should not be calculated for loss. + + Args: + x (torch.Tensor): prediction (batch, seqlen, class) + target (torch.Tensor): + target signal masked with self.padding_id (batch, seqlen) + Returns: + loss (torch.Tensor) : The KL loss, scalar float value + """ + assert x.size(2) == self.size + batch_size = x.size(0) + x = x.view(-1, self.size) + target = target.view(-1) + # use zeros_like instead of torch.no_grad() for true_dist, + # since no_grad() can not be exported by JIT + true_dist = torch.zeros_like(x) + true_dist.fill_(self.smoothing / (self.size - 1)) + ignore = target == self.padding_idx # (B,) + total = len(target) - ignore.sum().item() + target = target.masked_fill(ignore, 0) # avoid -1 index + true_dist.scatter_(1, target.unsqueeze(1), self.confidence) + kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) + denom = total if self.normalize_length else batch_size + return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom diff --git a/speech/speech_recognition/transformer/pytorch/wenet/transformer/positionwise_feed_forward.py b/speech/speech_recognition/transformer/pytorch/wenet/transformer/positionwise_feed_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..1ee04f4d1ec4dad995c88f46ff4159a4cd188c19 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/transformer/positionwise_feed_forward.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +"""Positionwise feed forward layer definition.""" + +import torch + + +class PositionwiseFeedForward(torch.nn.Module): + """Positionwise feed forward layer. + + FeedForward are appied on each position of the sequence. + The output dim is same with the input dim. + + Args: + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + activation (torch.nn.Module): Activation function + """ + def __init__(self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU()): + """Construct a PositionwiseFeedForward object.""" + super(PositionwiseFeedForward, self).__init__() + self.w_1 = torch.nn.Linear(idim, hidden_units) + self.activation = activation + self.dropout = torch.nn.Dropout(dropout_rate) + self.w_2 = torch.nn.Linear(hidden_units, idim) + + def forward(self, xs: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + xs: input tensor (B, L, D) + Returns: + output tensor, (B, L, D) + """ + return self.w_2(self.dropout(self.activation(self.w_1(xs)))) diff --git a/speech/speech_recognition/transformer/pytorch/wenet/transformer/subsampling.py b/speech/speech_recognition/transformer/pytorch/wenet/transformer/subsampling.py new file mode 100644 index 0000000000000000000000000000000000000000..c6735b201b7c190cae4f7ab49ddfb6c2402a3a1a --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/transformer/subsampling.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Mobvoi Inc. All Rights Reserved. +# Author: di.wu@mobvoi.com (DI WU) +"""Subsampling layer definition.""" + +from typing import Tuple + +import torch + + +class BaseSubsampling(torch.nn.Module): + def __init__(self): + super().__init__() + self.right_context = 0 + self.subsampling_rate = 1 + + def position_encoding(self, offset: int, size: int) -> torch.Tensor: + return self.pos_enc.position_encoding(offset, size) + + +class LinearNoSubsampling(BaseSubsampling): + """Linear transform the input without subsampling + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an linear object.""" + super().__init__() + self.out = torch.nn.Sequential( + torch.nn.Linear(idim, odim), + torch.nn.LayerNorm(odim, eps=1e-5), + torch.nn.Dropout(dropout_rate), + ) + self.pos_enc = pos_enc_class + self.right_context = 0 + self.subsampling_rate = 1 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: int = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Input x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + torch.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + + """ + x = self.out(x) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask + + +class Conv2dSubsampling4(BaseSubsampling): + """Convolutional 2D subsampling (to 1/4 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling4 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) + self.pos_enc = pos_enc_class + # The right context for every conv layer is computed by: + # (kernel_size - 1) * frame_rate_of_this_layer + self.subsampling_rate = 4 + # 6 = (3 - 1) * 1 + (3 - 1) * 2 + self.right_context = 6 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: int = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 4. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 4. + torch.Tensor: positional encoding + + """ + x = x.unsqueeze(1) # (b, c=1, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] + + +class Conv2dSubsampling6(BaseSubsampling): + """Convolutional 2D subsampling (to 1/6 length). + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc (torch.nn.Module): Custom position encoding layer. + """ + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling6 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 5, 3), + torch.nn.ReLU(), + ) + self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), + odim) + self.pos_enc = pos_enc_class + # 10 = (3 - 1) * 1 + (5 - 1) * 2 + self.subsampling_rate = 6 + self.right_context = 10 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: int = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 6. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 6. + torch.Tensor: positional encoding + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3] + + +class Conv2dSubsampling8(BaseSubsampling): + """Convolutional 2D subsampling (to 1/8 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + def __init__(self, idim: int, odim: int, dropout_rate: float, + pos_enc_class: torch.nn.Module): + """Construct an Conv2dSubsampling8 object.""" + super().__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.linear = torch.nn.Linear( + odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim) + self.pos_enc = pos_enc_class + self.subsampling_rate = 8 + # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4 + self.right_context = 14 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: int = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 8. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 8. + torch.Tensor: positional encoding + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2] diff --git a/speech/speech_recognition/transformer/pytorch/wenet/transformer/swish.py b/speech/speech_recognition/transformer/pytorch/wenet/transformer/swish.py new file mode 100644 index 0000000000000000000000000000000000000000..2571ad452c0724cd7efe19896d26080f628a21e9 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/transformer/swish.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +"""Swish() activation function for Conformer.""" + +import torch + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return Swish activation function.""" + return x * torch.sigmoid(x) diff --git a/speech/speech_recognition/transformer/pytorch/wenet/utils/checkpoint.py b/speech/speech_recognition/transformer/pytorch/wenet/utils/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..3623e2c0a7d441d87c40e8ad970ed6e75648b5d6 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/utils/checkpoint.py @@ -0,0 +1,92 @@ +# Copyright 2019 Mobvoi Inc. All Rights Reserved. +# Author: binbinzhang@mobvoi.com (Binbin Zhang) + +import logging +import os +import re + +import yaml +import torch +from collections import OrderedDict + + +def load_checkpoint(model: torch.nn.Module, path: str) -> dict: + if torch.cuda.is_available(): + logging.info('Checkpoint: loading from checkpoint %s for GPU' % path) + checkpoint = torch.load(path) + else: + logging.info('Checkpoint: loading from checkpoint %s for CPU' % path) + checkpoint = torch.load(path, map_location='cpu') + model.load_state_dict(checkpoint, strict=False) + info_path = re.sub('.pt$', '.yaml', path) + configs = {} + if os.path.exists(info_path): + with open(info_path, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + return configs + + +def save_checkpoint(model: torch.nn.Module, path: str, infos=None): + ''' + Args: + infos (dict or None): any info you want to save. + ''' + logging.info('Checkpoint: save to checkpoint %s' % path) + if isinstance(model, torch.nn.DataParallel): + state_dict = model.module.state_dict() + elif isinstance(model, torch.nn.parallel.DistributedDataParallel): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + torch.save(state_dict, path) + info_path = re.sub('.pt$', '.yaml', path) + if infos is None: + infos = {} + with open(info_path, 'w') as fout: + data = yaml.dump(infos) + fout.write(data) + + +def filter_modules(model_state_dict, modules): + new_mods = [] + incorrect_mods = [] + mods_model = model_state_dict.keys() + for mod in modules: + if any(key.startswith(mod) for key in mods_model): + new_mods += [mod] + else: + incorrect_mods += [mod] + if incorrect_mods: + logging.warning( + "module(s) %s don't match or (partially match) " + "available modules in model.", + incorrect_mods, + ) + logging.warning("for information, the existing modules in model are:") + logging.warning("%s", mods_model) + + return new_mods + + +def load_trained_modules(model: torch.nn.Module, args: None): + # Load encoder modules with pre-trained model(s). + enc_model_path = args.enc_init + enc_modules = args.enc_init_mods + main_state_dict = model.state_dict() + logging.warning("model(s) found for pre-initialization") + if os.path.isfile(enc_model_path): + logging.info('Checkpoint: loading from checkpoint %s for CPU' % + enc_model_path) + model_state_dict = torch.load(enc_model_path, map_location='cpu') + modules = filter_modules(model_state_dict, enc_modules) + partial_state_dict = OrderedDict() + for key, value in model_state_dict.items(): + if any(key.startswith(m) for m in modules): + partial_state_dict[key] = value + main_state_dict.update(partial_state_dict) + else: + logging.warning("model was not found : %s", model_path) + + model.load_state_dict(main_state_dict) + configs = {} + return configs diff --git a/speech/speech_recognition/transformer/pytorch/wenet/utils/cmvn.py b/speech/speech_recognition/transformer/pytorch/wenet/utils/cmvn.py new file mode 100644 index 0000000000000000000000000000000000000000..d262143210dde2c73b7dabd67eba87ecdbc2a7b4 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/utils/cmvn.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import math + +import numpy as np + + +def _load_json_cmvn(json_cmvn_file): + """ Load the json format cmvn stats file and calculate cmvn + + Args: + json_cmvn_file: cmvn stats file in json format + + Returns: + a numpy array of [means, vars] + """ + with open(json_cmvn_file) as f: + cmvn_stats = json.load(f) + + means = cmvn_stats['mean_stat'] + variance = cmvn_stats['var_stat'] + count = cmvn_stats['frame_num'] + for i in range(len(means)): + means[i] /= count + variance[i] = variance[i] / count - means[i] * means[i] + if variance[i] < 1.0e-20: + variance[i] = 1.0e-20 + variance[i] = 1.0 / math.sqrt(variance[i]) + cmvn = np.array([means, variance]) + return cmvn + + +def _load_kaldi_cmvn(kaldi_cmvn_file): + """ Load the kaldi format cmvn stats file and calculate cmvn + + Args: + kaldi_cmvn_file: kaldi text style global cmvn file, which + is generated by: + compute-cmvn-stats --binary=false scp:feats.scp global_cmvn + + Returns: + a numpy array of [means, vars] + """ + means = [] + variance = [] + with open(kaldi_cmvn_file, 'r') as fid: + # kaldi binary file start with '\0B' + if fid.read(2) == '\0B': + logging.error('kaldi cmvn binary file is not supported, please ' + 'recompute it by: compute-cmvn-stats --binary=false ' + ' scp:feats.scp global_cmvn') + sys.exit(1) + fid.seek(0) + arr = fid.read().split() + assert (arr[0] == '[') + assert (arr[-2] == '0') + assert (arr[-1] == ']') + feat_dim = int((len(arr) - 2 - 2) / 2) + for i in range(1, feat_dim + 1): + means.append(float(arr[i])) + count = float(arr[feat_dim + 1]) + for i in range(feat_dim + 2, 2 * feat_dim + 2): + variance.append(float(arr[i])) + + for i in range(len(means)): + means[i] /= count + variance[i] = variance[i] / count - means[i] * means[i] + if variance[i] < 1.0e-20: + variance[i] = 1.0e-20 + variance[i] = 1.0 / math.sqrt(variance[i]) + cmvn = np.array([means, variance]) + return cmvn + + +def load_cmvn(cmvn_file, is_json): + if is_json: + cmvn = _load_json_cmvn(cmvn_file) + else: + cmvn = _load_kaldi_cmvn(cmvn_file) + return cmvn[0], cmvn[1] diff --git a/speech/speech_recognition/transformer/pytorch/wenet/utils/common.py b/speech/speech_recognition/transformer/pytorch/wenet/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..3ddd4e2bdcd1ca8002cdb5e52f85b5a67bc1cb27 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/utils/common.py @@ -0,0 +1,186 @@ +"""Unility functions for Transformer.""" + +import math +from typing import Tuple, List + +import torch +from torch.nn.utils.rnn import pad_sequence + +IGNORE_ID = -1 + + +def pad_list(xs: List[torch.Tensor], pad_value: int): + """Perform padding for the list of tensors. + + Args: + xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. + pad_value (float): Value for padding. + + Returns: + Tensor: Padded tensor (B, Tmax, `*`). + + Examples: + >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] + >>> x + [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] + >>> pad_list(x, 0) + tensor([[1., 1., 1., 1.], + [1., 1., 0., 0.], + [1., 0., 0., 0.]]) + + """ + n_batch = len(xs) + max_len = max([x.size(0) for x in xs]) + pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device) + pad = pad.fill_(pad_value) + for i in range(n_batch): + pad[i, :xs[i].size(0)] = xs[i] + + return pad + + +def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int, + ignore_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Add and labels. + + Args: + ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) + sos (int): index of + eos (int): index of + ignore_id (int): index of padding + + Returns: + ys_in (torch.Tensor) : (B, Lmax + 1) + ys_out (torch.Tensor) : (B, Lmax + 1) + + Examples: + >>> sos_id = 10 + >>> eos_id = 11 + >>> ignore_id = -1 + >>> ys_pad + tensor([[ 1, 2, 3, 4, 5], + [ 4, 5, 6, -1, -1], + [ 7, 8, 9, -1, -1]], dtype=torch.int32) + >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id) + >>> ys_in + tensor([[10, 1, 2, 3, 4, 5], + [10, 4, 5, 6, 11, 11], + [10, 7, 8, 9, 11, 11]]) + >>> ys_out + tensor([[ 1, 2, 3, 4, 5, 11], + [ 4, 5, 6, 11, -1, -1], + [ 7, 8, 9, 11, -1, -1]]) + """ + _sos = torch.tensor([sos], + dtype=torch.long, + requires_grad=False, + device=ys_pad.device) + _eos = torch.tensor([eos], + dtype=torch.long, + requires_grad=False, + device=ys_pad.device) + ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys + ys_in = [torch.cat([_sos, y], dim=0) for y in ys] + ys_out = [torch.cat([y, _eos], dim=0) for y in ys] + return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) + + +def reverse_pad_list(ys_pad: torch.Tensor, + ys_lens: torch.Tensor, + pad_value: float = -1.0) -> torch.Tensor: + """Reverse padding for the list of tensors. + + Args: + ys_pad (tensor): The padded tensor (B, Tokenmax). + ys_lens (tensor): The lens of token seqs (B) + pad_value (int): Value for padding. + + Returns: + Tensor: Padded tensor (B, Tokenmax). + + Examples: + >>> x + tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]]) + >>> pad_list(x, 0) + tensor([[4, 3, 2, 1], + [7, 6, 5, 0], + [9, 8, 0, 0]]) + + """ + r_ys_pad = pad_sequence([(torch.flip(y.int()[:i], [0])) + for y, i in zip(ys_pad, ys_lens)], True, + pad_value) + return r_ys_pad + + +def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor, + ignore_label: int) -> float: + """Calculate accuracy. + + Args: + pad_outputs (Tensor): Prediction tensors (B * Lmax, D). + pad_targets (LongTensor): Target label tensors (B, Lmax, D). + ignore_label (int): Ignore label id. + + Returns: + float: Accuracy value (0.0 - 1.0). + + """ + pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1), + pad_outputs.size(1)).argmax(2) + mask = pad_targets != ignore_label + numerator = torch.sum( + pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) + denominator = torch.sum(mask) + return float(numerator) / float(denominator) + + +def get_activation(act): + """Return activation function.""" + # Lazy load to avoid unused import + from wenet.transformer.swish import Swish + + activation_funcs = { + "hardtanh": torch.nn.Hardtanh, + "tanh": torch.nn.Tanh, + "relu": torch.nn.ReLU, + "selu": torch.nn.SELU, + "swish": getattr(torch.nn, "SiLU", Swish), + "gelu": torch.nn.GELU + } + + return activation_funcs[act]() + + +def get_subsample(config): + input_layer = config["encoder_conf"]["input_layer"] + assert input_layer in ["conv2d", "conv2d6", "conv2d8"] + if input_layer == "conv2d": + return 4 + elif input_layer == "conv2d6": + return 6 + elif input_layer == "conv2d8": + return 8 + + +def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: + new_hyp: List[int] = [] + cur = 0 + while cur < len(hyp): + if hyp[cur] != 0: + new_hyp.append(hyp[cur]) + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + cur += 1 + return new_hyp + + +def log_add(args: List[int]) -> float: + """ + Stable log add + """ + if all(a == -float('inf') for a in args): + return -float('inf') + a_max = max(args) + lsp = math.log(sum(math.exp(a - a_max) for a in args)) + return a_max + lsp diff --git a/speech/speech_recognition/transformer/pytorch/wenet/utils/config.py b/speech/speech_recognition/transformer/pytorch/wenet/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..fad8632c1c1ac738fc5ade733e52e031194fb5a3 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/utils/config.py @@ -0,0 +1,24 @@ +import copy + +def override_config(configs, override_list): + new_configs = copy.deepcopy(configs) + for item in override_list: + arr = item.split() + if len(arr) != 2: + print(f"the overrive {item} format not correct, skip it") + continue + keys = arr[0].split('.') + s_configs = new_configs + for i, key in enumerate(keys): + if key not in s_configs: + print(f"the overrive {item} format not correct, skip it") + if i == len(keys) - 1: + param_type = type(s_configs[key]) + if param_type != bool: + s_configs[key] = param_type(arr[1]) + else: + s_configs[key] = arr[1] in ['true', 'True'] + print(f"override {arr[0]} with {arr[1]}") + else: + s_configs = s_configs[key] + return new_configs diff --git a/speech/speech_recognition/transformer/pytorch/wenet/utils/ctc_util.py b/speech/speech_recognition/transformer/pytorch/wenet/utils/ctc_util.py new file mode 100644 index 0000000000000000000000000000000000000000..416507c91d5497153e761edff2a66ee7a397d5bc --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/utils/ctc_util.py @@ -0,0 +1,72 @@ +# Copyright 2021 Mobvoi Inc. All Rights Reserved. +# Author: binbinzhang@mobvoi.com (Di Wu) + +import numpy as np +import torch + +def insert_blank(label, blank_id=0): + """Insert blank token between every two label token.""" + label = np.expand_dims(label, 1) + blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id + label = np.concatenate([blanks, label], axis=1) + label = label.reshape(-1) + label = np.append(label, label[0]) + return label + +def forced_align(ctc_probs: torch.Tensor, + y: torch.Tensor, + blank_id=0) -> list: + """ctc forced alignment. + + Args: + torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D) + torch.Tensor y: id sequence tensor 1d tensor (L) + int blank_id: blank symbol index + Returns: + torch.Tensor: alignment result + """ + y_insert_blank = insert_blank(y, blank_id) + + log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank))) + log_alpha = log_alpha - float('inf') # log of zero + state_path = (torch.zeros( + (ctc_probs.size(0), len(y_insert_blank)), dtype=torch.int16) - 1 + ) # state path + + # init start state + log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] + log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] + + for t in range(1, ctc_probs.size(0)): + for s in range(len(y_insert_blank)): + if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ + s] == y_insert_blank[s - 2]: + candidates = torch.tensor( + [log_alpha[t - 1, s], log_alpha[t - 1, s - 1]]) + prev_state = [s, s - 1] + else: + candidates = torch.tensor([ + log_alpha[t - 1, s], + log_alpha[t - 1, s - 1], + log_alpha[t - 1, s - 2], + ]) + prev_state = [s, s - 1, s - 2] + log_alpha[t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]] + state_path[t, s] = prev_state[torch.argmax(candidates)] + + state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16) + + candidates = torch.tensor([ + log_alpha[-1, len(y_insert_blank) - 1], + log_alpha[-1, len(y_insert_blank) - 2] + ]) + prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2] + state_seq[-1] = prev_state[torch.argmax(candidates)] + for t in range(ctc_probs.size(0) - 2, -1, -1): + state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] + + output_alignment = [] + for t in range(0, ctc_probs.size(0)): + output_alignment.append(y_insert_blank[state_seq[t, 0]]) + + return output_alignment diff --git a/speech/speech_recognition/transformer/pytorch/wenet/utils/executor.py b/speech/speech_recognition/transformer/pytorch/wenet/utils/executor.py new file mode 100644 index 0000000000000000000000000000000000000000..b0289c28e4bad89f47c6eb353cfca1f9c51d9fdb --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/utils/executor.py @@ -0,0 +1,148 @@ +# Copyright 2019 Mobvoi Inc. All Rights Reserved. +# Author: binbinzhang@mobvoi.com (Binbin Zhang) + +import logging +try: + from contextlib import nullcontext +except: + from contextlib import suppress as nullcontext +# if your python version < 3.7 use the below one +# from contextlib import suppress as nullcontext +import torch +from torch.nn.utils import clip_grad_norm_ + + +class Executor: + def __init__(self): + self.step = 0 + + def train(self, model, optimizer, scheduler, data_loader, device, writer, + args, scaler): + ''' Train one epoch + ''' + model.train() + clip = args.get('grad_clip', 50.0) + log_interval = args.get('log_interval', 10) + rank = args.get('rank', 0) + epoch = args.get('epoch', 0) + accum_grad = args.get('accum_grad', 1) + is_distributed = args.get('is_distributed', True) + use_amp = args.get('use_amp', False) + logging.info('using accumulate grad, new batch size is {} times' + ' larger than before'.format(accum_grad)) + if use_amp: + assert scaler is not None + # A context manager to be used in conjunction with an instance of + # torch.nn.parallel.DistributedDataParallel to be able to train + # with uneven inputs across participating processes. + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model_context = model.join + else: + model_context = nullcontext + num_seen_utts = 0 + with model_context(): + for batch_idx, batch in enumerate(data_loader): + key, feats, target, feats_lengths, target_lengths = batch + feats = feats.to(device) + target = target.to(device) + feats_lengths = feats_lengths.to(device) + target_lengths = target_lengths.to(device) + num_utts = target_lengths.size(0) + if num_utts == 0: + continue + context = None + # Disable gradient synchronizations across DDP processes. + # Within this context, gradients will be accumulated on module + # variables, which will later be synchronized. + if is_distributed and batch_idx % accum_grad != 0: + context = model.no_sync + # Used for single gpu training and DDP gradient synchronization + # processes. + else: + context = nullcontext + with context(): + # autocast context + # The more details about amp can be found in + # https://pytorch.org/docs/stable/notes/amp_examples.html + with torch.cuda.amp.autocast(scaler is not None): + loss, loss_att, loss_ctc = model( + feats, feats_lengths, target, target_lengths) + loss = loss / accum_grad + if use_amp: + scaler.scale(loss).backward() + else: + loss.backward() + + num_seen_utts += num_utts + if batch_idx % accum_grad == 0: + if rank == 0 and writer is not None: + writer.add_scalar('train_loss', loss, self.step) + # Use mixed precision training + if use_amp: + scaler.unscale_(optimizer) + grad_norm = clip_grad_norm_(model.parameters(), clip) + # Must invoke scaler.update() if unscale_() is used in + # the iteration to avoid the following error: + # RuntimeError: unscale_() has already been called + # on this optimizer since the last update(). + # We don't check grad here since that if the gradient + # has inf/nan values, scaler.step will skip + # optimizer.step(). + scaler.step(optimizer) + scaler.update() + else: + grad_norm = clip_grad_norm_(model.parameters(), clip) + if torch.isfinite(grad_norm): + optimizer.step() + optimizer.zero_grad() + scheduler.step() + self.step += 1 + if batch_idx % log_interval == 0: + lr = optimizer.param_groups[0]['lr'] + log_str = 'TRAIN Batch {}/{} loss {:.6f} '.format( + epoch, batch_idx, + loss.item() * accum_grad) + if loss_att is not None: + log_str += 'loss_att {:.6f} '.format(loss_att.item()) + if loss_ctc is not None: + log_str += 'loss_ctc {:.6f} '.format(loss_ctc.item()) + log_str += 'lr {:.8f} rank {}'.format(lr, rank) + logging.debug(log_str) + + def cv(self, model, data_loader, device, args): + ''' Cross validation on + ''' + model.eval() + rank = args.get('rank', 0) + epoch = args.get('epoch', 0) + log_interval = args.get('log_interval', 10) + # in order to avoid division by 0 + num_seen_utts = 1 + total_loss = 0.0 + with torch.no_grad(): + for batch_idx, batch in enumerate(data_loader): + key, feats, target, feats_lengths, target_lengths = batch + feats = feats.to(device) + target = target.to(device) + feats_lengths = feats_lengths.to(device) + target_lengths = target_lengths.to(device) + num_utts = target_lengths.size(0) + if num_utts == 0: + continue + loss, loss_att, loss_ctc = model(feats, feats_lengths, target, + target_lengths) + if torch.isfinite(loss): + num_seen_utts += num_utts + total_loss += loss.item() * num_utts + if batch_idx % log_interval == 0: + log_str = 'CV Batch {}/{} loss {:.6f} '.format( + epoch, batch_idx, loss.item()) + if loss_att is not None: + log_str += 'loss_att {:.6f} '.format(loss_att.item()) + if loss_ctc is not None: + log_str += 'loss_ctc {:.6f} '.format(loss_ctc.item()) + log_str += 'history loss {:.6f}'.format(total_loss / + num_seen_utts) + log_str += ' rank {}'.format(rank) + logging.debug(log_str) + return total_loss, num_seen_utts diff --git a/speech/speech_recognition/transformer/pytorch/wenet/utils/file_utils.py b/speech/speech_recognition/transformer/pytorch/wenet/utils/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7b7e516cc61f759267f4ef09309ff0b45110a0c1 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/utils/file_utils.py @@ -0,0 +1,66 @@ +# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + + +def read_lists(list_file): + lists = [] + with open(list_file, 'r', encoding='utf8') as fin: + for line in fin: + lists.append(line.strip()) + return lists + + +def read_non_lang_symbols(non_lang_sym_path): + """read non-linguistic symbol from file. + + The file format is like below: + + {NOISE}\n + {BRK}\n + ... + + + Args: + non_lang_sym_path: non-linguistic symbol file path, None means no any + syms. + + """ + if non_lang_sym_path is None: + return None + else: + syms = read_lists(non_lang_sym_path) + non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})") + for sym in syms: + if non_lang_syms_pattern.fullmatch(sym) is None: + class BadSymbolFormat(Exception): + pass + raise BadSymbolFormat( + "Non-linguistic symbols should be " + "formatted in {xxx}//[xxx], consider" + " modify '%s' to meet the requirment. " + "More details can be found in discussions here : " + "https://github.com/wenet-e2e/wenet/pull/819" % (sym)) + return syms + + +def read_symbol_table(symbol_table_file): + symbol_table = {} + with open(symbol_table_file, 'r', encoding='utf8') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + symbol_table[arr[0]] = int(arr[1]) + return symbol_table diff --git a/speech/speech_recognition/transformer/pytorch/wenet/utils/mask.py b/speech/speech_recognition/transformer/pytorch/wenet/utils/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..8dc6e28d72064369750244761ec63c872cf1309c --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/utils/mask.py @@ -0,0 +1,287 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import torch + +''' +def subsequent_mask( + size: int, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size). + + This mask is used only in decoder which works in an auto-regressive mode. + This means the current step could only do attention with its left steps. + + In encoder, fully attention is used when streaming is not necessary and + the sequence is not long. In this case, no attention mask is needed. + + When streaming is need, chunk-based attention is used in encoder. See + subsequent_chunk_mask for the chunk-based attention mask. + + Args: + size (int): size of mask + str device (str): "cpu" or "cuda" or torch.Tensor.device + dtype (torch.device): result dtype + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + """ + ret = torch.ones(size, size, device=device, dtype=torch.bool) + return torch.tril(ret) +''' + +def subsequent_mask( + size: int, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size). + + This mask is used only in decoder which works in an auto-regressive mode. + This means the current step could only do attention with its left steps. + + In encoder, fully attention is used when streaming is not necessary and + the sequence is not long. In this case, no attention mask is needed. + + When streaming is need, chunk-based attention is used in encoder. See + subsequent_chunk_mask for the chunk-based attention mask. + + Args: + size (int): size of mask + str device (str): "cpu" or "cuda" or torch.Tensor.device + dtype (torch.device): result dtype + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + """ + arange = torch.arange(size, device=device) + mask = arange.expand(size, size) + arange = arange.unsqueeze(-1) + mask = mask <= arange + return mask + + +def subsequent_chunk_mask( + size: int, + chunk_size: int, + num_left_chunks: int = -1, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size) with chunk size, + this is for streaming encoder + + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + device (torch.device): "cpu" or "cuda" or torch.Tensor.device + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + ret = torch.zeros(size, size, device=device, dtype=torch.bool) + for i in range(size): + if num_left_chunks < 0: + start = 0 + else: + start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) + ending = min((i // chunk_size + 1) * chunk_size, size) + ret[i, start:ending] = True + return ret + + +def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor, + use_dynamic_chunk: bool, + use_dynamic_left_chunk: bool, + decoding_chunk_size: int, static_chunk_size: int, + num_decoding_left_chunks: int): + """ Apply optional mask for encoder. + + Args: + xs (torch.Tensor): padded input, (B, L, D), L for max length + mask (torch.Tensor): mask for xs, (B, 1, L) + use_dynamic_chunk (bool): whether to use dynamic chunk or not + use_dynamic_left_chunk (bool): whether to use dynamic left chunk for + training. + decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + static_chunk_size (int): chunk size for static chunk training/decoding + if it's greater than 0, if use_dynamic_chunk is true, + this parameter will be ignored + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + + Returns: + torch.Tensor: chunk mask of the input xs. + """ + # Whether to use chunk mask or not + if use_dynamic_chunk: + max_len = xs.size(1) + if decoding_chunk_size < 0: + chunk_size = max_len + num_left_chunks = -1 + elif decoding_chunk_size > 0: + chunk_size = decoding_chunk_size + num_left_chunks = num_decoding_left_chunks + else: + # chunk size is either [1, 25] or full context(max_len). + # Since we use 4 times subsampling and allow up to 1s(100 frames) + # delay, the maximum frame is 100 / 4 = 25. + chunk_size = torch.randint(1, max_len, (1, )).item() + num_left_chunks = -1 + if chunk_size > max_len // 2: + chunk_size = max_len + else: + chunk_size = chunk_size % 25 + 1 + if use_dynamic_left_chunk: + max_left_chunks = (max_len - 1) // chunk_size + num_left_chunks = torch.randint(0, max_left_chunks, + (1, )).item() + chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, + num_left_chunks, + xs.device) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + elif static_chunk_size > 0: + num_left_chunks = num_decoding_left_chunks + chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, + num_left_chunks, + xs.device) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + else: + chunk_masks = masks + return chunk_masks + + +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (torch.Tensor): Batch of lengths (B,). + Returns: + torch.Tensor: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = lengths.size(0) + max_len = max_len if max_len > 0 else lengths.max().item() + seq_range = torch.arange(0, + max_len, + dtype=torch.int64, + device=lengths.device) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask + + +def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor: + """Make mask tensor containing indices of non-padded part. + + The sequences in a batch may have different lengths. To enable + batch computing, padding is need to make all sequence in same + size. To avoid the padding part pass value to context dependent + block such as attention or convolution , this padding part is + masked. + + This pad_mask is used in both encoder and decoder. + + 1 for non-padded part and 0 for padded part. + + Args: + lengths (torch.Tensor): Batch of lengths (B,). + Returns: + torch.Tensor: mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[1, 1, 1, 1 ,1], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0]] + """ + return ~make_pad_mask(lengths) + + +def mask_finished_scores(score: torch.Tensor, + flag: torch.Tensor) -> torch.Tensor: + """ + If a sequence is finished, we only allow one alive branch. This function + aims to give one branch a zero score and the rest -inf score. + + Args: + score (torch.Tensor): A real value array with shape + (batch_size * beam_size, beam_size). + flag (torch.Tensor): A bool array with shape + (batch_size * beam_size, 1). + + Returns: + torch.Tensor: (batch_size * beam_size, beam_size). + """ + beam_size = score.size(-1) + zero_mask = torch.zeros_like(flag, dtype=torch.bool) + if beam_size > 1: + unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])), + dim=1) + finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])), + dim=1) + else: + unfinished = zero_mask + finished = flag + score.masked_fill_(unfinished, -float('inf')) + score.masked_fill_(finished, 0) + return score + + +def mask_finished_preds(pred: torch.Tensor, flag: torch.Tensor, + eos: int) -> torch.Tensor: + """ + If a sequence is finished, all of its branch should be + + Args: + pred (torch.Tensor): A int array with shape + (batch_size * beam_size, beam_size). + flag (torch.Tensor): A bool array with shape + (batch_size * beam_size, 1). + + Returns: + torch.Tensor: (batch_size * beam_size). + """ + beam_size = pred.size(-1) + finished = flag.repeat([1, beam_size]) + return pred.masked_fill_(finished, eos) diff --git a/speech/speech_recognition/transformer/pytorch/wenet/utils/scheduler.py b/speech/speech_recognition/transformer/pytorch/wenet/utils/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..0031b08a6b4c3388bf5d573f132343b6076d6017 --- /dev/null +++ b/speech/speech_recognition/transformer/pytorch/wenet/utils/scheduler.py @@ -0,0 +1,52 @@ +from typing import Union + +import torch +from torch.optim.lr_scheduler import _LRScheduler + +from typeguard import check_argument_types + + +class WarmupLR(_LRScheduler): + """The WarmupLR scheduler + + This scheduler is almost same as NoamLR Scheduler except for following + difference: + + NoamLR: + lr = optimizer.lr * model_size ** -0.5 + * min(step ** -0.5, step * warmup_step ** -1.5) + WarmupLR: + lr = optimizer.lr * warmup_step ** 0.5 + * min(step ** -0.5, step * warmup_step ** -1.5) + + Note that the maximum lr equals to optimizer.lr in this scheduler. + + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + warmup_steps: Union[int, float] = 25000, + last_epoch: int = -1, + ): + assert check_argument_types() + self.warmup_steps = warmup_steps + + # __init__() must be invoked before setting field + # because step() is also invoked in __init__() + super().__init__(optimizer, last_epoch) + + def __repr__(self): + return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})" + + def get_lr(self): + step_num = self.last_epoch + 1 + return [ + lr + * self.warmup_steps ** 0.5 + * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5) + for lr in self.base_lrs + ] + + def set_step(self, step: int): + self.last_epoch = step