diff --git a/research/audio/gdprnn/README_CN.md b/research/audio/gdprnn/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..f13cd5f0f02260114c49722039022dde60716b87 --- /dev/null +++ b/research/audio/gdprnn/README_CN.md @@ -0,0 +1,262 @@ +# 目录 + + + +- [目录](#目录) +- [gdprnn介绍](#gdprnn介绍) +- [模型架构](#模型架构) +- [数据集](#数据集) +- [环境要求](#环境要求) +- [脚本说明](#脚本说明) + - [脚本及样例代码](#脚本及样例代码) + - [脚本参数](#脚本参数) + - [数据预处理过程](#数据预处理过程) + - [数据预处理](#数据预处理) + - [训练过程](#训练过程) + - [训练](#训练) + - [评估过程](#评估过程) + - [评估](#评估) + - [导出mindir模型](#导出mindir模型) + - [导出](#导出) + - [推理过程](#推理过程) + - [推理](#推理) +- [模型描述](#模型描述) + - [性能](#性能) + - [训练性能](#训练性能) + - [推理性能](#推理性能) +- [随机情况说明](#随机情况说明) +- [ModelZoo主页](#modelzoo主页) + +# gdprnn介绍 + +gdprnn由三个处理模块组成,编码器、分离模块和解码器。首先,编码器模块用于将混合波形的短段转换为它们在中间特征空间中的对应表示。然后,该表示用于主模型进行分离。最后,利用解码器模块对主模型输出结果重构,得到源波形。gdprnn被广泛的应用在语音分离等任务上,取得了显著的效果。 + +[论文](https://arxiv.org/pdf/2003.01531.pdf): gdprnn: Voice Separation with an Unknown Number of Multiple Speakers + +# 模型架构 + +模型包括 +encoder:类似fft,提取语音特征。 +decoder:类似ifft,得到语音波形。 +separation:对语音进行分离,得到单个语音的一个语谱图,通过decoder还原出语音波形。 + +# 数据集 + +使用的数据集为: [librimix](),LibriMix 是一个开源数据集,用于在嘈杂环境中进行源代码分离。 +要生成 LibriMix,请参照开源项目:https://github.com/JorisCos/LibriMix + +# 环境要求 + +- 硬件(ASCEND) + - ASCEND处理器 +- 框架 + - [MindSpore](https://www.mindspore.cn/install/en) +- 通过下面网址可以获得更多信息: + - [MindSpore tutorials](https://www.mindspore.cn/tutorials/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/docs/zh-CN/master/index.html) +- 依赖 + - 见requirements.txt文件,使用方法如下: + +```python +pip install -r requirements.txt +``` + +# 脚本说明 + +## 脚本及样例代码 + +```path +gdprnn +├─ README.md # descriptions +├── scripts + ├─ run_distribute_train.sh # launch ascend training(8 pcs) + └─ run_stranalone_train.sh # launch ascend training(1 pcs) + └─ run_infer_310.sh # launch infer 310 + └─ run_eval.sh # launch ascend eval +├── src + ├── data + ├─ data.py # postprocess data + └─ preprocess.py # preprocess json + ├── models + ├─ loss.py # loss function + └─ swave.py # module + ├─ generatorloss.py # generate loss + ├─ network_define.py # define network + └─ trainonestep.py # trainonestepcell +├─ train.py # train +├─ evaluate.py # eval +├─ export.py # export mindir script +├─ model.py # dptnet_tasnet +├─ preprocess.py # preprocess of 310 +├─ postprocess.py # postprocess of 310 +├─ requirements.txt # requirements +``` + +## 脚本参数 + +数据预处理、训练、评估的相关参数在`train.py`等文件 + +```text +数据预处理相关参数 +in-dir 预处理前加载原始数据集目录 +out-dir 预处理后的json文件的目录 +sample-rate 采样率 +train_name 预处理后的训练MindRecord文件的名称 +test_name 预处理后的测试MindRecord文件的名称 +``` + +```text +训练和模型相关参数 +train_dir 训练集 +valid_dir 测试集 +segment 取得音频的长度 +sr 采样率 +N 输入通道数 +L 卷积核大小 +H 分离模块卷积块通道数 +R 分离层中重复次数 +C 说话者数量 +lr 学习率 +``` + +```text +评估相关参数 +model_path ckpt文件 +data-dir 测试集路径 +batch_size 测试集batch大小 +``` + +```text +配置相关参数 +device_traget 硬件,只支持ASCEND +device_id 设备号 +``` + +# 数据预处理过程 + +## 数据预处理 + +数据预处理运行示例: + +```text +python preprocess.py +``` + +数据预处理过程很快,大约需要三分钟时间 + +# 训练过程 + +## 训练 + +- ### 单卡训练 + +运行示例: + +```text +python train.py +参数: +train_dir 训练集 +valid_dir 测试集 +segment 取得音频的长度 +sr 采样率 +N 输入通道数 +L 卷积核大小 +H 分离模块卷积块通道数 +R 分离层中重复次数 +C 说话者数量 +lr 学习率 +``` + +或者可以运行脚本: + +```bash +bash run_standalone_train.sh [DEVICE_ID] [SAVE_FOLDER] +``` + +上述命令将在后台运行,可以通过train.log查看结果 +每个epoch将运行10小时左右 + +- ### 分布式训练 + +分布式训练脚本如下 + +```bash +bash run_distribute_train.sh 8 1 /home/hccl_2p_01_127.0.0.1.json /home/heu_MEDAI/test/gdprnn/scripts/output +``` + +# 评估过程 + +## 评估 + +运行示例: + +```text +python eval.py +参数: +model_path ckpt文件 +data-dir 测试集路径 +batch_size 测试集batch大小 +``` + +或者可以运行脚本: + +```bash +bash run_eval.sh [DEVICE_ID] [MODEL_PATH] [DATA_DIR] +``` + +上述命令在后台运行,可以通过eval.log查看结果,测试结果如下 + +# 导出mindir模型 + +## 导出 + +```bash +python export.py +``` + +# 推理过程 + +## 推理 + +### 用法 + +```bash +./scripts/run_infer_310.sh [MINDIR_PATH] [TEST_PATH] [NEED_PREPROCESS] +``` + +### 结果 + +```text +Average SISNR improvement: 11.28 +``` + +# 模型描述 + +## 性能 + +### 训练性能 + +| 参数 | gdprnn | +| -------------------------- | ---------------------------------------------------------------| +| 资源 | Ascend910 | +| 上传日期 | 2022-9-10 | +| MindSpore版本 | 1.6.1 | +| 数据集 | Librimix | +| 训练参数 | 8p, epoch = 120, batch_size = 6, lr=0.001 | +| 优化器 | ADAM | +| 损失函数 | SI-SNR | +| 输出 | SI-SNR(11.28) | +| 损失值 | -15.9 | +| 运行速度 | 8p 14690 ms/step | +| 训练总时间 | 8p 约145h | + +# 随机情况说明 + +随机性主要来自下面两点: + +- 参数初始化 +- 轮换数据集 + +# ModelZoo主页 + + [ModelZoo主页](https://gitee.com/mindspore/models). diff --git a/research/audio/gdprnn/ascend310_infer/CMakeLists.txt b/research/audio/gdprnn/ascend310_infer/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d1c4ae5e2f22dcf3e147496c0e8da217344d5a5e --- /dev/null +++ b/research/audio/gdprnn/ascend310_infer/CMakeLists.txt @@ -0,0 +1,12 @@ +cmake_minimum_required(VERSION 3.14.1) +project(310infer) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fPIE -Wl,--allow-shlib-undefined") + +option(MINDSPORE_PATH "mindspore install path" "") +include_directories(${MINDSPORE_PATH}) +include_directories(${MINDSPORE_PATH}/include) +find_library(MS_LIB libmindspore.so ${MINDSPORE_PATH}/lib) +file(GLOB_RECURSE MD_LIB ${MINDSPORE_PATH}/_c_dataengine*) + +add_executable(swave main.cc) +target_link_libraries(swave ${MS_LIB} ${MD_LIB}) diff --git a/research/audio/gdprnn/ascend310_infer/build.sh b/research/audio/gdprnn/ascend310_infer/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..1aa5b25bc85b3aad7d595d3a18c7eb4154684959 --- /dev/null +++ b/research/audio/gdprnn/ascend310_infer/build.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ ! -d build ]; then + mkdir build +fi +cd build || exit +cmake .. \ + -DMINDSPORE_PATH="`pip show mindspore-ascend | grep Location | awk '{print $2"/mindspore"}' | xargs realpath`" +make diff --git a/research/audio/gdprnn/ascend310_infer/main.cc b/research/audio/gdprnn/ascend310_infer/main.cc new file mode 100644 index 0000000000000000000000000000000000000000..a9482b32913cf5819c4885f123648d764a205f48 --- /dev/null +++ b/research/audio/gdprnn/ascend310_infer/main.cc @@ -0,0 +1,226 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "include/api/context.h" +#include "include/api/model.h" +#include "include/api/serialization.h" +#include "include/dataset/execute.h" +#include "include/dataset/vision.h" + + +uint64_t GetTimeMicroSeconds() { + struct timespec t; + t.tv_sec = t.tv_nsec = 0; + clock_gettime(/*CLOCK_REALTIME*/0, &t); + return (uint64_t)t.tv_sec * 1000000ULL + t.tv_nsec / 1000L; +} +struct stat info; +namespace ms = mindspore; +namespace ds = mindspore::dataset; + +std::vector GetAllFiles(std::string_view dir_name); +DIR *OpenDir(std::string_view dir_name); +std::string RealPath(std::string_view path); +size_t WriteFile(ms::MSTensor& data, std::string outfile); +ms::MSTensor ReadFile(const std::string &file); +int WriteResult(const std::string& dataFile, const std::vector& outputs); + +int main(int argc, char **argv) { + // set context + auto context = std::make_shared(); + auto ascend310_info = std::make_shared(); + ascend310_info->SetDeviceID(0); + ascend310_info->SetPrecisionMode("allow_fp32_to_fp16"); + context->MutableDeviceInfo().push_back(ascend310_info); + + // define model + std::string ecapa_file = argv[1]; + std::string data_path = argv[2]; + ms::Graph graph; + ms::Status ret = ms::Serialization::Load(ecapa_file, ms::ModelType::kMindIR, &graph); + if (ret != ms::kSuccess) { + std::cout << "Load model failed." << std::endl; + return 1; + } + std::cout << "Load model success." << std::endl; + ms::Model swave; + + // build model + ret = swave.Build(ms::GraphCell(graph), context); + if (ret != ms::kSuccess) { + std::cout << "Build model failed." << std::endl; + return 1; + } + std::cout << "Build model success." << std::endl; + // get model info + std::vector model_inputs = swave.GetInputs(); + if (model_inputs.empty()) { + std::cout << "Invalid model, inputs is empty." << std::endl; + return 1; + } + + std::vector feats = GetAllFiles(data_path); + uint64_t Time1 = GetTimeMicroSeconds(); + for (const auto &feat_file : feats) { + // prepare input + std::vector outputs; + std::vector inputs; + + // read image file and preprocess + auto feat = ReadFile(feat_file); + + inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(), + feat.Data().get(), feat.DataSize()); + ret = swave.Predict(inputs, &outputs); + if (ret != ms::kSuccess) { + std::cout << "Predict model failed." << std::endl; + return 1; + } + int ret1 = WriteResult(feat_file, outputs); + if (ret1 != 0) { + std::cout << "write result failed." << std::endl; + return ret1; + } + } + uint64_t end = GetTimeMicroSeconds(); + printf("The total run time is: %f ms \n", static_cast(end - Time1) / 1000); + return 0; +} + +std::vector GetAllFiles(std::string_view dir_name) { + struct dirent *filename; + DIR *dir = OpenDir(dir_name); + if (dir == nullptr) { + return {}; + } + + /* read all the files in the dir ~ */ + std::vector res; + while ((filename = readdir(dir)) != nullptr) { + std::string d_name = std::string(filename->d_name); + // get rid of "." and ".." + if (d_name == "." || d_name == ".." || filename->d_type != DT_REG) + continue; + res.emplace_back(std::string(dir_name) + "/" + filename->d_name); + } + + std::sort(res.begin(), res.end()); + return res; +} + +DIR *OpenDir(std::string_view dir_name) { + // check the parameter ! + if (dir_name.empty()) { + std::cout << " dir_name is null ! " << std::endl; + return nullptr; + } + + std::string real_path = RealPath(dir_name); + + // check if dir_name is a valid dir + struct stat s; + lstat(real_path.c_str(), &s); + if (!S_ISDIR(s.st_mode)) { + std::cout << "dir_name is not a valid directory !" << std::endl; + return nullptr; + } + + DIR *dir; + dir = opendir(real_path.c_str()); + if (dir == nullptr) { + std::cout << "Can not open dir " << dir_name << std::endl; + return nullptr; + } + return dir; +} + +std::string RealPath(std::string_view path) { + char real_path_mem[PATH_MAX] = {0}; + char *real_path_ret = realpath(path.data(), real_path_mem); + + if (real_path_ret == nullptr) { + std::cout << "File: " << path << " is not exist."; + return ""; + } + + return std::string(real_path_mem); +} + +ms::MSTensor ReadFile(const std::string &file) { + if (file.empty()) { + std::cout << "Pointer file is nullptr" << std::endl; + return ms::MSTensor(); + } + + std::ifstream ifs(file); + if (!ifs.good()) { + std::cout << "File: " << file << " is not exist" << std::endl; + return ms::MSTensor(); + } + + if (!ifs.is_open()) { + std::cout << "File: " << file << "open failed" << std::endl; + return ms::MSTensor(); + } + + ifs.seekg(0, std::ios::end); + size_t size = ifs.tellg(); + ms::MSTensor buffer(file, ms::DataType::kNumberTypeFloat32, {1, 301, 80}, nullptr, size); + + ifs.seekg(0, std::ios::beg); + ifs.read(reinterpret_cast(buffer.MutableData()), size); + ifs.close(); + return buffer; +} + +int WriteResult(const std::string& dataFile, const std::vector &outputs) { + std::string homePath = "./result_Files"; + const int INVALID_POINTER = -1; + const int ERROR = -2; + for (size_t i = 0; i < outputs.size(); ++i) { + size_t outputSize; + std::shared_ptr netOutput = outputs[i].Data(); + outputSize = outputs[i].DataSize(); + int pos = dataFile.rfind('/'); + std::string fileName(dataFile, pos + 1); + fileName.replace(fileName.find('.'), fileName.size() - fileName.find('.'), '_' + std::to_string(i) + ".bin"); + std::string outFileName = homePath + "/" + fileName; + FILE *outputFile = fopen(outFileName.c_str(), "wb"); + if (outputFile == nullptr) { + std::cout << "open result file " << outFileName << " failed" << std::endl; + return INVALID_POINTER; + } + size_t size = fwrite(netOutput.get(), sizeof(char), outputSize, outputFile); + if (size != outputSize) { + fclose(outputFile); + outputFile = nullptr; + std::cout << "write result file " << outFileName << " failed, write size[" << size << + "] is smaller than output size[" << outputSize << "], maybe the disk is full." << std::endl; + return ERROR; + } + fclose(outputFile); + outputFile = nullptr; + } + return 0; +} diff --git a/research/audio/gdprnn/evaluate.py b/research/audio/gdprnn/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..b748b4f290531bb5b0bf322d7e3826fceee832a2 --- /dev/null +++ b/research/audio/gdprnn/evaluate.py @@ -0,0 +1,188 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import numpy as np +from src.models.swave import SWave +from src.models.loss import MyLoss +from src.data.data import DatasetGenerator +import mindspore +import mindspore.dataset as ds +import mindspore.ops as ops +from mindspore import context +from mindspore import load_checkpoint, load_param_into_net +from mir_eval.separation import bss_eval_sources + +parser = argparse.ArgumentParser() +parser.add_argument('--model_path', type=str, + default=r"/home/heu_MEDAI/1_gdprnn.ckpt", + help='Path to model file created by training') +parser.add_argument('--data_dir', type=str, + default=r"/mass_data/dataset/LS-2mix/Libri2Mix/tt", + help='Directory including mix.json, s1.json and s2.json') +parser.add_argument('--cal_sdr', type=int, default=0, + help='Whether calculate SDR, add this option because calculation of SDR is very slow') +parser.add_argument('--sample_rate', default=8000, type=int, + help='Sample rate') +parser.add_argument('--batch_size', default=2, type=int, + help='Batch size') +parser.add_argument('--device_id', type=int, default=0, + help='Device id') +# Network architecture +parser.add_argument('--N', default=128, type=int, + help='The number of expected features in the input') +parser.add_argument('--L', default=8, type=int, + help='kernel sizes') +parser.add_argument('--H', default=128, type=int, + help='The hidden size of RNN') +parser.add_argument('--R', default=6, type=int, + help='Model layers') +parser.add_argument('--C', default=2, type=int, + help='Maximum number of speakers') +parser.add_argument('--sr', default=8000, type=int, + help='Sample rate of audio file') +parser.add_argument('--segment', type=int, default=4, + help='Segment size') +parser.add_argument('--input_normalize', default=False, type=bool, + help='Normalize or not') + +def evaluate(args): + total_SISNRi = 0 + total_SDRi = 0 + total_cnt = 0 + + # Load model + model = SWave(args.N, args.L, args.H, args.R, args.C, args.sr, args.segment, input_normalize=False) + model.set_train(mode=False) + param_dict = load_checkpoint(args.model_path) + load_param_into_net(model, param_dict) + + # Load data + tt_dataset = DatasetGenerator(args.data_dir, args.batch_size, + sample_rate=args.sample_rate, segment=args.segment) + tt_loader = ds.GeneratorDataset(tt_dataset, ["mixture", "lens", "sources"], shuffle=False) + tt_loader = tt_loader.batch(batch_size=4) + + for data in tt_loader.create_dict_iterator(): + padded_mixture = data["mixture"] + mixture_lengths = data["lens"] + padded_source = data["sources"] + padded_mixture = ops.Cast()(padded_mixture, mindspore.float32) + padded_source = ops.Cast()(padded_source, mindspore.float32) + mixture_lengths_with_list = mixture_lengths.asnumpy().tolist() + estimate_source = model(padded_mixture)[-1] # [B, C, T] + my_loss = MyLoss() + _, estimate_source, reorder_estimate_source = \ + my_loss(padded_source, estimate_source, mixture_lengths) + # Remove padding and flat + mixture = remove_pad(padded_mixture, mixture_lengths_with_list) + source = remove_pad(padded_source, mixture_lengths_with_list) + # NOTE: use reorder estimate source + estimate_source = remove_pad(reorder_estimate_source, + mixture_lengths_with_list) + # for each utterance + for mix, src_ref, src_est in zip(mixture, source, estimate_source): + print("Utt", total_cnt + 1) + # Compute SDRi + if args.cal_sdr: + avg_SDRi = cal_SDRi(src_ref, src_est, mix) + total_SDRi += avg_SDRi + print("\tSDRi={0:.2f}".format(avg_SDRi)) + # Compute SI-SNRi + avg_SISNRi = cal_SISNRi(src_ref, src_est, mix) + print("\tSI-SNRi={0:.2f}".format(avg_SISNRi)) + total_SISNRi += avg_SISNRi + total_cnt += 1 + if args.cal_sdr: + print("Average SDR improvement: {0:.2f}".format(total_SDRi / total_cnt)) + print("Average SISNR improvement: {0:.2f}".format(total_SISNRi / total_cnt)) + + +def cal_SDRi(src_ref, src_est, mix): + """Calculate Source-to-Distortion Ratio improvement (SDRi). + NOTE: bss_eval_sources is very very slow. + Args: + src_ref: numpy.ndarray, [C, T] + src_est: numpy.ndarray, [C, T], reordered by best PIT permutation + mix: numpy.ndarray, [T] + Returns: + average_SDRi + """ + src_anchor = np.stack([mix, mix], axis=0) + sdr, _, _, _ = bss_eval_sources(src_ref, src_est) + sdr0, _, _, _ = bss_eval_sources(src_ref, src_anchor) + avg_SDRi = ((sdr[0]-sdr0[0]) + (sdr[1]-sdr0[1])) / 2 + return avg_SDRi + + +def cal_SISNRi(src_ref, src_est, mix): + """Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi) + Args: + src_ref: numpy.ndarray, [C, T] + src_est: numpy.ndarray, [C, T], reordered by best PIT permutation + mix: numpy.ndarray, [T] + Returns: + average_SISNRi + """ + sisnr1 = cal_SISNR(src_ref[0], src_est[0]) + sisnr2 = cal_SISNR(src_ref[1], src_est[1]) + sisnr1b = cal_SISNR(src_ref[0], mix) + sisnr2b = cal_SISNR(src_ref[1], mix) + avg_SISNRi = ((sisnr1 - sisnr1b) + (sisnr2 - sisnr2b)) / 2 + return avg_SISNRi + + +def cal_SISNR(ref_sig, out_sig, eps=1e-8): + """Calculate Scale-Invariant Source-to-Noise Ratio (SI-SNR) + Args: + ref_sig: numpy.ndarray, [T] + out_sig: numpy.ndarray, [T] + Returns: + SISNR + """ + assert len(ref_sig) == len(out_sig) + ref_sig = ref_sig - np.mean(ref_sig) + out_sig = out_sig - np.mean(out_sig) + ref_energy = np.sum(ref_sig ** 2) + eps + proj = np.sum(ref_sig * out_sig) * ref_sig / ref_energy + noise = out_sig - proj + ratio = np.sum(proj ** 2) / (np.sum(noise ** 2) + eps) + sisnr = 10 * np.log(ratio + eps) / np.log(10.0) + return sisnr + + +def remove_pad(inputs, inputs_lengths): + """ + Args: + inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size + inputs_lengths: torch.Tensor, [B] + Returns: + results: a list containing B items, each item is [C, T], T varies + """ + results = [] + dim = inputs.ndim + if dim == 3: + C = inputs.shape[1] + for i, input1 in enumerate(inputs): + if dim == 3: # [B, C, T] + results.append(input1[:, :inputs_lengths[i]].view(C, -1).asnumpy()) + elif dim == 2: # [B, T] + results.append(input1[:inputs_lengths[i]].view(-1).asnumpy()) + return results + +if __name__ == '__main__': + arg = parser.parse_args() + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", device_id=arg.device_id) + evaluate(arg) diff --git a/research/audio/gdprnn/export.py b/research/audio/gdprnn/export.py new file mode 100644 index 0000000000000000000000000000000000000000..cf047b6ab8aa03d9ee1060b14cc79bcf137475e1 --- /dev/null +++ b/research/audio/gdprnn/export.py @@ -0,0 +1,57 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import numpy as np +from src.models.swave import SWave +from mindspore.train.serialization import export +from mindspore import Tensor, context +from mindspore.train.serialization import load_checkpoint, load_param_into_net + + +parser = argparse.ArgumentParser() +parser.add_argument('--N', default=128, type=int, + help='The number of expected features in the input') +parser.add_argument('--L', default=8, type=int, + help='kernel sizes') +parser.add_argument('--H', default=128, type=int, + help='The hidden size of RNN') +parser.add_argument('--R', default=6, type=int, + help='Model layers') +parser.add_argument('--C', default=2, type=int, + help='Maximum number of speakers') +parser.add_argument('--sr', default=8000, type=int, + help='Sample rate of audio file') +parser.add_argument('--segment', type=int, default=4, + help='Segment size') +parser.add_argument('--input_normalize', default=False, type=bool, + help='Normalize or not') +parser.add_argument('--ckpt_path', default="/home/heu_MEDAI/1_gdprnn.ckpt", + help='Path to model file created by training') + +def export_gdprnn(): + """ export """ + args = parser.parse_args() + net = SWave(args.N, args.L, args.H, args.R, args.C, args.sr, args.segment, input_normalize=False) + + param_dict = load_checkpoint(args.ckpt_path) + load_param_into_net(net, param_dict) + input_data = Tensor(np.random.uniform(0.0, 1.0, size=[1, 32000]).astype(np.float32)) + export(net, input_data, file_name='SWave', file_format='MINDIR') + print("export success") + +if __name__ == '__main__': + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=4) + export_gdprnn() diff --git a/research/audio/gdprnn/postprocess.py b/research/audio/gdprnn/postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b2f959f2ca320b0720c3bf20baba40ce848b8b --- /dev/null +++ b/research/audio/gdprnn/postprocess.py @@ -0,0 +1,289 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import argparse +from itertools import permutations +import numpy as np +from mir_eval.separation import bss_eval_sources +from src.data.data import DatasetGenerator +import mindspore +import mindspore.dataset as ds +import mindspore.ops as ops +from mindspore import Tensor, context + +parser = argparse.ArgumentParser('Evaluate separation performance using TasNet') +parser.add_argument('--bin_path', type=str, + default=r"/home/heu_MEDAI/liwenjie/svoice-main-mindspore/outputs/exp_/half1_7007_lwj.ckpt", + help='Path to model file created by training') +parser.add_argument('--test_dir', type=str, + default=r"/home/liwenjie/svoice-main-ascend/egs/tr", + help='directory including mix.json, s1.json and s2.json') +parser.add_argument('--cal_sdr', type=int, default=0, + help='Whether calculate SDR, add this option because calculation of SDR is very slow') +parser.add_argument('--use_cuda', type=int, default=0, + help='Whether use GPU') +parser.add_argument('--sample_rate', default=8000, type=int, + help='Sample rate') +parser.add_argument('--batch_size', default=2, type=int, + help='Batch size') +parser.add_argument('--segment', default=4, type=int, + help='The hidden size of RNN') + +EPS = 1e-8 + +def evaluate(args, list1): + total_SISNRi = 0 + total_SDRi = 0 + total_cnt = 0 + + # Load data + tt_dataset = DatasetGenerator(args.test_dir, args.batch_size, + sample_rate=args.sample_rate, segment=args.segment) + tt_loader = ds.GeneratorDataset(tt_dataset, ["mixture", "lens", "sources"], shuffle=False) + tt_loader = tt_loader.batch(batch_size=1) + + i = 0 + for data in tt_loader.create_dict_iterator(): + padded_mixture = data["mixture"] + mixture_lengths = data["lens"] + padded_source = data["sources"] + padded_mixture = ops.Cast()(padded_mixture, mindspore.float32) + padded_source = ops.Cast()(padded_source, mindspore.float32) + mixture_lengths_with_list = mixture_lengths.asnumpy().tolist() + estimate_source = list1[i] + i += 1 + + _, estimate_source, reorder_estimate_source = \ + cal_loss(padded_source, estimate_source, mixture_lengths) + mixture = remove_pad(padded_mixture, mixture_lengths_with_list) + source = remove_pad(padded_source, mixture_lengths_with_list) + # NOTE: use reorder estimate source + estimate_source = remove_pad(reorder_estimate_source, + mixture_lengths_with_list) + # for each utterance + for mix, src_ref, src_est in zip(mixture, source, estimate_source): + print("Utt", total_cnt + 1) + # Compute SDRi + if args.cal_sdr: + avg_SDRi = cal_SDRi(src_ref, src_est, mix) + total_SDRi += avg_SDRi + print("\tSDRi={0:.2f}".format(avg_SDRi)) + # Compute SI-SNRi + avg_SISNRi = cal_SISNRi(src_ref, src_est, mix) + print("\tSI-SNRi={0:.2f}".format(avg_SISNRi)) + total_SISNRi += avg_SISNRi + total_cnt += 1 + if args.cal_sdr: + print("Average SDR improvement: {0:.2f}".format(total_SDRi / total_cnt)) + print("Average SISNR improvement: {0:.2f}".format(total_SISNRi / total_cnt)) + + +def cal_SDRi(src_ref, src_est, mix): + """Calculate Source-to-Distortion Ratio improvement (SDRi). + NOTE: bss_eval_sources is very very slow. + Args: + src_ref: numpy.ndarray, [C, T] + src_est: numpy.ndarray, [C, T], reordered by best PIT permutation + mix: numpy.ndarray, [T] + Returns: + average_SDRi + """ + src_anchor = np.stack([mix, mix], axis=0) + sdr, _, _, _ = bss_eval_sources(src_ref, src_est) + sdr0, _, _, _ = bss_eval_sources(src_ref, src_anchor) + avg_SDRi = ((sdr[0]-sdr0[0]) + (sdr[1]-sdr0[1])) / 2 + return avg_SDRi + + +def cal_SISNRi(src_ref, src_est, mix): + """Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi) + Args: + src_ref: numpy.ndarray, [C, T] + src_est: numpy.ndarray, [C, T], reordered by best PIT permutation + mix: numpy.ndarray, [T] + Returns: + average_SISNRi + """ + sisnr1 = cal_SISNR(src_ref[0], src_est[0]) + sisnr2 = cal_SISNR(src_ref[1], src_est[1]) + sisnr1b = cal_SISNR(src_ref[0], mix) + sisnr2b = cal_SISNR(src_ref[1], mix) + avg_SISNRi = ((sisnr1 - sisnr1b) + (sisnr2 - sisnr2b)) / 2 + return avg_SISNRi + + +def cal_SISNR(ref_sig, out_sig, eps=1e-8): + """Calculate Scale-Invariant Source-to-Noise Ratio (SI-SNR) + Args: + ref_sig: numpy.ndarray, [T] + out_sig: numpy.ndarray, [T] + Returns: + SISNR + """ + assert len(ref_sig) == len(out_sig) + ref_sig = ref_sig - np.mean(ref_sig) + out_sig = out_sig - np.mean(out_sig) + ref_energy = np.sum(ref_sig ** 2) + eps + proj = np.sum(ref_sig * out_sig) * ref_sig / ref_energy + noise = out_sig - proj + ratio = np.sum(proj ** 2) / (np.sum(noise ** 2) + eps) + sisnr = 10 * np.log(ratio + eps) / np.log(10.0) + return sisnr + + +def remove_pad(inputs, inputs_lengths): + """ + Args: + inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size + inputs_lengths: torch.Tensor, [B] + Returns: + results: a list containing B items, each item is [C, T], T varies + """ + results = [] + dim = inputs.ndim + if dim == 3: + C = inputs.shape[1] + for i, input1 in enumerate(inputs): + if dim == 3: + results.append(input1[:, :inputs_lengths[i]].view(C, -1).asnumpy()) + elif dim == 2: + results.append(input1[:inputs_lengths[i]].view(-1).asnumpy()) + return results + +def cal_loss(source, estimate_source, source_lengths): + """ + Args: + source: [B, C, T], B is batch size + estimate_source: [B, C, T] + source_lengths: [B] + """ + max_snr, perms, max_snr_idx = cal_si_snr_with_pit(source, estimate_source, source_lengths) + mean = ops.ReduceMean() + loss = 0 - mean(max_snr) + reorder_estimate_source = reorder_source(estimate_source, perms, max_snr_idx) + return loss, estimate_source, reorder_estimate_source + + +def cal_si_snr_with_pit(source, estimate_source_0, source_lengths): + """Calculate SI-SNR with PIT training. + Args: + source: [B, C, T], B is batch size + estimate_source: [B, C, T] + source_lengths: [B], each item is between [0, T] + """ + B, C, _ = source.shape + # mask padding position along T + mask = get_mask(source, source_lengths) + estimate_source_1 = estimate_source_0 * mask + + # Step 1. Zero-mean norm + num_samples = source_lengths.view(-1, 1, 1).astype(mindspore.float32) + ops_sum = ops.ReduceSum(keep_dims=True) + mean_target = ops_sum(source, 2) / num_samples + mean_estimate = ops_sum(estimate_source_1, 2) / num_samples + zero_mean_target_0 = source - mean_target + zero_mean_estimate_0 = estimate_source_1 - mean_estimate + # mask padding position along T + zero_mean_target = zero_mean_target_0 * mask + zero_mean_estimate = zero_mean_estimate_0 * mask + # Step 2. SI-SNR with PIT + # reshape to use broadcast + expand_dims_0 = ops.ExpandDims() + s_target = expand_dims_0(zero_mean_target, 1) # [B, 1, C, T] + s_estimate = expand_dims_0(zero_mean_estimate, 2) # [B, C, 1, T] + # s_target = s / ||s||^2 + pair_wise_dot_0 = ops_sum(s_estimate * s_target, 3) # [B, C, C, 1] + s_target_energy_0 = ops_sum(s_target * s_target, 3) + EPS # [B, 1, C, 1] + pair_wise_proj = pair_wise_dot_0 * s_target / s_target_energy_0 # [B, C, C, T] + # e_noise = s' - s_target + e_noise = s_estimate - pair_wise_proj # [B, C, C, T] + # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2) + _sum = ops.ReduceSum(keep_dims=False) + pair_wise_si_snr_0 = _sum(pair_wise_proj * pair_wise_proj, 3) / (_sum(e_noise * e_noise, 3) + EPS) + log = ops.Log() + log_10 = Tensor(np.array([10.0]), mindspore.float32) + temp = log(log_10) / 10 + pair_wise_si_snr_1 = log(pair_wise_si_snr_0 + EPS) + pair_wise_si_snr = pair_wise_si_snr_1 / temp # [B, C, C] + + # Get max_snr of each utterance + # permutations, [C!, C] + perms = Tensor(list(permutations(range(2))), dtype=mindspore.int64) + perms_one_hot = perms_one_hot = Tensor(np.array([[1, 0], [0, 1], [0, 1], [1, 0]]), mindspore.float32) + matmul = ops.MatMul() + snr_set = matmul(pair_wise_si_snr.view(B, -1), perms_one_hot) + Argmax = ops.Argmax(axis=1, output_type=mindspore.int32) + max_snr_idx = Argmax(snr_set) # [B] + argmax = ops.ArgMaxWithValue(axis=1, keep_dims=True) + _, max_snr = argmax(snr_set) + max_snr /= C + return max_snr, perms, max_snr_idx + +def reorder_source(source, perms, max_snr_idx): + """ + Args: + source: [B, C, T] + perms: [C!, C], permutations + max_snr_idx: [B], each item is between [0, C!) + Returns: + reorder_source: [B, C, T] + """ + B, C, _ = source.shape + # [B, C], permutation whose SI-SNR is max of each utterance + # for each utterance, reorder estimate source according this permutation + max_snr_perm = perms[max_snr_idx, :] + zeros_like = ops.ZerosLike() + reorder_sources = zeros_like(source) + for b in range(B): + for c in range(C): + if max_snr_perm[b][c] == 1: + reorder_sources[b, c] = source[b, 1] + else: + reorder_sources[b, c] = source[b, 0] + return reorder_sources + + +def get_mask(source, source_lengths): + """ + Args: + source: [B, C, T] + source_lengths: [B] + Returns: + mask: [B, 1, T] + """ + B, _, T = source.shape + ones = ops.Ones() + mask = ones((B, 1, T), mindspore.float32) + for i in range(B): + mask[i, :, 32000:] = 0 + return mask + +if __name__ == "__main__": + arg = parser.parse_args() + dataset_path = arg.bin_path + context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") + audio_files = os.listdir(dataset_path) + audio_files = sorted(audio_files, key=lambda x: int(os.path.splitext(x)[0])) + list0 = [] + for f in audio_files: + f_name = os.path.join(dataset_path, f.split('.')[0] + '.bin') + try: + logits = np.fromfile(f_name, np.float32).reshape(1, 2, 32000) + except ValueError as e: + logits = np.fromfile(f_name, np.float32).reshape(6, 1, 2, 32000)[-1] + logits = Tensor(logits) + list0.append(logits) + evaluate(arg, list0) diff --git a/research/audio/gdprnn/preprocess.py b/research/audio/gdprnn/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..91693a037537c40b7ab2f0ae766e646e1034b029 --- /dev/null +++ b/research/audio/gdprnn/preprocess.py @@ -0,0 +1,201 @@ +#2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import json +import math +import os +import argparse +import librosa +import numpy as np +import mindspore.dataset as ds +from mindspore import context + + +parser = argparse.ArgumentParser( + "Dual-path transformer" + "with Permutation Invariant Training") +parser.add_argument('--test_dir', type=str, default='/mass_data/dataset/LS-2mix/Libri2Mix/tt/', + help='directory including mix.json, s1.json and s2.json') +parser.add_argument('--out_dir', type=str, default='/cv', + help='directory including mix.json, s1.json and s2.json') +parser.add_argument('--batch_size', default=2, type=int, + help='Batch size') +parser.add_argument('--sample_rate', default=8000, type=int, + help='Sample rate') +parser.add_argument('--segment', default=4, type=float, + help='Segment length (seconds)') + +def load_mixtures_and_sources(batch): + """ + Each info include wav path and wav duration. + Returns: + mixtures: a list containing B items, each item is T np.ndarray + sources: a list containing B items, each item is T x C np.ndarray + T varies from item to item. + """ + mixtures, sources = [], [] + mix_infos, s1_infos, s2_infos, sample_rate, segment_len = batch + # for each utterance + for mix_info, s1_info, s2_info in zip(mix_infos, s1_infos, s2_infos): + mix_path = mix_info[0] + s1_path = s1_info[0] + s2_path = s2_info[0] + assert mix_info[1] == s1_info[1] and s1_info[1] == s2_info[1] + # read wav file + mix, _ = librosa.load(mix_path, sr=sample_rate) + s1, _ = librosa.load(s1_path, sr=sample_rate) + s2, _ = librosa.load(s2_path, sr=sample_rate) + # merge s1 and s2 + s = np.dstack((s1, s2))[0] + utt_len = mix.shape[-1] + if segment_len >= 0: + # segment + for i in range(0, utt_len - segment_len + 1, segment_len): + mixtures.append(mix[i:i+segment_len]) + sources.append(s[i:i+segment_len]) + if utt_len % segment_len != 0: + mixtures.append(mix[-segment_len:]) + sources.append(s[-segment_len:]) + else: + mixtures.append(mix) + sources.append(s) + return mixtures, sources + +def pad_list(xs): + n_batch = len(xs) + max_len = max(x.shape for x in xs) + if len(max_len) == 1: + pad = np.zeros((n_batch, max_len[0]), np.float32) + else: + pad = np.zeros((n_batch, max_len[0], max_len[1]), np.float32) + for h in range(n_batch): + temp = xs[h].shape + pad[h, :temp[0]] = xs[h] + return pad + + +class DatasetGenerator: + + def __init__(self, json_dir, batch_size, sample_rate=8000, segment=4.0, cv_maxlen=8.0): + """ + Args: + json_dir: directory including mix.json, s1.json and s2.json + segment: duration of audio segment, when set to -1, use full audio + + xxx_infos is a list and each item is a tuple (wav_file, #samples) + """ + super(DatasetGenerator, self).__init__() + mix_json = os.path.join(json_dir, 'mix.json') + s1_json = os.path.join(json_dir, 's1.json') + s2_json = os.path.join(json_dir, 's2.json') + with open(mix_json, 'r') as f: + mix_infos = json.load(f) + with open(s1_json, 'r') as f: + s1_infos = json.load(f) + with open(s2_json, 'r') as f: + s2_infos = json.load(f) + # sort it by #samples (impl bucket) + def sort(infos): + return sorted(infos, key=lambda info: int(info[1]), reverse=True) + sorted_mix_infos = sort(mix_infos) + sorted_s1_infos = sort(s1_infos) + sorted_s2_infos = sort(s2_infos) + + # segment length and count dropped utts + segment_len = int(segment * sample_rate) # 4s * 8000/s = 32000 samples + drop_utt, drop_len = 0, 0 + for _, sample in sorted_mix_infos: + if sample < segment_len: + drop_utt += 1 + drop_len += sample + print("Drop {} utts({:.2f} h) which is short than {} samples".format( + drop_utt, drop_len/sample_rate/36000, segment_len)) + mixture_pad = [] + lens1 = [] + source_pad = [] + start = 0 + while True: + num_segments = 0 + end = start + part_mix, part_s1, part_s2 = [], [], [] + while num_segments < batch_size and end < len(sorted_mix_infos): + utt_len = int(sorted_mix_infos[end][1]) + if utt_len >= segment_len: # skip too short utt + num_segments += math.ceil(utt_len / segment_len) + # Ensure num_segments is less than batch_size + if num_segments > batch_size: + # if num_segments of 1st audio > batch_size, skip it + if start == end: end += 1 + break + part_mix.append(sorted_mix_infos[end]) + part_s1.append(sorted_s1_infos[end]) + part_s2.append(sorted_s2_infos[end]) + end += 1 + if part_mix: + meta = [part_mix, part_s1, part_s2, sample_rate, segment_len] + mixtures_pad, ilens, sources_pad = self.sort_and_pad(meta) + for k in range(len(mixtures_pad)): + mixture_pad.append(mixtures_pad[k]) + lens1.append(ilens[k]) + source_pad.append(sources_pad[k]) + if end == len(sorted_mix_infos): + break + start = end + self.mixture = mixture_pad + self.len = lens1 + self.sources = source_pad + + def __getitem__(self, index): + return (self.mixture[index], self.len[index], self.sources[index]) + + def __len__(self): + return len(self.mixture) + + + def sort_and_pad(self, batch): + #assert len(batch) == 1 + mixtures, sources = load_mixtures_and_sources(batch) + + # get batch of lengths of input sequences + ilens = np.array([mix.shape[0] for mix in mixtures]) + ilens = ilens.astype(np.int16) + + mixtures_pad = pad_list([mix for mix in mixtures]) + + sources_pad = pad_list([s for s in sources]) + + sources_pad = sources_pad.transpose((0, 2, 1)) + return mixtures_pad, ilens, sources_pad + + + +if __name__ == "__main__": + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", device_id=0) + args = parser.parse_args() + print(args) + tr_dataset = DatasetGenerator(args.test_dir, args.batch_size, + sample_rate=args.sample_rate, segment=args.segment) + dataset_ = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], shuffle=False) + dataset = dataset_.batch(batch_size=2) + output_path = args.out_dir + + if not os.path.exists(output_path): + os.makedirs(output_path, exist_ok=False) + j = 1 + for data in tr_dataset: + mixture, lens, source = data + savename = os.path.join(output_path + str(j) + '.bin') + mixture.tofile(savename) + j += 1 diff --git a/research/audio/gdprnn/requirements.txt b/research/audio/gdprnn/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7604d6ea35bbd4724f43727f7bd7004bdd04914d --- /dev/null +++ b/research/audio/gdprnn/requirements.txt @@ -0,0 +1 @@ +librosa==0.9.1 \ No newline at end of file diff --git a/research/audio/gdprnn/scripts/run_distribute_train.sh b/research/audio/gdprnn/scripts/run_distribute_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..83b7984d02342ce4c07e9d68f230d21ea98f2e0b --- /dev/null +++ b/research/audio/gdprnn/scripts/run_distribute_train.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +if [ $# != 4 ] +then + echo "===========================================================================" + echo "Please run the script as: " + echo "For example:" + echo "Usage: bash run_distribute_train.sh [DEVICE_NUM] [DISTRIBUTE] [RANK_TABLE_FILE] [SAVE_FOLDER]" + echo "bash run_distribute_train.sh 8 1 ./hccl_8p.json /home" + echo "Using absolute path is recommended" + echo "===========================================================================" + exit 1 +fi + +export RANK_TABLE_FILE=$3 +export RANK_START_ID=0 +export RANK_SIZE=$1 +echo "lets begin!!!!XD" + +for((i=0;i<$1;i++)) +do + export DEVICE_ID=$((i + RANK_START_ID)) + export RANK_ID=$i + echo "start training for rank $i, device $DEVICE_ID" + env > env.log + + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp -r ../*.py ./train_parallel$i + cp -r ../src/ ./train_parallel$i + cd ./train_parallel$i || exit + python train.py --device_num=$1 --is_distribute=$2 --save_folder=$4 --device_id=$DEVICE_ID > paralletrain.log 2>&1 & + cd .. +done diff --git a/research/audio/gdprnn/scripts/run_eval.sh b/research/audio/gdprnn/scripts/run_eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..a2820b27bc139c83d488359f380cedf4c1d12c6e --- /dev/null +++ b/research/audio/gdprnn/scripts/run_eval.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 3 ] +then + echo "===========================================================================" + echo "Please run the script as:" + echo "For example:" + echo "Usage: bash run_eval.sh [DEVICE_ID] [MODEL_PATH] [DATA_DIR]" + echo "bash run_eval.sh 0 weights/gdprnn.ckpt /dataset/" + echo "Using absolute path is recommended" + echo "===========================================================================" + exit 1 +fi + +export DEVICE_ID=$1 +export RANK_SIZE=1 + +rm -rf ./eval +mkdir ./eval +cp -r ../*.py ./eval +cp -r ../src ./eval +cd ./eval + +python evaluate.py --device_id=$1 --model_path=$2 --data_dir=$3 > eval.log 2>&1 & diff --git a/research/audio/gdprnn/scripts/run_infer_310.sh b/research/audio/gdprnn/scripts/run_infer_310.sh new file mode 100644 index 0000000000000000000000000000000000000000..cdbef733a7e9ecfec2e81eaa2576a7dc96f47818 --- /dev/null +++ b/research/audio/gdprnn/scripts/run_infer_310.sh @@ -0,0 +1,110 @@ +#!/bin/bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [[ $# -lt 2 || $# -gt 3 ]]; then + echo "Usage: bash run_infer_310.sh [MODEL_PATH] [TEST_PATH] [NEED_PREPROCESS] + NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'." + exit 1 +fi + +get_real_path() { + if [ -z "$1" ]; then + echo "" + elif [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +model=$(get_real_path $1) +test_path=$(get_real_path $2) + +if [ "$3" == "y" ] || [ "$3" == "n" ]; then + need_preprocess=$3 +else + echo "weather need preprocess or not, it's value must be in [y, n]" + exit 1 +fi + +echo "mindir name: "$model +echo "test_path: "$test_path +echo "need preprocess: "$need_preprocess + +export ASCEND_HOME=/usr/local/Ascend/ +if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then + export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/ascend-toolkit/latest/atc/bin:$PATH + export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/ascend-toolkit/latest/atc/lib64:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH + export TBE_IMPL_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe + export PYTHONPATH=$ASCEND_HOME/fwkacllib/python/site-packages:${TBE_IMPL_PATH}:$ASCEND_HOME/ascend-toolkit/latest/fwkacllib/python/site-packages:$PYTHONPATH + export ASCEND_OPP_PATH=$ASCEND_HOME/ascend-toolkit/latest/opp +else + export ASCEND_HOME=/usr/local/Ascend/latest/ + export PATH=$ASCEND_HOME/fwkacllib/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/fwkacllib/ccec_compiler/bin:$ASCEND_HOME/atc/bin:$PATH + export LD_LIBRARY_PATH=$ASCEND_HOME/fwkacllib/lib64:/usr/local/lib:$ASCEND_HOME/atc/lib64:$ASCEND_HOME/acllib/lib64:$ASCEND_HOME/driver/lib64:$ASCEND_HOME/add-ons:$LD_LIBRARY_PATH + export PYTHONPATH=$ASCEND_HOME/atc/python/site-packages:$PYTHONPATH + export ASCEND_OPP_PATH=$ASCEND_HOME/opp +fi + +function preprocess_data() { + if [ -d preprocess_310_Result ];then + rm -rf ./preprocess_310_Result + fi + mkdir preprocess_310_Result + python ./preprocess.py --test_dir=$test_path --out_dir=./preprocess_310_Result/ +} + +function compile_app() { + cd ./ascend310_infer || exit + bash build.sh &>build.log +} + +function infer() { + cd - || exit + if [ -d result_Files ]; then + rm -rf ./result_Files + fi + mkdir result_Files + + ./ascend310_infer/build/swave $model ./preprocess_310_Result/ &> infer.log + +} + +function cal_acc() { + python ./postprocess.py --test_dir=$test_path --bin_path=./result_Files &> acc.log +} + +if [ $need_preprocess == "y" ]; then + preprocess_data + if [ $? -ne 0 ]; then + echo "preprocess dataset failed" + exit 1 + fi +fi +compile_app +if [ $? -ne 0 ]; then + echo "compile app code failed" + exit 1 +fi +infer +if [ $? -ne 0 ]; then + echo " execute inference failed" + exit 1 +fi +cal_acc +if [ $? -ne 0 ]; then + echo "calculate accuracy failed" + exit 1 +fi diff --git a/research/audio/gdprnn/scripts/run_standalone_train.sh b/research/audio/gdprnn/scripts/run_standalone_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..54160588d7985dc87027074bd55f86870a5731e1 --- /dev/null +++ b/research/audio/gdprnn/scripts/run_standalone_train.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "===========================================================================" + echo "Please run the script as: " + echo "For example:" + echo "bash run_standalone_train.sh [DEVICE_ID] [save_folder]" + echo "bash run_standalone_train.sh 0 /home/" + echo "Using absolute path is recommended" + echo "===========================================================================" + exit 1 +fi + +export DEVICE_ID=$1 +export RANK_ID=0 +export RANK_SIZE=1 +export SLOG_PRINT_TO_STDOUT=0 + + +rm -rf ./train_gdprnn +mkdir ./train_gdprnn +cp -r ../*.py ./train_gdprnn +cp -r ../src/ ./train_gdprnn +cd ./train_gdprnn || exit +python train.py --device_id=$DEVICE_ID --save_folder=$2 > train.log 2>&1 & diff --git a/research/audio/gdprnn/src/data/data.py b/research/audio/gdprnn/src/data/data.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce53fec00a8f4e0aa6024d13b9fc12418038d23 --- /dev/null +++ b/research/audio/gdprnn/src/data/data.py @@ -0,0 +1,176 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import json +import math +import os +import argparse +import librosa +import numpy as np + +parser = argparse.ArgumentParser( + "Dual-path transformer" + "with Permutation Invariant Training") +parser.add_argument('--train_dir', type=str, default='/home/heu_MEDAI/liwenjie/svoice-main-mindspore/egs/debug/tt', + help='directory including mix.json, s1.json and s2.json') +parser.add_argument('--valid_dir', type=str, default='/home/heu_MEDAI/fanruibo/project/out_dir/cv', + help='directory including mix.json, s1.json and s2.json') +parser.add_argument('--batch_size', default=2, type=int, #default =3 + help='Batch size') +parser.add_argument('--sample_rate', default=8000, type=int, + help='Sample rate') +parser.add_argument('--segment', default=4, type=float, + help='Segment length (seconds)') + +def load_mixtures_and_sources(batch): + """ + Each info include wav path and wav duration. + Returns: + mixtures: a list containing B items, each item is T np.ndarray + sources: a list containing B items, each item is T x C np.ndarray + T varies from item to item. + """ + mixtures, sources = [], [] + mix_infos, s1_infos, s2_infos, sample_rate, segment_len = batch + # for each utterance + for mix_info, s1_info, s2_info in zip(mix_infos, s1_infos, s2_infos): + mix_path = mix_info[0] + s1_path = s1_info[0] + s2_path = s2_info[0] + assert mix_info[1] == s1_info[1] and s1_info[1] == s2_info[1] + # read wav file + mix, _ = librosa.load(mix_path, sr=sample_rate) + s1, _ = librosa.load(s1_path, sr=sample_rate) + s2, _ = librosa.load(s2_path, sr=sample_rate) + + # merge s1 and s2 + s = np.dstack((s1, s2))[0] + utt_len = mix.shape[-1] + if segment_len >= 0: + # segment + for i in range(0, utt_len - segment_len + 1, segment_len): + mixtures.append(mix[i:i+segment_len]) + sources.append(s[i:i+segment_len]) + if utt_len % segment_len != 0: + mixtures.append(mix[-segment_len:]) + sources.append(s[-segment_len:]) + else: # full utterance + mixtures.append(mix) + sources.append(s) + return mixtures, sources + + +def pad_list(xs): + n_batch = len(xs) + max_len = max(x.shape for x in xs) + if len(max_len) == 1: + pad = np.zeros((n_batch, max_len[0]), np.float32) + else: + pad = np.zeros((n_batch, max_len[0], max_len[1]), np.float32) + for i in range(n_batch): + temp = xs[i].shape + pad[i, :temp[0]] = xs[i] + return pad + + +class DatasetGenerator: + + def __init__(self, json_dir, batch_size, sample_rate=8000, segment=4.0, cv_maxlen=8.0): + """ + Args: + json_dir: directory including mix.json, s1.json and s2.json + segment: duration of audio segment, when set to -1, use full audio + + xxx_infos is a list and each item is a tuple (wav_file, #samples) + """ + super(DatasetGenerator, self).__init__() + mix_json = os.path.join(json_dir, 'mix.json') + s1_json = os.path.join(json_dir, 's1.json') + s2_json = os.path.join(json_dir, 's2.json') + with open(mix_json, 'r') as f: + mix_infos = json.load(f) + with open(s1_json, 'r') as f: + s1_infos = json.load(f) + with open(s2_json, 'r') as f: + s2_infos = json.load(f) + # sort it by #samples (impl bucket) + def sort(infos): + return sorted(infos, key=lambda info: int(info[1]), reverse=True) + sorted_mix_infos = sort(mix_infos) + sorted_s1_infos = sort(s1_infos) + sorted_s2_infos = sort(s2_infos) + # segment length and count dropped utts + segment_len = int(segment * sample_rate) + drop_utt, drop_len = 0, 0 + for _, sample in sorted_mix_infos: + if sample < segment_len: + drop_utt += 1 + drop_len += sample + print("Drop {} utts({:.2f} h) which is short than {} samples".format( + drop_utt, drop_len/sample_rate/36000, segment_len)) + # generate minibach infomations + mixture_pad = [] + lens = [] + source_pad = [] + start = 0 + while True: + num_segments = 0 + end = start + part_mix, part_s1, part_s2 = [], [], [] + while num_segments < batch_size and end < len(sorted_mix_infos): + utt_len = int(sorted_mix_infos[end][1]) + if utt_len >= segment_len: # skip too short utt + num_segments += math.ceil(utt_len / segment_len) + # Ensure num_segments is less than batch_size + if num_segments > batch_size: + # if num_segments of 1st audio > batch_size, skip it + if start == end: end += 1 + break + part_mix.append(sorted_mix_infos[end]) + part_s1.append(sorted_s1_infos[end]) + part_s2.append(sorted_s2_infos[end]) + end += 1 + if part_mix: + meta = [part_mix, part_s1, part_s2, sample_rate, segment_len] + mixtures_pad, ilens, sources_pad = self.sort_and_pad(meta) + for i in range(len(mixtures_pad)): + mixture_pad.append(mixtures_pad[i]) + lens.append(ilens[i]) + source_pad.append(sources_pad[i]) + if end == len(sorted_mix_infos): + break + start = end + self.mixture = mixture_pad + self.len = lens + self.sources = source_pad + + def __getitem__(self, index): + return (self.mixture[index], self.len[index], self.sources[index]) + + def __len__(self): + return len(self.mixture) + + + def sort_and_pad(self, batch): + mixtures, sources = load_mixtures_and_sources(batch) + # get batch of lengths of input sequences + ilens = np.array([mix.shape[0] for mix in mixtures]) + + mixtures_pad = pad_list([mix for mix in mixtures]) + + sources_pad = pad_list([s for s in sources]) + + sources_pad = sources_pad.transpose((0, 2, 1)) + return mixtures_pad, ilens, sources_pad diff --git a/research/audio/gdprnn/src/data/preprocess.py b/research/audio/gdprnn/src/data/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..41d25a877bad2222c0cb2bfc590745b69019fe1b --- /dev/null +++ b/research/audio/gdprnn/src/data/preprocess.py @@ -0,0 +1,57 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import json +import os +import librosa + + +def preprocess_one_dir(in_dir, out_dir, out_filename, sample_rate=8000): + file_infos = [] + in_dir = os.path.abspath(in_dir) + wav_list = os.listdir(in_dir) + for wav_file in wav_list: + if not wav_file.endswith('.wav'): + continue + wav_path = os.path.join(in_dir, wav_file) + samples, _ = librosa.load(wav_path, sr=sample_rate) + file_infos.append((wav_path, len(samples))) + if not os.path.exists(out_dir): + os.makedirs(out_dir) + with open(os.path.join(out_dir, out_filename + '.json'), 'w') as f: + json.dump(file_infos, f, indent=4) + + +def preprocess(args): + for data_type in ['tr', 'cv', 'tt']: + for speaker in ['mix', 's1', 's2']: + preprocess_one_dir(os.path.join(args.in_dir, data_type, speaker), + os.path.join(args.out_dir, data_type), + speaker, + sample_rate=args.sr) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("WSJ0 data preprocessing") + parser.add_argument('--in-dir', type=str, default="dataset/debug", + help='Directory path of LS-2mix including tr, cv and tt') + parser.add_argument('--out-dir', type=str, default="egs/debug", + help='Directory path to put output files') + parser.add_argument('--sr', type=int, default=8000, + help='Sample rate of audio file') + arg = parser.parse_args() + print(arg) + preprocess(arg) diff --git a/research/audio/gdprnn/src/generatorloss.py b/research/audio/gdprnn/src/generatorloss.py new file mode 100644 index 0000000000000000000000000000000000000000..a7d6526d5d684d0829238baa0d83d04fb66430f1 --- /dev/null +++ b/research/audio/gdprnn/src/generatorloss.py @@ -0,0 +1,29 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore.nn as nn +from src.network_define import WithLossCell +from src.models.loss import MyLoss + +class Generatorloss(nn.Cell): + def __init__(self, generator): + super(Generatorloss, self).__init__() + self.generator = generator + self.my_loss = MyLoss() + self.net_with_loss = WithLossCell(self.generator, self.my_loss) + + def construct(self, maxture, lens, source): + loss = self.net_with_loss(maxture, lens, source) + return loss diff --git a/research/audio/gdprnn/src/models/loss.py b/research/audio/gdprnn/src/models/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..13ba65bf52155e1e955e859656387a5b9df3ab16 --- /dev/null +++ b/research/audio/gdprnn/src/models/loss.py @@ -0,0 +1,145 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from itertools import permutations +import numpy as np +import mindspore +import mindspore.ops as ops +from mindspore import nn +from mindspore import Tensor + +EPS = 1e-8 + +class MyLoss(nn.Cell): + def __init__(self): + super(MyLoss, self).__init__() + self.mean = ops.ReduceMean() + self.cast = ops.Cast() + self.sum = ops.ReduceSum(keep_dims=True) + self.expand_dims = ops.ExpandDims() + self._sum = ops.ReduceSum(keep_dims=False) + self.log = ops.Log() + self.scatter = ops.ScatterNd() + self.matmul = ops.MatMul() + self.transpose = ops.Transpose() + self.Argmax = ops.Argmax(axis=1, output_type=mindspore.int32) + self.argmax = ops.ArgMaxWithValue(axis=1, keep_dims=True) + self.ones = ops.Ones() + self.zeros_like = ops.ZerosLike() + self.log10 = Tensor(np.array([10.0]), mindspore.float32) + self.perms = Tensor(list(permutations(range(2))), dtype=mindspore.int64) + self.perms_one_hot = Tensor(np.array([[1, 0], [0, 1], [0, 1], [1, 0]]), mindspore.float32) + + + def construct(self, source, estimate_source, source_lengths): + return self.cal_loss(source, estimate_source, source_lengths) + + def cal_loss(self, source, estimate_source, source_lengths): + """ + Args: + source: [B, C, T], B is batch size + estimate_source: [B, C, T] + source_lengths: [B] + """ + max_snr, perms, max_snr_idx = self.cal_si_snr_with_pit(source, estimate_source, source_lengths) + loss = 0 - self.mean(max_snr) + reorder_estimate_source = self.reorder_source(estimate_source, perms, max_snr_idx) + return loss, estimate_source, reorder_estimate_source + + + def cal_si_snr_with_pit(self, source, estimate_source, source_lengths): + """Calculate SI-SNR with PIT training. + Args: + source: [B, C, T], B is batch size + estimate_source: [B, C, T] + source_lengths: [B], each item is between [0, T] + """ + B, C, _ = source.shape + # mask padding position along T + mask = self.get_mask(source, source_lengths) + estimate_source *= mask + + # Step 1. Zero-mean norm + # num_samples = self.cast(source_lengths.view(-1, 1, 1), mindspore.float32) # [B, 1, 1] + num_samples = source_lengths.view(-1, 1, 1).astype(mindspore.float32) + mean_target = self.sum(source, 2) / num_samples + mean_estimate = self.sum(estimate_source, 2) / num_samples + zero_mean_target = source - mean_target + zero_mean_estimate = estimate_source - mean_estimate + # mask padding position along T + zero_mean_target *= mask + zero_mean_estimate *= mask + + # Step 2. SI-SNR with PIT + # reshape to use broadcast + s_target = self.expand_dims(zero_mean_target, 1) # [B, 1, C, T] + s_estimate = self.expand_dims(zero_mean_estimate, 2) # [B, C, 1, T] + # s_target = s / ||s||^2 + pair_wise_dot = self.sum(s_estimate * s_target, 3) # [B, C, C, 1] + s_target_energy = self.sum(s_target ** 2, 3) + EPS # [B, 1, C, 1] + pair_wise_proj = pair_wise_dot * s_target / s_target_energy # [B, C, C, T] + # e_noise = s' - s_target + e_noise = s_estimate - pair_wise_proj # [B, C, C, T] + # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2) + pair_wise_si_snr = self._sum(pair_wise_proj ** 2, 3) / (self._sum(e_noise ** 2, 3) + EPS) + pair_wise_si_snr = 10 * self.log(pair_wise_si_snr + EPS) / self.log( + self.log10) # [B, C, C] + + # Get max_snr of each utterance + # permutations, [C!, C] + perms = self.perms + perms_one_hot = self.perms_one_hot + snr_set = self.matmul(pair_wise_si_snr.view(B, -1), perms_one_hot) + max_snr_idx = self.Argmax(snr_set) # [B] + _, max_snr = self.argmax(snr_set) + max_snr /= C + return max_snr, perms, max_snr_idx + + def reorder_source(self, source, perms, max_snr_idx): + """ + Args: + source: [B, C, T] + perms: [C!, C], permutations + max_snr_idx: [B], each item is between [0, C!) + Returns: + reorder_source: [B, C, T] + """ + B, C, _ = source.shape + # [B, C], permutation whose SI-SNR is max of each utterance + # for each utterance, reorder estimate source according this permutation + max_snr_perm = perms[max_snr_idx, :] + reorder_source = self.zeros_like(source) + for b in range(B): + for c in range(C): + if max_snr_perm[b][c] == 1: + reorder_source[b, c] = source[b, 1] + else: + reorder_source[b, c] = source[b, 0] + return reorder_source + + + def get_mask(self, source, source_lengths): + """ + Args: + source: [B, C, T] + source_lengths: [B] + Returns: + mask: [B, 1, T] + """ + B, _, T = source.shape + mask = self.ones((B, 1, T), mindspore.float32) + for i in range(B): + mask[i, :, 32000:] = 0 + return mask diff --git a/research/audio/gdprnn/src/models/swave.py b/research/audio/gdprnn/src/models/swave.py new file mode 100644 index 0000000000000000000000000000000000000000..3f1b32c1f23a09ee84204cd0a6e5940ce19b9d3a --- /dev/null +++ b/research/audio/gdprnn/src/models/swave.py @@ -0,0 +1,370 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import functools +import numpy as np +import mindspore +import mindspore.ops as ops +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.initializer import initializer, HeNormal + + +class MulCatBlock(nn.Cell): + def __init__(self, input_size, hidden_size, dropout=0.0, bidirectional=False): + super(MulCatBlock, self).__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.num_direction = int(bidirectional) + 1 + + self.rnn = nn.LSTM(input_size, hidden_size, 1, dropout=dropout, + batch_first=True, bidirectional=bidirectional) + self.rnn_proj = nn.Dense(hidden_size * self.num_direction, input_size, weight_init="HeNormal") + + self.gate_rnn = nn.LSTM(input_size, hidden_size, num_layers=1, + batch_first=True, dropout=dropout, bidirectional=bidirectional) + self.gate_rnn_proj = nn.Dense( + hidden_size * self.num_direction, input_size, weight_init="HeNormal") + + self.block_projection = nn.Dense(input_size * 2, input_size, weight_init="HeNormal") + self.mul = ops.Mul() + self.op2 = ops.Concat(2) + + def construct(self, inputs): + output = inputs + rnn_output, _ = self.rnn(output) + rnn_output = self.rnn_proj(rnn_output.view(-1, rnn_output.shape[2])).view(output.shape) + # run gate rnn module + gate_rnn_output, _ = self.gate_rnn(output) + gate_rnn_output = self.gate_rnn_proj(gate_rnn_output.view(-1, gate_rnn_output.shape[2])).view(output.shape) + # apply gated rnn + gated_output = self.mul(rnn_output, gate_rnn_output) + gated_output = self.op2([gated_output, output]) + gated_output = self.block_projection( + gated_output.view(-1, gated_output.shape[2])).view(output.shape) + return gated_output + + +class ByPass(nn.Cell): + def construct(self, inputs): + return inputs + + +class DPMulCat(nn.Cell): + def __init__(self, input_size, hidden_size, output_size, num_spk, + dropout=0.0, num_layers=1, bidirectional=True, input_normalize=False): + super(DPMulCat, self).__init__() + + self.input_size = input_size + self.output_size = output_size + self.hidden_size = hidden_size + self.in_norm = input_normalize + self.num_layers = num_layers + + self.rows_grnn = nn.CellList([]) + self.cols_grnn = nn.CellList([]) + self.rows_normalization = nn.CellList([]) + self.cols_normalization = nn.CellList([]) + + # create the dual path pipeline + for _ in range(num_layers): + self.rows_grnn.append(MulCatBlock( + input_size, hidden_size, dropout, bidirectional=bidirectional)) + self.cols_grnn.append(MulCatBlock( + input_size, hidden_size, dropout, bidirectional=bidirectional)) + if self.in_norm: + self.rows_normalization.append( + nn.GroupNorm(1, input_size, eps=1e-8)) + self.cols_normalization.append( + nn.GroupNorm(1, input_size, eps=1e-8)) + else: + # used to disable normalization + self.rows_normalization.append(ByPass()) + self.cols_normalization.append(ByPass()) + self.output = nn.SequentialCell( + nn.PReLU(), nn.Conv2d(input_size, output_size * num_spk, 1, has_bias=True)) + + def construct(self, inputs): + batch_size, _, d1, d2 = inputs.shape + output = inputs + output_all = [] + for i in range(self.num_layers): + row_input = output.transpose(0, 3, 2, 1).view( + batch_size * d2, d1, -1) + row_output = self.rows_grnn[i](row_input) + row_output = row_output.view( + batch_size, d2, d1, -1).transpose(0, 3, 2, 1) + row_output = self.rows_normalization[i](row_output) + # apply a skip connection + if not self.training: + output = output + row_output + else: + output += row_output + + col_input = output.transpose(0, 2, 3, 1).view( + batch_size * d1, d2, -1) + col_output = self.cols_grnn[i](col_input) + col_output = col_output.view( + batch_size, d1, d2, -1).transpose(0, 3, 1, 2) + col_output = self.cols_normalization[i](col_output) + # apply a skip connection + if not self.training: + output = output + col_output + else: + output += col_output + + output_i = self.output(output) + if not self.training or i == (self.num_layers - 1): + output_all.append(output_i) + return output_all + +class Separator(nn.Cell): + def __init__(self, input_dim, feature_dim, hidden_dim, output_dim, num_spk=2, + layer=4, segment_size=100, input_normalize=False, bidirectional=True): + super(Separator, self).__init__() + + self.input_dim = input_dim + self.feature_dim = feature_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + + self.layer = layer + self.segment_size = segment_size + self.num_spk = num_spk + self.input_normalize = input_normalize + self.zeros = ops.Zeros() + self.op2 = ops.Concat(2) + self.op3 = ops.Concat(3) + self.transpose = ops.Transpose() + self.input_perm = (0, 1, 3, 2) + + self.rnn_model = DPMulCat(self.feature_dim, self.hidden_dim, + self.feature_dim, self.num_spk, + num_layers=layer, bidirectional=bidirectional, input_normalize=input_normalize) + + def pad_segment(self, inputs, segment_size): + # input is the features: (B, N, T) + batch_size, dim, seq_len = inputs.shape + segment_stride = segment_size // 2 + rest = segment_size - (segment_stride + seq_len % + segment_size) % segment_size + + if rest > 0: + pad = self.zeros((batch_size, dim, rest), mindspore.float32) + inputs = self.op2([inputs, pad]) + + pad_aux = self.zeros(( + batch_size, dim, segment_stride), mindspore.float32) + inputs = self.op2([pad_aux, inputs, pad_aux]) + return inputs, rest + + def create_chuncks(self, inputs, segment_size): + # split the feature into chunks of segment size + # input is the features: (B, N, T) + input0, rest = self.pad_segment(inputs, segment_size) + batch_size, dim, _ = input0.shape + segment_stride = segment_size // 2 + + segments1 = input0[:, :, :-segment_stride].view(batch_size, dim, -1, segment_size) + segments2 = input0[:, :, segment_stride:].view(batch_size, dim, -1, segment_size) + + segments = self.transpose(self.op3([segments1, segments2]).view( + batch_size, dim, -1, segment_size), self.input_perm) + return segments, rest + + def merge_chuncks(self, inputs, rest): + # merge the split features into full utterance + # input is the features: (B, N, L, K) + + batch_size, dim, segment_size, _ = inputs.shape + segment_stride = segment_size // 2 + input0 = self.transpose(inputs, self.input_perm).view(batch_size, + dim, -1, segment_size*2) + + input1 = input0[:, :, :, :segment_size].view( + batch_size, dim, -1)[:, :, segment_stride:] + input2 = input0[:, :, :, segment_size:].view( + batch_size, dim, -1)[:, :, :-segment_stride] + + output = input1 + input2 + if rest > 0: + output = output[:, :, :-rest] + return output + + def construct(self, inputs): + # create chunks + enc_segments, enc_rest = self.create_chuncks( + inputs, self.segment_size) + # separate + output_all = self.rnn_model(enc_segments) + + # merge back audio files + output_all_wav = [] + for ii in range(len(output_all)): + output_ii = self.merge_chuncks(output_all[ii], enc_rest) + output_all_wav.append(output_ii) + return output_all_wav + + +def capture_init(init): + """ + Decorate `__init__` with this, and you can then + recover the *args and **kwargs passed to it in `self.kwarg` + """ + @functools.wraps(init) + def __init__(self, *args, **kwargs): + self.kwarg = (args, kwargs) + init(self, *args, **kwargs) + + return __init__ + +class SWave(nn.Cell): + @capture_init + def __init__(self, N, L, H, R, C, sr, segment, input_normalize): + super(SWave, self).__init__() + # hyper-parameter + self.N, self.L, self.H, self.R, self.C, self.sr, self.segment = N, L, H, R, C, sr, segment + self.input_normalize = input_normalize + self.context_len = 2 * self.sr / 1000 + self.context = int(self.sr * self.context_len / 1000) + self.layer = self.R + self.filter_dim = self.context * 2 + 1 + self.num_spk = self.C + self.stack = ops.Stack() + # similar to dprnn paper, setting chancksize to sqrt(2*L) + self.segment_size = int( + np.sqrt(2 * self.sr * self.segment / (self.L/2))) + # model sub-networks` + self.encoder = Encoder(L, N) + self.decoder = Decoder(L) + self.separator = Separator(self.filter_dim + self.N, self.N, self.H, + self.filter_dim, self.num_spk, self.layer, self.segment_size, self.input_normalize) + # init + for p in self.get_parameters(): + if p.dim() > 1: + initializer(HeNormal(), p.shape, mindspore.float32) + + def construct(self, mixture): + mixture_w = self.encoder(mixture) + output_all = self.separator(mixture_w) + # fix time dimension, might change due to convolution operations + T_mix = mixture.shape[-1] + # generate wav after each RNN block and optimize the loss + outputs = [] + for ii in range(len(output_all)): + output_ii = output_all[ii].view( + mixture.shape[0], self.C, self.N, mixture_w.shape[2]) + output_ii = self.decoder(output_ii) + + T_est = output_ii.shape[-1] + output_ii = output_ii[:, :, 0:T_mix-T_est] + outputs.append(output_ii) + return self.stack(outputs) + + +class Encoder(nn.Cell): + def __init__(self, L, N): + super(Encoder, self).__init__() + self.L, self.N = L, N + self.relu = ops.ReLU() + self.expand_dims = ops.ExpandDims() + # setting 50% overlap + self.conv = nn.Conv1d( + 1, N, kernel_size=L, stride=L // 2, has_bias=False, pad_mode="pad", weight_init="HeNormal") + + def construct(self, mixture): + mixture = self.expand_dims(mixture, 1) + mixture_w = self.relu(self.conv(mixture)) + return mixture_w + +def matrix(): + a = np.zeros(31996, dtype=np.int16) + for i in range(1, 31996): + if i % 4 == 0: + a[i] = a[i-4]+1 + else: + a[i] = a[i-1]+1 + mat = np.zeros((8002, 31996), dtype=np.int16) + for i in range(31996): + mat[a[i]][i] = 1 + mat = Tensor(mat, dtype=mindspore.float32) + transpose = ops.Transpose() + mat = transpose(mat, (1, 0)) + return mat + + +class Decoder(nn.Cell): + def __init__(self, L): + super(Decoder, self).__init__() + self.L = L + self.zeros = ops.Zeros() + self.transpose = ops.Transpose() + self.op2 = ops.Concat(2) + self.mat = matrix() + + def gcd(self, a, b): + if a < b: + m = b + n = a + else: + m = a + n = b + r = m % n + while r != 0: + m = n + n = r + r = m % n + return n + + def overlap_and_add(self, signal, frame_step): + """Reconstructs a signal from a framed representation. + + Adds potentially overlapping frames of a signal with shape + `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. + The resulting tensor has shape `[..., output_size]` where + + output_size = (frames - 1) * frame_step + frame_length + + Args: + signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2. + frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length. + + Returns: + A Tensor with shape [..., output_size] containing the overlap-added + frames of signal's inner-most two dimensions. + output_size = (frames - 1) * frame_step + frame_length + + Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py + """ + outer_dimensions = signal.shape[:-2] + _, frame_length = signal.shape[-2:] + + # gcd=Greatest Common Divisor + subframe_length = self.gcd(frame_length, frame_step) + + subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) + subframe_signal = self.transpose(subframe_signal, (0, 1, 3, 2)) + result = ops.matmul(subframe_signal, self.mat) + result = self.transpose(result, (0, 1, 3, 2)) + result = result.view(*outer_dimensions, -1) + return result + + def construct(self, est_source): + est_source = self.transpose(est_source, (0, 1, 3, 2)) + pool = nn.AvgPool2d((1, self.L), stride=(1, self.L)) + est_source = pool(est_source) + est_source = self.overlap_and_add(est_source, self.L//2) + return est_source diff --git a/research/audio/gdprnn/src/network_define.py b/research/audio/gdprnn/src/network_define.py new file mode 100644 index 0000000000000000000000000000000000000000..4a763edffd050dda2d03ec98409a372b37d91901 --- /dev/null +++ b/research/audio/gdprnn/src/network_define.py @@ -0,0 +1,48 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore +import mindspore.nn as nn +import mindspore.ops as ops + +class WithLossCell(nn.Cell): + """ + Wrap the network with loss function to compute loss. + + Args: + net (Cell): The target network to wrap. + loss_fn (Cell): The loss function used to compute loss. + """ + def __init__(self, net, loss_fn): + super(WithLossCell, self).__init__(auto_prefix=False) + self._net = net + self._loss = loss_fn + self.cast = ops.Cast() + self.ones = ops.Ones() + self.zeros = ops.Zeros() + + def construct(self, padded_mixture, mixture_lengths, padded_source): + padded_mixture = padded_mixture.astype(mindspore.float32) + padded_source = padded_source.astype(mindspore.float32) + estimate_source = self._net(padded_mixture) + estimate_source = estimate_source.astype(mindspore.float32) + loss = 0 + cnt = len(estimate_source) + for c_idx, est_src in enumerate(estimate_source): + coeff = (c_idx+1)*(1.0/cnt) + sisnr_loss, est_src, _ = self._loss(padded_source, est_src, mixture_lengths) + loss += (coeff * sisnr_loss) + loss /= len(estimate_source) + return loss diff --git a/research/audio/gdprnn/src/trainonestep.py b/research/audio/gdprnn/src/trainonestep.py new file mode 100644 index 0000000000000000000000000000000000000000..804174e57966dc4d3cec3550e86a6afb38d3c4d0 --- /dev/null +++ b/research/audio/gdprnn/src/trainonestep.py @@ -0,0 +1,59 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import context +from mindspore.communication.management import get_group_size +from mindspore.context import ParallelMode +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.parallel._auto_parallel_context import auto_parallel_context + +class TrainOneStep(nn.TrainOneStepCell): + def __init__(self, network, optimizer, sens=1.0, use_global_norm=True, clip_global_norm_value=5.0): + super(TrainOneStep, self).__init__(network, optimizer, sens) + self.network = network + self.network.set_grad() + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation(get_by_list=True, sens_param=True) + self.sens = float(sens) + self.reducer_flag = False + self.grad_reducer = None + self.use_global_norm = use_global_norm + self.clip_global_norm_value = clip_global_norm_value + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + if self.reducer_flag: + mean = context.get_auto_parallel_context("gradients_mean") + if auto_parallel_context().get_device_num_is_set(): + degree = context.get_auto_parallel_context("device_num") + else: + degree = get_group_size() + self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) + + + def construct(self, padded_mixture, mixture_lengths, padded_source): + loss = self.network(padded_mixture, mixture_lengths, padded_source) + sens = P.Fill()(loss.dtype, loss.shape, self.sens) + grads = self.grad(self.network, self.weights)(padded_mixture, mixture_lengths, padded_source, sens) + if self.reducer_flag: + grads = self.grad_reducer(grads) + if self.use_global_norm: + grads = C.clip_by_global_norm(grads, clip_norm=self.clip_global_norm_value) + loss = ops.depend(loss, self.optimizer(grads)) + return loss diff --git a/research/audio/gdprnn/train.py b/research/audio/gdprnn/train.py new file mode 100644 index 0000000000000000000000000000000000000000..cc7f46eabe32ccd9c760f5a3a157493cbf7c11de --- /dev/null +++ b/research/audio/gdprnn/train.py @@ -0,0 +1,213 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import argparse +import time +from src.data.preprocess import preprocess +from src.models.swave import SWave +from src.data.data import DatasetGenerator +from src.generatorloss import Generatorloss +from src.trainonestep import TrainOneStep +from mindspore import save_checkpoint, load_checkpoint, load_param_into_net +import mindspore.dataset as ds +from mindspore import nn +from mindspore import context +from mindspore.context import ParallelMode +from mindspore.communication.management import init, get_rank, get_group_size + +parser = argparse.ArgumentParser() +parser.add_argument('--in-dir', type=str, default=r"/home/work/user-job-dir/inputs/data/", + help='Directory path of LS-2mix including tr, cv and tt') +parser.add_argument('--out-dir', type=str, default=r"/home/work/user-job-dir/inputs/data_json", + help='Directory path to put output files') +parser.add_argument('--sample-rate', type=int, default=8000, + help='Sample rate of audio file') +parser.add_argument('--data_url', + help='path to training/inference dataset folder', + default='/home/work/user-job-dir/inputs/data/') +parser.add_argument('--train_url', + help='Model folder to save/load', + default='/home/work/user-job-dir/model/') +parser.add_argument( + '--device_target', + type=str, + default="Ascend", + choices=['Ascend', 'GPU', 'CPU'], + help='device where the code will be implemented (default: Ascend)') +parser.add_argument('--segment', type=int, default=4, + help='Segment size') +parser.add_argument('--batch_size', type=int, default=6, + help='Batch size') +parser.add_argument('--epochs', type=int, default=120, + help='Epoch') +parser.add_argument('--device_num', type=int, default=8, + help='Device num') +parser.add_argument('--device_id', type=int, default=0, + help='Device id') +parser.add_argument('--is_distribute', type=int, default=0, + help='Distribute or not') +parser.add_argument('--data_batch_size', type=int, default=3, + help='Data num') +parser.add_argument('--train', type=str, default='/home/heu_MEDAI/liwenjie/littledata/tr', + help='path to training/inference dataset folder') +parser.add_argument('--lr', type=float, default=1e-3, + help='Learning rate') +parser.add_argument('--modelArts', default=0, type=int, + help='Cload') +parser.add_argument('--continue_train', default=0, type=int, + help='Continue from checkpoint model') +parser.add_argument('--model_type', type=str, default='swave') +parser.add_argument('--save_folder', default='output', + help='Location to save epoch models') +parser.add_argument('--ckpt_path', default='1_gdprnn.ckpt', + help='Path to model file created by training') +parser.add_argument('--N', default=128, type=int, + help='The number of expected features in the input') +parser.add_argument('--L', default=8, type=int, + help='kernel sizes') +parser.add_argument('--H', default=128, type=int, + help='The hidden size of RNN') +parser.add_argument('--R', default=6, type=int, + help='Model layers') +parser.add_argument('--C', default=2, type=int, + help='Maximum number of speakers') +parser.add_argument('--sr', default=8000, type=int, + help='Sample rate of audio file') +parser.add_argument('--input_normalize', default=False, type=bool, + help='Normalize or not') + +def train(trainoneStep, dataset, train_dir, obs_train_url, args): + tr_loader = dataset['tr_loader'] + step = tr_loader.get_dataset_size() + i = 0 + for epoch in range(args.epochs): + total_loss = 0 + j = 0 + for data in tr_loader: + mixture, lens, source = [x for x in data] + t0 = time.time() + loss = trainoneStep(mixture, lens, source) + t1 = time.time() + print("epoch[{}]({}/{}),loss:{:.4f},stepTime:{}".format(epoch + 1, j+1, step, loss.asnumpy(), t1 - t0)) + j = j + 1 + total_loss += loss + train_loss = total_loss/j + print("epoch[{}]:trainAvgLoss:{:.4f}".format(epoch + 1, train_loss.asnumpy())) + if args.modelArts: + save_checkpoint_path = train_dir + '/device_' + os.getenv('DEVICE_ID') + '/' + else: + save_checkpoint_path = args.save_folder + if not os.path.exists(save_checkpoint_path): + os.makedirs(save_checkpoint_path) + i = i % 5 + if args.modelArts: + save_ckpt = os.path.join(save_checkpoint_path, '{}_gdprnn.ckpt'.format(i)) + save_checkpoint(trainoneStep.network, save_ckpt) + else: + if args.is_distribute and args.device_id == 0: + save_ckpt = os.path.join(save_checkpoint_path, '{}_gdprnn.ckpt'.format(i)) + save_checkpoint(trainoneStep.network, save_ckpt) + if not args.is_distribute: + save_ckpt = os.path.join(save_checkpoint_path, '{}_gdprnn.ckpt'.format(i)) + save_checkpoint(trainoneStep.network, save_ckpt) + i = i + 1 + if args.modelArts: + mox.file.copy_parallel(train_dir, obs_train_url) + print("Successfully Upload {} to {}".format(train_dir, + obs_train_url)) + +def main(args): + device_num = int(os.environ.get("RANK_SIZE", 1)) + if device_num == 1: + is_distributed = 0 + elif device_num > 1: + is_distributed = 1 + + if is_distributed: + print("parallel init", flush=True) + init() + rank_id = get_rank() + context.reset_auto_parallel_context() + parallel_mode = ParallelMode.DATA_PARALLEL + rank_size = get_group_size() + context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=args.device_num) + context.set_auto_parallel_context(parameter_broadcast=True) + print("Starting traning on multiple devices...") + else: + if args.modelArts: + init() + rank_id = get_rank() + rank_size = get_group_size() + else: + context.set_context(device_id=args.device_id) + if args.modelArts: + import moxing as mox + home = os.path.dirname(os.path.realpath(__file__)) + obs_data_url = args.data_url + args.data_url = '/home/work/user-job-dir/inputs/data/' + train_dir = os.path.join(home, 'checkpoints') + str(rank_id) + + obs_train_url = args.train_url + if not os.path.exists(train_dir): + os.mkdir(train_dir) + + mox.file.copy_parallel(obs_data_url, args.data_url) + print("Successfully Download {} to {}".format(obs_data_url, + args.data_url)) + preprocess(args) + + net = SWave(args.N, args.L, args.H, args.R, args.C, args.sr, args.segment, input_normalize=False) + + + if args.continue_train: + if args.modelArts: + home = os.path.dirname(os.path.realpath(__file__)) + ckpt = os.path.join(home, args.ckpt_path) + params = load_checkpoint(ckpt) + load_param_into_net(net, params) + else: + params = load_checkpoint(args.ckpt_path) + load_param_into_net(net, params) + + tr_dataset = DatasetGenerator(args.train, args.data_batch_size, + sample_rate=args.sample_rate, segment=args.segment) + if is_distributed: + tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], + shuffle=False, num_shards=rank_size, shard_id=rank_id) + else: + tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], + shuffle=False) + tr_loader = tr_loader.batch(args.batch_size) + num_steps = tr_loader.get_dataset_size() + data = {"tr_loader": tr_loader} + + loss_network = Generatorloss(net) + milestone = [45 * num_steps, 78 * num_steps, 120 * num_steps] + learning_rates = [1e-3, 5e-4, 2e-4] + lr = nn.piecewise_constant_lr(milestone, learning_rates) + optimizer = nn.Adam(net.trainable_params(), learning_rate=lr, beta1=0.9, beta2=0.999) + trainonestepNet = TrainOneStep(loss_network, optimizer, sens=1.0) + if args.modelArts: + train(trainonestepNet, data, train_dir, obs_train_url, args) + else: + train_dir = '/home/' + obs_train_url = '/home/' + train(trainonestepNet, data, train_dir, obs_train_url, args) + +if __name__ == '__main__': + arg = parser.parse_args() + context.set_context(mode=context.GRAPH_MODE, device_target=arg.device_target) + main(arg)