代码拉取完成,页面将自动刷新
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from collections import OrderedDict
import numpy as np
import torch
from mmengine.fileio import load
def convert(src, dst):
if src.endswith('pth'):
src_model = torch.load(src)
else:
src_model = load(src)
dst_state_dict = OrderedDict()
for k, v in src_model['model'].items():
key_name_split = k.split('.')
if 'backbone.fpn_lateral' in k:
lateral_id = int(key_name_split[-2][-1])
name = f'neck.lateral_convs.{lateral_id - 2}.' \
f'conv.{key_name_split[-1]}'
elif 'backbone.fpn_output' in k:
lateral_id = int(key_name_split[-2][-1])
name = f'neck.fpn_convs.{lateral_id - 2}.conv.' \
f'{key_name_split[-1]}'
elif 'backbone.bottom_up.stem.conv1.norm.' in k:
name = f'backbone.bn1.{key_name_split[-1]}'
elif 'backbone.bottom_up.stem.conv1.' in k:
name = f'backbone.conv1.{key_name_split[-1]}'
elif 'backbone.bottom_up.res' in k:
# weight_type = key_name_split[-1]
res_id = int(key_name_split[2][-1]) - 1
# deal with short cut
if 'shortcut' in key_name_split[4]:
if 'shortcut' == key_name_split[-2]:
name = f'backbone.layer{res_id}.' \
f'{key_name_split[3]}.downsample.0.' \
f'{key_name_split[-1]}'
elif 'shortcut' == key_name_split[-3]:
name = f'backbone.layer{res_id}.' \
f'{key_name_split[3]}.downsample.1.' \
f'{key_name_split[-1]}'
else:
print(f'Unvalid key {k}')
# deal with conv
elif 'conv' in key_name_split[-2]:
conv_id = int(key_name_split[-2][-1])
name = f'backbone.layer{res_id}.{key_name_split[3]}' \
f'.conv{conv_id}.{key_name_split[-1]}'
# deal with BN
elif key_name_split[-2] == 'norm':
conv_id = int(key_name_split[-3][-1])
name = f'backbone.layer{res_id}.{key_name_split[3]}.' \
f'bn{conv_id}.{key_name_split[-1]}'
else:
print(f'{k} is invalid')
elif key_name_split[0] == 'head':
# d2: head.xxx -> mmdet: bbox_head.xxx
name = f'bbox_{k}'
else:
# some base parameters such as beta will not convert
print(f'{k} is not converted!!')
continue
if not isinstance(v, np.ndarray) and not isinstance(v, torch.Tensor):
raise ValueError(
'Unsupported type found in checkpoint! {}: {}'.format(
k, type(v)))
if not isinstance(v, torch.Tensor):
dst_state_dict[name] = torch.from_numpy(v)
else:
dst_state_dict[name] = v
mmdet_model = dict(state_dict=dst_state_dict, meta=dict())
torch.save(mmdet_model, dst)
def main():
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument('src', help='src detectron model path')
parser.add_argument('dst', help='save path')
args = parser.parse_args()
convert(args.src, args.dst)
if __name__ == '__main__':
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。