106 Star 843 Fork 1.4K

MindSpore/models

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
resnet.py 21.90 KB
一键复制 编辑 原始数据 按行查看 历史
zhaoting 提交于 2022-11-17 14:18 . move official models
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet."""
import math
import numpy as np
from scipy.stats import truncnorm
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from src.model_utils.config import config
def conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
fan_in = in_channel * kernel_size * kernel_size
scale = 1.0
scale /= max(1., fan_in)
stddev = (scale ** 0.5) / .87962566103423978
if config.net_name == "resnet152":
stddev = (scale ** 0.5)
mu, sigma = 0, stddev
weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size)
weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size))
return Tensor(weight, dtype=mstype.float32)
def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)
def calculate_gain(nonlinearity, param=None):
"""calculate_gain"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
res = 0
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
res = 1
elif nonlinearity == 'tanh':
res = 5.0 / 3
elif nonlinearity == 'relu':
res = math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
neg_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
neg_slope = param
else:
raise ValueError("neg_slope {} not a valid number".format(param))
res = math.sqrt(2.0 / (1 + neg_slope ** 2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
return res
def _calculate_fan_in_and_fan_out(tensor):
"""_calculate_fan_in_and_fan_out"""
dimensions = len(tensor)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
if dimensions == 2: # Linear
fan_in = tensor[1]
fan_out = tensor[0]
else:
num_input_fmaps = tensor[1]
num_output_fmaps = tensor[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = tensor[2] * tensor[3]
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Unsupported mode {}, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out
def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)
def _conv3x3(in_channel, out_channel, stride=1, use_se=False, res_base=False):
if use_se:
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
else:
weight_shape = (out_channel, in_channel, 3, 3)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
if config.net_name == "resnet152":
weight = _weight_variable(weight_shape)
if res_base:
return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
padding=1, pad_mode='pad', weight_init=weight)
return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
padding=0, pad_mode='same', weight_init=weight)
def _conv1x1(in_channel, out_channel, stride=1, use_se=False, res_base=False):
if use_se:
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
else:
weight_shape = (out_channel, in_channel, 1, 1)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
if config.net_name == "resnet152":
weight = _weight_variable(weight_shape)
if res_base:
return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
padding=0, pad_mode='pad', weight_init=weight)
return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
padding=0, pad_mode='same', weight_init=weight)
def _conv7x7(in_channel, out_channel, stride=1, use_se=False, res_base=False):
if use_se:
weight = conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
else:
weight_shape = (out_channel, in_channel, 7, 7)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
if config.net_name == "resnet152":
weight = _weight_variable(weight_shape)
if res_base:
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=3, pad_mode='pad', weight_init=weight)
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
def _bn(channel, res_base=False):
if res_base:
return nn.BatchNorm2d(channel, eps=1e-5, momentum=0.1,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _bn_last(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _fc(in_channel, out_channel, use_se=False):
if use_se:
weight = np.random.normal(loc=0, scale=0.01, size=out_channel * in_channel)
weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32)
else:
weight_shape = (out_channel, in_channel)
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
if config.net_name == "resnet152":
weight = _weight_variable(weight_shape)
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
class ResidualBlock(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
use_se (bool): Enable SE-ResNet50 net. Default: False.
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(3, 256, stride=2)
"""
expansion = 4
def __init__(self,
in_channel,
out_channel,
stride=1,
use_se=False, se_block=False):
super(ResidualBlock, self).__init__()
self.stride = stride
self.use_se = use_se
self.se_block = se_block
channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se)
self.bn1 = _bn(channel)
if self.use_se and self.stride != 1:
self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel),
nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')])
else:
self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se)
self.bn2 = _bn(channel)
self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
self.bn3 = _bn(out_channel)
if config.optimizer == "Thor" or config.net_name == "resnet152":
self.bn3 = _bn_last(out_channel)
if self.se_block:
self.se_global_pool = ops.ReduceMean(keep_dims=False)
self.se_dense_0 = _fc(out_channel, int(out_channel / 4), use_se=self.use_se)
self.se_dense_1 = _fc(int(out_channel / 4), out_channel, use_se=self.use_se)
self.se_sigmoid = nn.Sigmoid()
self.se_mul = ops.Mul()
self.relu = nn.ReLU()
self.down_sample = False
if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None
if self.down_sample:
if self.use_se:
if stride == 1:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel,
stride, use_se=self.use_se), _bn(out_channel)])
else:
self.down_sample_layer = nn.SequentialCell([nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'),
_conv1x1(in_channel, out_channel, 1,
use_se=self.use_se), _bn(out_channel)])
else:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
use_se=self.use_se), _bn(out_channel)])
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
if self.use_se and self.stride != 1:
out = self.e2(out)
else:
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.se_block:
out_se = out
out = self.se_global_pool(out, (2, 3))
out = self.se_dense_0(out)
out = self.relu(out)
out = self.se_dense_1(out)
out = self.se_sigmoid(out)
out = ops.reshape(out, ops.shape(out) + (1, 1))
out = self.se_mul(out, out_se)
if self.down_sample:
identity = self.down_sample_layer(identity)
out = out + identity
out = self.relu(out)
return out
class ResidualBlockBase(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
use_se (bool): Enable SE-ResNet50 net. Default: False.
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
res_base (bool): Enable parameter setting of resnet18. Default: True.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlockBase(3, 256, stride=2)
"""
def __init__(self,
in_channel,
out_channel,
stride=1,
use_se=False,
se_block=False,
res_base=True):
super(ResidualBlockBase, self).__init__()
self.res_base = res_base
self.conv1 = _conv3x3(in_channel, out_channel, stride=stride, res_base=self.res_base)
self.bn1d = _bn(out_channel)
self.conv2 = _conv3x3(out_channel, out_channel, stride=1, res_base=self.res_base)
self.bn2d = _bn(out_channel)
self.relu = nn.ReLU()
self.down_sample = False
if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_sample_layer = None
if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
use_se=use_se, res_base=self.res_base),
_bn(out_channel, res_base)])
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.bn1d(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2d(out)
if self.down_sample:
identity = self.down_sample_layer(identity)
out = out + identity
out = self.relu(out)
return out
class ResNet(nn.Cell):
"""
ResNet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
use_se (bool): Enable SE-ResNet50 net. Default: False.
se_block(bool): Use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False.
res_base (bool): Enable parameter setting of resnet18. Default: False.
Returns:
Tensor, output tensor.
Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> 10)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides,
num_classes,
use_se=False,
res_base=False):
super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.use_se = use_se
self.res_base = res_base
self.se_block = False
if self.use_se:
self.se_block = True
if self.use_se:
self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se)
self.bn1_0 = _bn(32)
self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se)
self.bn1_1 = _bn(32)
self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se)
else:
self.conv1 = _conv7x7(3, 64, stride=2, res_base=self.res_base)
self.bn1 = _bn(64, self.res_base)
self.relu = ops.ReLU()
if self.res_base:
self.pad = nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)))
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="valid")
else:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0],
use_se=self.use_se)
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1],
use_se=self.use_se)
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2],
use_se=self.use_se,
se_block=self.se_block)
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3],
use_se=self.use_se,
se_block=self.se_block)
self.mean = ops.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = _fc(out_channels[3], num_classes, use_se=self.use_se)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False):
"""
Make stage network of ResNet.
Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
se_block(bool): Use se block in SE-ResNet50 net. Default: False.
Returns:
SequentialCell, the output layer.
Examples:
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
"""
layers = []
resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se)
layers.append(resnet_block)
if se_block:
for _ in range(1, layer_num - 1):
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
layers.append(resnet_block)
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block)
layers.append(resnet_block)
else:
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
layers.append(resnet_block)
return nn.SequentialCell(layers)
def construct(self, x):
if self.use_se:
x = self.conv1_0(x)
x = self.bn1_0(x)
x = self.relu(x)
x = self.conv1_1(x)
x = self.bn1_1(x)
x = self.relu(x)
x = self.conv1_2(x)
else:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
if self.res_base:
x = self.pad(x)
c1 = self.maxpool(x)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
out = self.mean(c5, (2, 3))
out = self.flatten(out)
out = self.end_point(out)
return out
def resnet18(class_num=10):
"""
Get ResNet18 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet18 neural network.
Examples:
>>> net = resnet18(10)
"""
return ResNet(ResidualBlockBase,
[2, 2, 2, 2],
[64, 64, 128, 256],
[64, 128, 256, 512],
[1, 2, 2, 2],
class_num,
res_base=True)
def resnet34(class_num=10):
"""
Get ResNet34 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet34 neural network.
Examples:
>>> net = resnet18(10)
"""
return ResNet(ResidualBlockBase,
[3, 4, 6, 3],
[64, 64, 128, 256],
[64, 128, 256, 512],
[1, 2, 2, 2],
class_num,
res_base=True)
def resnet50(class_num=10):
"""
Get ResNet50 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet50 neural network.
Examples:
>>> net = resnet50(10)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)
def se_resnet50(class_num=1001):
"""
Get SE-ResNet50 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of SE-ResNet50 neural network.
Examples:
>>> net = se-resnet50(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num,
use_se=True)
def resnet101(class_num=1001):
"""
Get ResNet101 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet101 neural network.
Examples:
>>> net = resnet101(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 23, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)
def resnet152(class_num=1001):
"""
Get ResNet152 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet152 neural network.
Examples:
# >>> net = resnet152(1001)
"""
return ResNet(ResidualBlock,
[3, 8, 36, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mindspore/models.git
git@gitee.com:mindspore/models.git
mindspore
models
models
r2.0.0-alpha

搜索帮助