Fetch the repository succeeded.
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import torch
import torchvision
from torchvision import datasets, transforms
from PIL import Image
from lib.PraNet_Res2Net import PraNet
import torch.nn.functional as F
import numpy as np
class test_dataset:
# path = './data/TestDataset'
def __init__(self, path):
self.testsize = 352
self.data_path = '{}/images/cju0u82z3cuma0835wlxrnrjv.png'.format(path)
self.gt_path = '{}/masks/cju0u82z3cuma0835wlxrnrjv.png'.format(path)
self.transform = transforms.Compose([
transforms.Resize((self.testsize, self.testsize)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
self.gt_transform = transforms.ToTensor()
def load_data(self):
image = self.rgb_loader(self.data_path)
image = self.transform(image).unsqueeze(0)
gt = self.binary_loader(self.gt_path)
return image, gt
def rgb_loader(self, path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def binary_loader(self, path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('L')
def test():
path = './data/TestDataset/Kvasir'
loc = 'npu:0'
loc_cpu = 'cpu'
torch.npu.set_device(loc)
model = PraNet()
pretrained_dict = torch.load("./snapshots/PraNet_Res2Net/PraNet-19.pth", map_location="cpu")
model.load_state_dict({k.replace('module.',''):v for k, v in pretrained_dict.items()})
if "fc.weight" in pretrained_dict:
pretrained_dict.pop('fc.weight')
pretrained_dict.pop('fc.bias')
model.load_state_dict(pretrained_dict, strict=False)
model = model.to(loc)
model.eval()
test_loader = test_dataset(path)
image, gt = test_loader.load_data()
gt = np.asarray(gt, np.float32)
gt /= (gt.max() + 1e-8)
image = image.npu()
res5, res4, res3, res2 = model(image)
res = res2
res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
res = res.sigmoid().data.cpu().numpy().squeeze()
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
print('res is ', res)
if __name__ == "__main__":
test()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。