2.3K Star 8K Fork 4.2K

GVPMindSpore / mindspore

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
resnet.py 14.20 KB
一键复制 编辑 原始数据 按行查看 历史
LiangZhibo 提交于 2021-01-26 12:50 . Change TensorAdd to Add
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet."""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
from scipy.stats import truncnorm
def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
fan_in = in_channel * kernel_size * kernel_size
scale = 1.0
scale /= max(1., fan_in)
stddev = (scale ** 0.5) / .87962566103423978
mu, sigma = 0, stddev
weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size)
weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size))
return Tensor(weight, dtype=mstype.float32)
def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
def _conv3x3(in_channel, out_channel, stride=1, use_se=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
else:
weight_shape = (out_channel, in_channel, 3, 3)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)
def _conv1x1(in_channel, out_channel, stride=1, use_se=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
else:
weight_shape = (out_channel, in_channel, 1, 1)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)
def _conv7x7(in_channel, out_channel, stride=1, use_se=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
else:
weight_shape = (out_channel, in_channel, 7, 7)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
def _bn(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _bn_last(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _fc(in_channel, out_channel, use_se=False):
if use_se:
weight = np.random.normal(loc=0, scale=0.01, size=out_channel*in_channel)
weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32)
else:
weight_shape = (out_channel, in_channel)
weight = _weight_variable(weight_shape)
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
class ResidualBlock(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
use_se (bool): enable SE-ResNet50 net. Default: False.
se_block(bool): use se block in SE-ResNet50 net. Default: False.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(3, 256, stride=2)
"""
expansion = 4
def __init__(self,
in_channel,
out_channel,
stride=1,
use_se=False, se_block=False):
super(ResidualBlock, self).__init__()
self.stride = stride
self.use_se = use_se
self.se_block = se_block
channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se)
self.bn1 = _bn(channel)
if self.use_se and self.stride != 1:
self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel),
nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')])
else:
self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se)
self.bn2 = _bn(channel)
self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
self.bn3 = _bn_last(out_channel)
if self.se_block:
self.se_global_pool = P.ReduceMean(keep_dims=False)
self.se_dense_0 = _fc(out_channel, int(out_channel/4), use_se=self.use_se)
self.se_dense_1 = _fc(int(out_channel/4), out_channel, use_se=self.use_se)
self.se_sigmoid = nn.Sigmoid()
self.se_mul = P.Mul()
self.relu = nn.ReLU()
self.down_sample = False
if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None
if self.down_sample:
if self.use_se:
if stride == 1:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel,
stride, use_se=self.use_se), _bn(out_channel)])
else:
self.down_sample_layer = nn.SequentialCell([nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'),
_conv1x1(in_channel, out_channel, 1,
use_se=self.use_se), _bn(out_channel)])
else:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
use_se=self.use_se), _bn(out_channel)])
self.add = P.Add()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
if self.use_se and self.stride != 1:
out = self.e2(out)
else:
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.se_block:
out_se = out
out = self.se_global_pool(out, (2, 3))
out = self.se_dense_0(out)
out = self.relu(out)
out = self.se_dense_1(out)
out = self.se_sigmoid(out)
out = F.reshape(out, F.shape(out) + (1, 1))
out = self.se_mul(out, out_se)
if self.down_sample:
identity = self.down_sample_layer(identity)
out = self.add(out, identity)
out = self.relu(out)
return out
class ResNet(nn.Cell):
"""
ResNet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
use_se (bool): enable SE-ResNet50 net. Default: False.
se_block(bool): use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False.
Returns:
Tensor, output tensor.
Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> 10)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides,
num_classes,
use_se=False):
super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.use_se = use_se
self.se_block = False
if self.use_se:
self.se_block = True
if self.use_se:
self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se)
self.bn1_0 = _bn(32)
self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se)
self.bn1_1 = _bn(32)
self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se)
else:
self.conv1 = _conv7x7(3, 64, stride=2)
self.bn1 = _bn(64)
self.relu = P.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0],
use_se=self.use_se)
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1],
use_se=self.use_se)
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2],
use_se=self.use_se,
se_block=self.se_block)
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3],
use_se=self.use_se,
se_block=self.se_block)
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = _fc(out_channels[3], num_classes, use_se=self.use_se)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False):
"""
Make stage network of ResNet.
Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
se_block(bool): use se block in SE-ResNet50 net. Default: False.
Returns:
SequentialCell, the output layer.
Examples:
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
"""
layers = []
resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se)
layers.append(resnet_block)
if se_block:
for _ in range(1, layer_num - 1):
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
layers.append(resnet_block)
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block)
layers.append(resnet_block)
else:
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
layers.append(resnet_block)
return nn.SequentialCell(layers)
def construct(self, x):
if self.use_se:
x = self.conv1_0(x)
x = self.bn1_0(x)
x = self.relu(x)
x = self.conv1_1(x)
x = self.bn1_1(x)
x = self.relu(x)
x = self.conv1_2(x)
else:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
out = self.mean(c5, (2, 3))
out = self.flatten(out)
out = self.end_point(out)
return out
def resnet50(class_num=10):
"""
Get ResNet50 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet50 neural network.
Examples:
>>> net = resnet50(10)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)
def se_resnet50(class_num=1001):
"""
Get SE-ResNet50 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of SE-ResNet50 neural network.
Examples:
>>> net = se-resnet50(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num,
use_se=True)
def resnet101(class_num=1001):
"""
Get ResNet101 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet101 neural network.
Examples:
>>> net = resnet101(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 23, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore
r1.1

搜索帮助