diff --git a/ACL_PyTorch/contrib/cv/segmentation/Maskrcnn-mmdet/mmdet_postprocess.py b/ACL_PyTorch/contrib/cv/segmentation/Maskrcnn-mmdet/mmdet_postprocess.py index dbc978e895b9cfef9ea30a5ce7d3b0a952784054..beee9b5f05bc19b838e41fd19f8d3540a909628a 100644 --- a/ACL_PyTorch/contrib/cv/segmentation/Maskrcnn-mmdet/mmdet_postprocess.py +++ b/ACL_PyTorch/contrib/cv/segmentation/Maskrcnn-mmdet/mmdet_postprocess.py @@ -54,6 +54,7 @@ def postprocess_masks(masks, image_size, net_input_width, net_input_height): ws = int(pad_left + net_input_width - pad_w) masks = masks.to(dtype=torch.float32) res_append = torch.zeros(0, h, w) + tmp=[res_append] if torch.cuda.is_available(): res_append = res_append.to(device='cuda') for i in range(masks.size(0)): @@ -62,8 +63,11 @@ def postprocess_masks(masks, image_size, net_input_width, net_input_height): mask = F.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False) mask = mask[0][0] mask = mask.unsqueeze(0) - res_append = torch.cat((res_append, mask)) - + #res_append = torch.cat((res_append, mask)) + tmp.append(mask) + + tmp=tuple(tmp) + res_append = torch.cat(tmp) return res_append[:, None] import pickle