代码拉取完成,页面将自动刷新
import torch
import torchvision
from torchvision import transforms
from network_files.faster_rcnn_framework import FasterRCNN, FastRCNNPredictor
from backbone.resnet50_fpn_model import resnet50_fpn_backbone
from network_files.rpn_function import AnchorsGenerator
from backbone.mobilenetv2_model import MobileNetV2
from draw_box_utils import draw_box
from PIL import Image
import json
import matplotlib.pyplot as plt
def create_model(num_classes):
# mobileNetv2+faster_RCNN
# backbone = MobileNetV2().features
# backbone.out_channels = 1280
#
# anchor_generator = AnchorsGenerator(sizes=((32, 64, 128, 256, 512),),
# aspect_ratios=((0.5, 1.0, 2.0),))
#
# roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
# output_size=[7, 7],
# sampling_ratio=2)
#
# model = FasterRCNN(backbone=backbone,
# num_classes=num_classes,
# rpn_anchor_generator=anchor_generator,
# box_roi_pool=roi_pooler)
# resNet50+fpn+faster_RCNN
backbone = resnet50_fpn_backbone()
model = FasterRCNN(backbone=backbone, num_classes=num_classes)
return model
# get devices
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# create model
model = create_model(num_classes=21)
# load train weights
## TODO 需要修改成自己的权重
train_weights = "./save_weights/model.pth"
model.load_state_dict(torch.load(train_weights)["model"])
model.to(device)
# read class_indict
category_index = {}
try:
## TODO 修改成自己的分类文件
json_file = open('./pascal_voc_classes.json', 'r')
class_dict = json.load(json_file)
category_index = {v: k for k, v in class_dict.items()}
except Exception as e:
print(e)
exit(-1)
# load image
## TODO 修改成测试图片
original_img = Image.open("./test.jpg")
# from pil image to tensor, do not normalize image
data_transform = transforms.Compose([transforms.ToTensor()])
img = data_transform(original_img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
model.eval()
with torch.no_grad():
predictions = model(img.to(device))[0]
predict_boxes = predictions["boxes"].to("cpu").numpy()
predict_classes = predictions["labels"].to("cpu").numpy()
predict_scores = predictions["scores"].to("cpu").numpy()
if len(predict_boxes) == 0:
print("没有检测到任何目标!")
draw_box(original_img,
predict_boxes,
predict_classes,
predict_scores,
category_index,
thresh=0.5,
line_thickness=5)
## TODO 修改对应的图片保存路径
original_img。save('/home/aistudio/work/faster_rcnn/homework/test1_pred.jpg')
# plt.imshow(original_img)
# plt.show()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。