From d4578f9d8f3fe638273ec93cf041e0de02cbb648 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=95=87=E4=BA=AE?= Date: Wed, 6 Apr 2022 09:49:14 +0000 Subject: [PATCH 1/2] update ACL_PyTorch/built-in/cv/Res2Net_v1b_101_for_PyTorch/pth2onnx.py. --- .../Res2Net_v1b_101_for_PyTorch/pth2onnx.py | 67 +++++++++++-------- 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/ACL_PyTorch/built-in/cv/Res2Net_v1b_101_for_PyTorch/pth2onnx.py b/ACL_PyTorch/built-in/cv/Res2Net_v1b_101_for_PyTorch/pth2onnx.py index 852d370ee6..776b44620a 100644 --- a/ACL_PyTorch/built-in/cv/Res2Net_v1b_101_for_PyTorch/pth2onnx.py +++ b/ACL_PyTorch/built-in/cv/Res2Net_v1b_101_for_PyTorch/pth2onnx.py @@ -24,53 +24,62 @@ parser.add_argument('-m', '--trained_model', default=None, type=str, help='Trained state_dict file path to open') parser.add_argument('-o', '--output', default=None, type=str, help='ONNX model file') +parser.add_argument('--optimizer', default='store_ture', + help='ONNX model optimizer') args = parser.parse_args() model = res2net101_v1b_26w_4s() checkpoint = torch.load(args.trained_model, map_location=torch.device('cpu')) -presistent_buffers = {k: v for k, v in model.named_buffers() if k not in model._non_persistent_buffers_set} -local_name_params = itertools.chain(model.named_parameters(), presistent_buffers.items()) -local_state = {k: v for k, v in local_name_params if v is not None} -for name, param in checkpoint.items(): - if local_state[name].shape != param.shape: - if 'conv1' in name or 'conv3' in name: - n1, c1, h, w = local_state[name].shape - n2, c2, h, w = param.shape - if n1 == n2: - c = (c1 - c2) // 4 - cell = c2 // 4 - checkpoint[name] = torch.cat([torch.cat((param[:, i * cell: (i + 1) * cell, ...], - torch.zeros(n1, c, h, w, dtype=param.dtype)), - 1) for i in range(4)], 1) - else: - n = (n1 - n2) // 4 - cell = n2 // 4 - checkpoint[name] = torch.cat([torch.cat((param[i * cell: (i + 1) * cell, ...], - torch.zeros(n, c1, h, w, dtype=param.dtype)), - 0) for i in range(4)], 0) +def optimizer(model, checkpoint): + presistent_buffers = {k: v for k, v in model.named_buffers() if k not in model._non_persistent_buffers_set} + local_name_params = itertools.chain(model.named_parameters(), presistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in checkpoint.items(): + if local_state[name].shape != param.shape: + if 'conv1' in name or 'conv3' in name: + n1, c1, h, w = local_state[name].shape + n2, c2, h, w = param.shape + if n1 == n2: + c = (c1 - c2) // 4 + cell = c2 // 4 + checkpoint[name] = torch.cat([torch.cat((param[:, i * cell: (i + 1) * cell, ...], + torch.zeros(n1, c, h, w, dtype=param.dtype)), + 1) for i in range(4)], 1) + else: + n = (n1 - n2) // 4 + cell = n2 // 4 + checkpoint[name] = torch.cat([torch.cat((param[i * cell: (i + 1) * cell, ...], + torch.zeros(n, c1, h, w, dtype=param.dtype)), + 0) for i in range(4)], 0) elif 'bn1' in name or 'bn3' in name: cell = param.size(0) // 4 n = (local_state[name].size(0) - param.size(0)) // 4 checkpoint[name] = torch.cat([torch.cat((param[i * cell: (i + 1) * cell], - torch.zeros(n, dtype=param.dtype)), + torch.zeros(n, dtype=param.dtype)), 0) for i in range(4)]) else: if param.dim() == 1: checkpoint[name] = torch.cat((param, - torch.zeros(local_state[name].size(0) - param.size(0), dtype=param.dtype)), - 0) + torch.zeros(local_state[name].size(0) - param.size(0), dtype=param.dtype)), + 0) else: n1, c1, h, w = local_state[name].shape n2, c2, h, w = param.shape param = torch.cat((param, torch.zeros(n2, c1 - c2, h, w, dtype=param.dtype)), 1) checkpoint[name] = torch.cat((param, torch.zeros(n1 - n2, c1, h, w, dtype=param.dtype)), 0) + return checkpoint + -model.load_state_dict(checkpoint) -model.eval() +if __name__ == '__main__': + if args.optimizer: + checkpoint = optimizer(model, checkpoint) + model.load_state_dict(checkpoint) + model.eval() -inputs = torch.rand(1, 3, 224, 224) -torch.onnx.export(model, inputs, args.output, - input_names=["x"], output_names=["output"], - dynamic_axes={"x": {0: "-1"}}, opset_version=11) + inputs = torch.rand(1, 3, 224, 224) + torch.onnx.export(model, inputs, args.output, + input_names=["x"], output_names=["output"], + dynamic_axes={"x": {0: "-1"}}, opset_version=11) -- Gitee From a76379d7d5a1753388c560c06d31ee2da43f6f82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=95=87=E4=BA=AE?= Date: Wed, 6 Apr 2022 09:50:51 +0000 Subject: [PATCH 2/2] =?UTF-8?q?=E6=9B=B4=E6=96=B0readme=E5=90=8D=E5=AD=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../cv/Res2Net_v1b_101_for_PyTorch/{ README.md => README.md} | 4 ++++ 1 file changed, 4 insertions(+) rename ACL_PyTorch/built-in/cv/Res2Net_v1b_101_for_PyTorch/{ README.md => README.md} (92%) diff --git a/ACL_PyTorch/built-in/cv/Res2Net_v1b_101_for_PyTorch/ README.md b/ACL_PyTorch/built-in/cv/Res2Net_v1b_101_for_PyTorch/README.md similarity index 92% rename from ACL_PyTorch/built-in/cv/Res2Net_v1b_101_for_PyTorch/ README.md rename to ACL_PyTorch/built-in/cv/Res2Net_v1b_101_for_PyTorch/README.md index ed4d548672..69fa898a11 100644 --- a/ACL_PyTorch/built-in/cv/Res2Net_v1b_101_for_PyTorch/ README.md +++ b/ACL_PyTorch/built-in/cv/Res2Net_v1b_101_for_PyTorch/README.md @@ -44,7 +44,11 @@ git apply diff.patch 通过pth2onnx.py脚本转化为onnx模型 ```shell +# 直接导出原始ONNX python3.7 pth2onnx.py -m ./res2net101_v1b_26w_4s-0812c246.pth -o ./res2net.onnx + +# 导出NPU上优化后的ONNX +python3.7 pth2onnx.py -m ./res2net101_v1b_26w_4s-0812c246.pth -o ./res2net.onnx --optimizer ``` 利用ATC工具转换为om模型 -- Gitee