From 76e8b02e5615a75980c55f904f6f2169e99d8d38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=85=B4=E9=BE=99?= <13651641606@163.com> Date: Wed, 22 Oct 2025 03:03:36 +0000 Subject: [PATCH 1/3] update ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 杨兴龙 <13651641606@163.com> --- .../SwinTransformer_for_Pytorch/pth2onnx.py | 66 +++++++++++++------ 1 file changed, 46 insertions(+), 20 deletions(-) diff --git a/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py b/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py index 3b5263295e..50b9d81637 100644 --- a/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py +++ b/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py @@ -17,54 +17,80 @@ import os import argparse import torch import timm +import numpy as np def pth2onnx(args): pth_path = args.input_path batch_size = args.batch_size - model_name = args.model_name out_path = args.out_path - # get size + checkpoint = torch.load(pth_path, map_location='cpu') + + config = checkpoint['config'] + state_dict = checkpoint['model'] + + model_name = config.MODEL.NAME + + model = timm.create_model( + model_name, + pretrained=False, + num_classes=config.MODEL.NUM_CLASSES + ) + + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith('module.'): + new_state_dict[k[7:]] = v + else: + new_state_dict[k] = v + + # 修复relative_position_index的形状不匹配问题 + for key in list(new_state_dict.keys()): + if 'relative_position_index' in key: + # 原始形状是 [2401],需要重塑为 [49, 49] + if new_state_dict[key].shape == torch.Size([2401]): + new_state_dict[key] = new_state_dict[key].view(49, 49) + + model.load_state_dict(new_state_dict, strict=False) + + model.eval() + if 's3' in model_name: - size = int(model_name.split('_')[3]) + input_size = int(model_name.split('_')[3]) else: - size = int(model_name.split('_')[4]) - input_data = torch.randn([batch_size, 3, size, size]).to(torch.float32) - input_names = ["image"] - output_names = ["out"] + input_size = int(model_name.split('_')[4]) - # build model - model = timm.create_model(model_name, checkpoint_path=pth_path) - model.eval() + input_data = torch.randn([batch_size, 3, input_size, input_size], dtype=torch.float32) + print(f"输入数据形状: {input_data.shape}") + + print("开始导出ONNX...") + # 导出ONNX torch.onnx.export( model, input_data, out_path, verbose=True, opset_version=11, - input_names=input_names, - output_names=output_names + input_names=["image"], + output_names=["output"], ) + print(f"✅ ONNX模型已保存到: {out_path}") + def parse_arguments(): - parser = argparse.ArgumentParser(description='SwinTransformer onnx export.') + parser = argparse.ArgumentParser(description='Convert Swin-Tiny pth to onnx') parser.add_argument('-i', '--input_path', type=str, required=True, help='input path for pth model') parser.add_argument('-o', '--out_path', type=str, required=True, help='save path for output onnx model') - parser.add_argument('-n', '--model_name', type=str, default='swin_base_patch4_window12_384', - help='model name for swintransformer') parser.add_argument('-b', '--batch_size', type=int, default=1, help='batch size for output model') - args = parser.parse_args() - args.out_path = os.path.abspath(args.out_path) - os.makedirs(os.path.dirname(args.out_path), exist_ok=True) - return args - + return parser.parse_args() if __name__ == '__main__': args = parse_arguments() pth2onnx(args) + -- Gitee From d7c7234a63939b2d42064b7df84211bf0280692b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=85=B4=E9=BE=99?= <13651641606@163.com> Date: Wed, 22 Oct 2025 03:24:30 +0000 Subject: [PATCH 2/3] update ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 杨兴龙 <13651641606@163.com> --- .../built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py b/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py index 50b9d81637..160abe6e24 100644 --- a/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py +++ b/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py @@ -44,13 +44,13 @@ def pth2onnx(args): new_state_dict[k[7:]] = v else: new_state_dict[k] = v - + # 修复relative_position_index的形状不匹配问题 for key in list(new_state_dict.keys()): if 'relative_position_index' in key: - # 原始形状是 [2401],需要重塑为 [49, 49] - if new_state_dict[key].shape == torch.Size([2401]): - new_state_dict[key] = new_state_dict[key].view(49, 49) + tensor_value = new_state_dict.get(key) + if tensor_value is not None and tensor_value.shape == torch.Size([2401]): + new_state_dict[key] = tensor_value.view(49, 49) model.load_state_dict(new_state_dict, strict=False) -- Gitee From d1fd6e56bbe6ca3158f166ee6e13dade72050a83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=85=B4=E9=BE=99?= <13651641606@163.com> Date: Wed, 22 Oct 2025 08:29:51 +0000 Subject: [PATCH 3/3] update ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 杨兴龙 <13651641606@163.com> --- .../SwinTransformer_for_Pytorch/pth2onnx.py | 134 +++++++++++------- 1 file changed, 86 insertions(+), 48 deletions(-) diff --git a/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py b/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py index 160abe6e24..4db88e7b04 100644 --- a/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py +++ b/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py @@ -24,73 +24,111 @@ def pth2onnx(args): pth_path = args.input_path batch_size = args.batch_size out_path = args.out_path + model_name = args.model_name + is_open_source = args.open_source - checkpoint = torch.load(pth_path, map_location='cpu') - - config = checkpoint['config'] - state_dict = checkpoint['model'] - - model_name = config.MODEL.NAME - - model = timm.create_model( - model_name, - pretrained=False, - num_classes=config.MODEL.NUM_CLASSES - ) - - new_state_dict = {} - for k, v in state_dict.items(): - if k.startswith('module.'): - new_state_dict[k[7:]] = v + if is_open_source: + if 's3' in model_name: + size = int(model_name.split('_')[3]) else: - new_state_dict[k] = v - - # 修复relative_position_index的形状不匹配问题 - for key in list(new_state_dict.keys()): - if 'relative_position_index' in key: - tensor_value = new_state_dict.get(key) - if tensor_value is not None and tensor_value.shape == torch.Size([2401]): - new_state_dict[key] = tensor_value.view(49, 49) + size = int(model_name.split('_')[4]) - model.load_state_dict(new_state_dict, strict=False) + input_data = torch.randn([batch_size, 3, size, size]).to(torch.float32) + input_names = ["image"] + output_names = ["output"] - model.eval() + model = timm.create_model(model_name, checkpoint_path=pth_path) + model.eval() - if 's3' in model_name: - input_size = int(model_name.split('_')[3]) - else: - input_size = int(model_name.split('_')[4]) - - input_data = torch.randn([batch_size, 3, input_size, input_size], dtype=torch.float32) - print(f"输入数据形状: {input_data.shape}") + torch.onnx.export( + model, + input_data, + out_path, + verbose=True, + opset_version=11, + input_names=input_names, + output_names=output_names + ) - print("开始导出ONNX...") + else: + checkpoint = torch.load(pth_path, map_location='cpu') + + if 'config' in checkpoint and 'model' in checkpoint: + config = checkpoint['config'] + state_dict = checkpoint['model'] + if hasattr(config, 'MODEL') and hasattr(config.MODEL, 'NAME'): + model_name_from_config = config.MODEL.NAME + if model_name == "default": + model_name = model_name_from_config + else: + raise ValueError("Checkpoint文件结构不符合预期,应包含'config'和'model'键") + + model = timm.create_model( + model_name, + pretrained=False, + num_classes=config.MODEL.NUM_CLASSES + ) + + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith('module.'): + new_state_dict[k[7:]] = v + else: + new_state_dict[k] = v + + # 修复relative_position_index的形状不匹配问题 + for key in list(new_state_dict.keys()): + if 'relative_position_index' in key: + tensor_value = new_state_dict.get(key) + if tensor_value is not None and tensor_value.shape == torch.Size([2401]): + new_state_dict[key] = tensor_value.view(49, 49) + + model.load_state_dict(new_state_dict, strict=False) + model.eval() + + if 's3' in model_name: + input_size = int(model_name.split('_')[3]) + else: + input_size = int(model_name.split('_')[4]) - # 导出ONNX - torch.onnx.export( - model, - input_data, - out_path, - verbose=True, - opset_version=11, - input_names=["image"], - output_names=["output"], - ) + input_data = torch.randn([batch_size, 3, input_size, input_size], dtype=torch.float32) + input_names = ["image"] + output_names = ["output"] - print(f"✅ ONNX模型已保存到: {out_path}") + torch.onnx.export( + model, + input_data, + out_path, + verbose=True, + opset_version=11, + input_names=input_names, + output_names=output_names, + ) def parse_arguments(): - parser = argparse.ArgumentParser(description='Convert Swin-Tiny pth to onnx') + parser = argparse.ArgumentParser(description='Convert SwinTransformer pth to onnx') parser.add_argument('-i', '--input_path', type=str, required=True, help='input path for pth model') parser.add_argument('-o', '--out_path', type=str, required=True, help='save path for output onnx model') + parser.add_argument('-n', '--model_name', type=str, default='default', + help='model name for swintransformer (e.g., swin_base_patch4_window12_384)') parser.add_argument('-b', '--batch_size', type=int, default=1, help='batch size for output model') - return parser.parse_args() + parser.add_argument('--open_source', action='store_false', + help='whether the model is from open source (use timm direct loading)') + args = parser.parse_args() + + os.makedirs(os.path.dirname(os.path.abspath(args.out_path)), exist_ok=True) + + return args + if __name__ == '__main__': args = parse_arguments() + if args.open_source and args.model_name == 'default': + raise ValueError("使用开源模型时必须指定 --model_name 参数") + pth2onnx(args) -- Gitee