1 Star 0 Fork 1

chenyi/unet-pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
unet.py 27.71 KB
一键复制 编辑 原始数据 按行查看 历史
unknown 提交于 2年前 . update onnx predict
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
import colorsys
import copy
import time
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from nets.unet import Unet as unet
from utils.utils import cvtColor, preprocess_input, resize_image, show_config
#--------------------------------------------#
# 使用自己训练好的模型预测需要修改2个参数
# model_path和num_classes都需要修改!
# 如果出现shape不匹配
# 一定要注意训练时的model_path和num_classes数的修改
#--------------------------------------------#
class Unet(object):
_defaults = {
#-------------------------------------------------------------------#
# model_path指向logs文件夹下的权值文件
# 训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。
# 验证集损失较低不代表miou较高,仅代表该权值在验证集上泛化性能较好。
#-------------------------------------------------------------------#
"model_path" : 'model_data/unet_vgg_voc.pth',
#--------------------------------#
# 所需要区分的类的个数+1
#--------------------------------#
"num_classes" : 21,
#--------------------------------#
# 所使用的的主干网络:vgg、resnet50
#--------------------------------#
"backbone" : "vgg",
#--------------------------------#
# 输入图片的大小
#--------------------------------#
"input_shape" : [512, 512],
#-------------------------------------------------#
# mix_type参数用于控制检测结果的可视化方式
#
# mix_type = 0的时候代表原图与生成的图进行混合
# mix_type = 1的时候代表仅保留生成的图
# mix_type = 2的时候代表仅扣去背景,仅保留原图中的目标
#-------------------------------------------------#
"mix_type" : 0,
#--------------------------------#
# 是否使用Cuda
# 没有GPU可以设置成False
#--------------------------------#
"cuda" : True,
}
#---------------------------------------------------#
# 初始化UNET
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
for name, value in kwargs.items():
setattr(self, name, value)
#---------------------------------------------------#
# 画框设置不同的颜色
#---------------------------------------------------#
if self.num_classes <= 21:
self.colors = [ (0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128),
(128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128),
(64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128),
(128, 64, 12)]
else:
hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
#---------------------------------------------------#
# 获得模型
#---------------------------------------------------#
self.generate()
show_config(**self._defaults)
#---------------------------------------------------#
# 获得所有的分类
#---------------------------------------------------#
def generate(self, onnx=False):
self.net = unet(num_classes = self.num_classes, backbone=self.backbone)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.net.load_state_dict(torch.load(self.model_path, map_location=device))
self.net = self.net.eval()
print('{} model, and classes loaded.'.format(self.model_path))
if not onnx:
if self.cuda:
self.net = nn.DataParallel(self.net)
self.net = self.net.cuda()
#---------------------------------------------------#
# 检测图片
#---------------------------------------------------#
def detect_image(self, image, count=False, name_classes=None):
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
#---------------------------------------------------#
# 对输入图像进行一个备份,后面用于绘图
#---------------------------------------------------#
old_img = copy.deepcopy(image)
orininal_h = np.array(image).shape[0]
orininal_w = np.array(image).shape[1]
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]))
#---------------------------------------------------------#
# 添加上batch_size维度
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
#---------------------------------------------------#
# 图片传入网络进行预测
#---------------------------------------------------#
pr = self.net(images)[0]
#---------------------------------------------------#
# 取出每一个像素点的种类
#---------------------------------------------------#
pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
#--------------------------------------#
# 将灰条部分截取掉
#--------------------------------------#
pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
#---------------------------------------------------#
# 进行图片的resize
#---------------------------------------------------#
pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
#---------------------------------------------------#
# 取出每一个像素点的种类
#---------------------------------------------------#
pr = pr.argmax(axis=-1)
#---------------------------------------------------------#
# 计数
#---------------------------------------------------------#
if count:
classes_nums = np.zeros([self.num_classes])
total_points_num = orininal_h * orininal_w
print('-' * 63)
print("|%25s | %15s | %15s|"%("Key", "Value", "Ratio"))
print('-' * 63)
for i in range(self.num_classes):
num = np.sum(pr == i)
ratio = num / total_points_num * 100
if num > 0:
print("|%25s | %15s | %14.2f%%|"%(str(name_classes[i]), str(num), ratio))
print('-' * 63)
classes_nums[i] = num
print("classes_nums:", classes_nums)
if self.mix_type == 0:
# seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
# for c in range(self.num_classes):
# seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8')
# seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8')
# seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8')
seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
#------------------------------------------------#
# 将新图片转换成Image的形式
#------------------------------------------------#
image = Image.fromarray(np.uint8(seg_img))
#------------------------------------------------#
# 将新图与原图及进行混合
#------------------------------------------------#
image = Image.blend(old_img, image, 0.7)
elif self.mix_type == 1:
# seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
# for c in range(self.num_classes):
# seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8')
# seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8')
# seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8')
seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
#------------------------------------------------#
# 将新图片转换成Image的形式
#------------------------------------------------#
image = Image.fromarray(np.uint8(seg_img))
elif self.mix_type == 2:
seg_img = (np.expand_dims(pr != 0, -1) * np.array(old_img, np.float32)).astype('uint8')
#------------------------------------------------#
# 将新图片转换成Image的形式
#------------------------------------------------#
image = Image.fromarray(np.uint8(seg_img))
return image
def get_FPS(self, image, test_interval):
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]))
#---------------------------------------------------------#
# 添加上batch_size维度
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
#---------------------------------------------------#
# 图片传入网络进行预测
#---------------------------------------------------#
pr = self.net(images)[0]
#---------------------------------------------------#
# 取出每一个像素点的种类
#---------------------------------------------------#
pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy().argmax(axis=-1)
#--------------------------------------#
# 将灰条部分截取掉
#--------------------------------------#
pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
t1 = time.time()
for _ in range(test_interval):
with torch.no_grad():
#---------------------------------------------------#
# 图片传入网络进行预测
#---------------------------------------------------#
pr = self.net(images)[0]
#---------------------------------------------------#
# 取出每一个像素点的种类
#---------------------------------------------------#
pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy().argmax(axis=-1)
#--------------------------------------#
# 将灰条部分截取掉
#--------------------------------------#
pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
t2 = time.time()
tact_time = (t2 - t1) / test_interval
return tact_time
def convert_to_onnx(self, simplify, model_path):
import onnx
self.generate(onnx=True)
im = torch.zeros(1, 3, *self.input_shape).to('cpu') # image size(1, 3, 512, 512) BCHW
input_layer_names = ["images"]
output_layer_names = ["output"]
# Export the model
print(f'Starting export with onnx {onnx.__version__}.')
torch.onnx.export(self.net,
im,
f = model_path,
verbose = False,
opset_version = 12,
training = torch.onnx.TrainingMode.EVAL,
do_constant_folding = True,
input_names = input_layer_names,
output_names = output_layer_names,
dynamic_axes = None)
# Checks
model_onnx = onnx.load(model_path) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
# Simplify onnx
if simplify:
import onnxsim
print(f'Simplifying with onnx-simplifier {onnxsim.__version__}.')
model_onnx, check = onnxsim.simplify(
model_onnx,
dynamic_input_shape=False,
input_shapes=None)
assert check, 'assert check failed'
onnx.save(model_onnx, model_path)
print('Onnx model save as {}'.format(model_path))
def get_miou_png(self, image):
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
orininal_h = np.array(image).shape[0]
orininal_w = np.array(image).shape[1]
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]))
#---------------------------------------------------------#
# 添加上batch_size维度
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
#---------------------------------------------------#
# 图片传入网络进行预测
#---------------------------------------------------#
pr = self.net(images)[0]
#---------------------------------------------------#
# 取出每一个像素点的种类
#---------------------------------------------------#
pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
#--------------------------------------#
# 将灰条部分截取掉
#--------------------------------------#
pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
#---------------------------------------------------#
# 进行图片的resize
#---------------------------------------------------#
pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
#---------------------------------------------------#
# 取出每一个像素点的种类
#---------------------------------------------------#
pr = pr.argmax(axis=-1)
image = Image.fromarray(np.uint8(pr))
return image
class Unet_ONNX(object):
_defaults = {
#--------------------------------------------------------------------------#
# onnx_path指向model_data文件夹下的onnx权值文件
#-------------------------------------------------------------------#
"onnx_path" : 'model_data/models.onnx',
#--------------------------------#
# 所需要区分的类的个数+1
#--------------------------------#
"num_classes" : 21,
#--------------------------------#
# 所使用的的主干网络:vgg、resnet50
#--------------------------------#
"backbone" : "vgg",
#--------------------------------#
# 输入图片的大小
#--------------------------------#
"input_shape" : [512, 512],
#-------------------------------------------------#
# mix_type参数用于控制检测结果的可视化方式
#
# mix_type = 0的时候代表原图与生成的图进行混合
# mix_type = 1的时候代表仅保留生成的图
# mix_type = 2的时候代表仅扣去背景,仅保留原图中的目标
#-------------------------------------------------#
"mix_type" : 0,
}
@classmethod
def get_defaults(cls, n):
if n in cls._defaults:
return cls._defaults[n]
else:
return "Unrecognized attribute name '" + n + "'"
#---------------------------------------------------#
# 初始化YOLO
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
for name, value in kwargs.items():
setattr(self, name, value)
self._defaults[name] = value
import onnxruntime
self.onnx_session = onnxruntime.InferenceSession(self.onnx_path)
# 获得所有的输入node
self.input_name = self.get_input_name()
# 获得所有的输出node
self.output_name = self.get_output_name()
#---------------------------------------------------#
# 画框设置不同的颜色
#---------------------------------------------------#
if self.num_classes <= 21:
self.colors = [ (0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128),
(128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128),
(64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128),
(128, 64, 12)]
else:
hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
show_config(**self._defaults)
def get_input_name(self):
# 获得所有的输入node
input_name=[]
for node in self.onnx_session.get_inputs():
input_name.append(node.name)
return input_name
def get_output_name(self):
# 获得所有的输出node
output_name=[]
for node in self.onnx_session.get_outputs():
output_name.append(node.name)
return output_name
def get_input_feed(self,image_tensor):
# 利用input_name获得输入的tensor
input_feed={}
for name in self.input_name:
input_feed[name]=image_tensor
return input_feed
#---------------------------------------------------#
# 对输入图像进行resize
#---------------------------------------------------#
def resize_image(self, image, size):
iw, ih = image.size
w, h = size
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', size, (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
return new_image, nw, nh
#---------------------------------------------------#
# 检测图片
#---------------------------------------------------#
def detect_image(self, image, count=False, name_classes=None):
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
#---------------------------------------------------#
# 对输入图像进行一个备份,后面用于绘图
#---------------------------------------------------#
old_img = copy.deepcopy(image)
orininal_h = np.array(image).shape[0]
orininal_w = np.array(image).shape[1]
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]))
#---------------------------------------------------------#
# 添加上bawtch_size维度
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
input_feed = self.get_input_feed(image_data)
pr = self.onnx_session.run(output_names=self.output_name, input_feed=input_feed)[0][0]
def softmax(x, axis):
x -= np.max(x, axis=axis, keepdims=True)
f_x = np.exp(x) / np.sum(np.exp(x), axis=axis, keepdims=True)
return f_x
print(np.shape(pr))
#---------------------------------------------------#
# 取出每一个像素点的种类
#---------------------------------------------------#
pr = softmax(np.transpose(pr, (1, 2, 0)), -1)
#--------------------------------------#
# 将灰条部分截取掉
#--------------------------------------#
pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
#---------------------------------------------------#
# 进行图片的resize
#---------------------------------------------------#
pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
#---------------------------------------------------#
# 取出每一个像素点的种类
#---------------------------------------------------#
pr = pr.argmax(axis=-1)
#---------------------------------------------------------#
# 计数
#---------------------------------------------------------#
if count:
classes_nums = np.zeros([self.num_classes])
total_points_num = orininal_h * orininal_w
print('-' * 63)
print("|%25s | %15s | %15s|"%("Key", "Value", "Ratio"))
print('-' * 63)
for i in range(self.num_classes):
num = np.sum(pr == i)
ratio = num / total_points_num * 100
if num > 0:
print("|%25s | %15s | %14.2f%%|"%(str(name_classes[i]), str(num), ratio))
print('-' * 63)
classes_nums[i] = num
print("classes_nums:", classes_nums)
if self.mix_type == 0:
# seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
# for c in range(self.num_classes):
# seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8')
# seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8')
# seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8')
seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
#------------------------------------------------#
# 将新图片转换成Image的形式
#------------------------------------------------#
image = Image.fromarray(np.uint8(seg_img))
#------------------------------------------------#
# 将新图与原图及进行混合
#------------------------------------------------#
image = Image.blend(old_img, image, 0.7)
elif self.mix_type == 1:
# seg_img = np.zeros((np.shape(pr)[0], np.shape(pr)[1], 3))
# for c in range(self.num_classes):
# seg_img[:, :, 0] += ((pr[:, :] == c ) * self.colors[c][0]).astype('uint8')
# seg_img[:, :, 1] += ((pr[:, :] == c ) * self.colors[c][1]).astype('uint8')
# seg_img[:, :, 2] += ((pr[:, :] == c ) * self.colors[c][2]).astype('uint8')
seg_img = np.reshape(np.array(self.colors, np.uint8)[np.reshape(pr, [-1])], [orininal_h, orininal_w, -1])
#------------------------------------------------#
# 将新图片转换成Image的形式
#------------------------------------------------#
image = Image.fromarray(np.uint8(seg_img))
elif self.mix_type == 2:
seg_img = (np.expand_dims(pr != 0, -1) * np.array(old_img, np.float32)).astype('uint8')
#------------------------------------------------#
# 将新图片转换成Image的形式
#------------------------------------------------#
image = Image.fromarray(np.uint8(seg_img))
return image
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/chen-yi-01/unet-pytorch.git
git@gitee.com:chen-yi-01/unet-pytorch.git
chen-yi-01
unet-pytorch
unet-pytorch
main

搜索帮助