Ai
2 Star 20 Fork 12

monkey_cici/mmdetection

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
diffusiondet_resnet_to_mmdet.py 3.32 KB
一键复制 编辑 原始数据 按行查看 历史
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from collections import OrderedDict
import numpy as np
import torch
from mmengine.fileio import load
def convert(src, dst):
if src.endswith('pth'):
src_model = torch.load(src)
else:
src_model = load(src)
dst_state_dict = OrderedDict()
for k, v in src_model['model'].items():
key_name_split = k.split('.')
if 'backbone.fpn_lateral' in k:
lateral_id = int(key_name_split[-2][-1])
name = f'neck.lateral_convs.{lateral_id - 2}.' \
f'conv.{key_name_split[-1]}'
elif 'backbone.fpn_output' in k:
lateral_id = int(key_name_split[-2][-1])
name = f'neck.fpn_convs.{lateral_id - 2}.conv.' \
f'{key_name_split[-1]}'
elif 'backbone.bottom_up.stem.conv1.norm.' in k:
name = f'backbone.bn1.{key_name_split[-1]}'
elif 'backbone.bottom_up.stem.conv1.' in k:
name = f'backbone.conv1.{key_name_split[-1]}'
elif 'backbone.bottom_up.res' in k:
# weight_type = key_name_split[-1]
res_id = int(key_name_split[2][-1]) - 1
# deal with short cut
if 'shortcut' in key_name_split[4]:
if 'shortcut' == key_name_split[-2]:
name = f'backbone.layer{res_id}.' \
f'{key_name_split[3]}.downsample.0.' \
f'{key_name_split[-1]}'
elif 'shortcut' == key_name_split[-3]:
name = f'backbone.layer{res_id}.' \
f'{key_name_split[3]}.downsample.1.' \
f'{key_name_split[-1]}'
else:
print(f'Unvalid key {k}')
# deal with conv
elif 'conv' in key_name_split[-2]:
conv_id = int(key_name_split[-2][-1])
name = f'backbone.layer{res_id}.{key_name_split[3]}' \
f'.conv{conv_id}.{key_name_split[-1]}'
# deal with BN
elif key_name_split[-2] == 'norm':
conv_id = int(key_name_split[-3][-1])
name = f'backbone.layer{res_id}.{key_name_split[3]}.' \
f'bn{conv_id}.{key_name_split[-1]}'
else:
print(f'{k} is invalid')
elif key_name_split[0] == 'head':
# d2: head.xxx -> mmdet: bbox_head.xxx
name = f'bbox_{k}'
else:
# some base parameters such as beta will not convert
print(f'{k} is not converted!!')
continue
if not isinstance(v, np.ndarray) and not isinstance(v, torch.Tensor):
raise ValueError(
'Unsupported type found in checkpoint! {}: {}'.format(
k, type(v)))
if not isinstance(v, torch.Tensor):
dst_state_dict[name] = torch.from_numpy(v)
else:
dst_state_dict[name] = v
mmdet_model = dict(state_dict=dst_state_dict, meta=dict())
torch.save(mmdet_model, dst)
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument('src', help='src detectron model path')
parser.add_argument('dst', help='save path')
args = parser.parse_args()
convert(args.src, args.dst)
if __name__ == '__main__':
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/monkeycc/mmdetection.git
git@gitee.com:monkeycc/mmdetection.git
monkeycc
mmdetection
mmdetection
main

搜索帮助