2.3K Star 8K Fork 4.2K

GVPMindSpore / mindspore

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
deeplab_v3.py 9.09 KB
一键复制 编辑 原始数据 按行查看 历史
LiangZhibo 提交于 2021-01-26 12:50 . Change TensorAdd to Add
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore.ops import operations as P
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, weight_init='xavier_uniform')
def conv3x3(in_planes, out_planes, stride=1, dilation=1, padding=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, pad_mode='pad', padding=padding,
dilation=dilation, weight_init='xavier_uniform')
class Resnet(nn.Cell):
def __init__(self, block, block_num, output_stride, use_batch_statistics=True):
super(Resnet, self).__init__()
self.inplanes = 64
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, pad_mode='pad', padding=3,
weight_init='xavier_uniform')
self.bn1 = nn.BatchNorm2d(self.inplanes, use_batch_statistics=use_batch_statistics)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = self._make_layer(block, 64, block_num[0], use_batch_statistics=use_batch_statistics)
self.layer2 = self._make_layer(block, 128, block_num[1], stride=2, use_batch_statistics=use_batch_statistics)
if output_stride == 16:
self.layer3 = self._make_layer(block, 256, block_num[2], stride=2,
use_batch_statistics=use_batch_statistics)
self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=2, grids=[1, 2, 4],
use_batch_statistics=use_batch_statistics)
elif output_stride == 8:
self.layer3 = self._make_layer(block, 256, block_num[2], stride=1, base_dilation=2,
use_batch_statistics=use_batch_statistics)
self.layer4 = self._make_layer(block, 512, block_num[3], stride=1, base_dilation=4, grids=[1, 2, 4],
use_batch_statistics=use_batch_statistics)
def _make_layer(self, block, planes, blocks, stride=1, base_dilation=1, grids=None, use_batch_statistics=True):
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.SequentialCell([
conv1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm2d(planes * block.expansion, use_batch_statistics=use_batch_statistics)
])
if grids is None:
grids = [1] * blocks
layers = [
block(self.inplanes, planes, stride, downsample, dilation=base_dilation * grids[0],
use_batch_statistics=use_batch_statistics)
]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(self.inplanes, planes, dilation=base_dilation * grids[i],
use_batch_statistics=use_batch_statistics))
return nn.SequentialCell(layers)
def construct(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.maxpool(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
return out
class Bottleneck(nn.Cell):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, use_batch_statistics=True):
super(Bottleneck, self).__init__()
self.conv1 = conv1x1(inplanes, planes)
self.bn1 = nn.BatchNorm2d(planes, use_batch_statistics=use_batch_statistics)
self.conv2 = conv3x3(planes, planes, stride, dilation, dilation)
self.bn2 = nn.BatchNorm2d(planes, use_batch_statistics=use_batch_statistics)
self.conv3 = conv1x1(planes, planes * self.expansion)
self.bn3 = nn.BatchNorm2d(planes * self.expansion, use_batch_statistics=use_batch_statistics)
self.relu = nn.ReLU()
self.downsample = downsample
self.add = P.Add()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.add(out, identity)
out = self.relu(out)
return out
class ASPP(nn.Cell):
def __init__(self, atrous_rates, phase='train', in_channels=2048, num_classes=21,
use_batch_statistics=True):
super(ASPP, self).__init__()
self.phase = phase
out_channels = 256
self.aspp1 = ASPPConv(in_channels, out_channels, atrous_rates[0], use_batch_statistics=use_batch_statistics)
self.aspp2 = ASPPConv(in_channels, out_channels, atrous_rates[1], use_batch_statistics=use_batch_statistics)
self.aspp3 = ASPPConv(in_channels, out_channels, atrous_rates[2], use_batch_statistics=use_batch_statistics)
self.aspp4 = ASPPConv(in_channels, out_channels, atrous_rates[3], use_batch_statistics=use_batch_statistics)
self.aspp_pooling = ASPPPooling(in_channels, out_channels, use_batch_statistics=use_batch_statistics)
self.conv1 = nn.Conv2d(out_channels * (len(atrous_rates) + 1), out_channels, kernel_size=1,
weight_init='xavier_uniform')
self.bn1 = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, num_classes, kernel_size=1, weight_init='xavier_uniform', has_bias=True)
self.concat = P.Concat(axis=1)
self.drop = nn.Dropout(0.3)
def construct(self, x):
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.aspp_pooling(x)
x = self.concat((x1, x2))
x = self.concat((x, x3))
x = self.concat((x, x4))
x = self.concat((x, x5))
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
if self.phase == 'train':
x = self.drop(x)
x = self.conv2(x)
return x
class ASPPPooling(nn.Cell):
def __init__(self, in_channels, out_channels, use_batch_statistics=True):
super(ASPPPooling, self).__init__()
self.conv = nn.SequentialCell([
nn.Conv2d(in_channels, out_channels, kernel_size=1, weight_init='xavier_uniform'),
nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics),
nn.ReLU()
])
self.shape = P.Shape()
def construct(self, x):
size = self.shape(x)
out = nn.AvgPool2d(size[2])(x)
out = self.conv(out)
out = P.ResizeNearestNeighbor((size[2], size[3]), True)(out)
return out
class ASPPConv(nn.Cell):
def __init__(self, in_channels, out_channels, atrous_rate=1, use_batch_statistics=True):
super(ASPPConv, self).__init__()
if atrous_rate == 1:
conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, has_bias=False, weight_init='xavier_uniform')
else:
conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, pad_mode='pad', padding=atrous_rate,
dilation=atrous_rate, weight_init='xavier_uniform')
bn = nn.BatchNorm2d(out_channels, use_batch_statistics=use_batch_statistics)
relu = nn.ReLU()
self.aspp_conv = nn.SequentialCell([conv, bn, relu])
def construct(self, x):
out = self.aspp_conv(x)
return out
class DeepLabV3(nn.Cell):
def __init__(self, phase='train', num_classes=21, output_stride=16, freeze_bn=False):
super(DeepLabV3, self).__init__()
use_batch_statistics = not freeze_bn
self.resnet = Resnet(Bottleneck, [3, 4, 23, 3], output_stride=output_stride,
use_batch_statistics=use_batch_statistics)
self.aspp = ASPP([1, 6, 12, 18], phase, 2048, num_classes,
use_batch_statistics=use_batch_statistics)
self.shape = P.Shape()
def construct(self, x):
size = self.shape(x)
out = self.resnet(x)
out = self.aspp(out)
out = P.ResizeBilinear((size[2], size[3]), True)(out)
return out
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore
r1.1

搜索帮助