2.3K Star 8K Fork 4.2K

GVPMindSpore / mindspore

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
squeezenet.py 7.95 KB
一键复制 编辑 原始数据 按行查看 历史
LiangZhibo 提交于 2021-01-26 12:50 . Change TensorAdd to Add
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Squeezenet."""
import mindspore.nn as nn
from mindspore.common import initializer as weight_init
from mindspore.ops import operations as P
class Fire(nn.Cell):
"""
Fire network definition.
"""
def __init__(self, inplanes, squeeze_planes, expand1x1_planes,
expand3x3_planes):
super(Fire, self).__init__()
self.inplanes = inplanes
self.squeeze = nn.Conv2d(inplanes,
squeeze_planes,
kernel_size=1,
has_bias=True)
self.squeeze_activation = nn.ReLU()
self.expand1x1 = nn.Conv2d(squeeze_planes,
expand1x1_planes,
kernel_size=1,
has_bias=True)
self.expand1x1_activation = nn.ReLU()
self.expand3x3 = nn.Conv2d(squeeze_planes,
expand3x3_planes,
kernel_size=3,
pad_mode='same',
has_bias=True)
self.expand3x3_activation = nn.ReLU()
self.concat = P.Concat(axis=1)
def construct(self, x):
x = self.squeeze_activation(self.squeeze(x))
return self.concat((self.expand1x1_activation(self.expand1x1(x)),
self.expand3x3_activation(self.expand3x3(x))))
class SqueezeNet(nn.Cell):
r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
accuracy with 50x fewer parameters and <0.5MB model size"
<https://arxiv.org/abs/1602.07360>`_ paper.
Get SqueezeNet neural network.
Args:
num_classes (int): Class number.
Returns:
Cell, cell instance of SqueezeNet neural network.
Examples:
>>> net = SqueezeNet(10)
"""
def __init__(self, num_classes=10):
super(SqueezeNet, self).__init__()
self.features = nn.SequentialCell([
nn.Conv2d(3,
96,
kernel_size=7,
stride=2,
pad_mode='valid',
has_bias=True),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
Fire(96, 16, 64, 64),
Fire(128, 16, 64, 64),
Fire(128, 32, 128, 128),
nn.MaxPool2d(kernel_size=3, stride=2),
Fire(256, 32, 128, 128),
Fire(256, 48, 192, 192),
Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256),
nn.MaxPool2d(kernel_size=3, stride=2),
Fire(512, 64, 256, 256),
])
# Final convolution is initialized differently from the rest
self.final_conv = nn.Conv2d(512,
num_classes,
kernel_size=1,
has_bias=True)
self.dropout = nn.Dropout(keep_prob=0.5)
self.relu = nn.ReLU()
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.custom_init_weight()
def custom_init_weight(self):
"""
Init the weight of Conv2d in the net.
"""
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
if cell is self.final_conv:
cell.weight.set_data(
weight_init.initializer('normal', cell.weight.shape,
cell.weight.dtype))
else:
cell.weight.set_data(
weight_init.initializer('he_uniform',
cell.weight.shape,
cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(
weight_init.initializer('zeros', cell.bias.shape,
cell.bias.dtype))
def construct(self, x):
x = self.features(x)
x = self.dropout(x)
x = self.final_conv(x)
x = self.relu(x)
x = self.mean(x, (2, 3))
x = self.flatten(x)
return x
class SqueezeNet_Residual(nn.Cell):
r"""SqueezeNet with simple bypass model architecture from the `"SqueezeNet:
AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size"
<https://arxiv.org/abs/1602.07360>`_ paper.
Get SqueezeNet with simple bypass neural network.
Args:
num_classes (int): Class number.
Returns:
Cell, cell instance of SqueezeNet with simple bypass neural network.
Examples:
>>> net = SqueezeNet_Residual(10)
"""
def __init__(self, num_classes=10):
super(SqueezeNet_Residual, self).__init__()
self.conv1 = nn.Conv2d(3,
96,
kernel_size=7,
stride=2,
pad_mode='valid',
has_bias=True)
self.fire2 = Fire(96, 16, 64, 64)
self.fire3 = Fire(128, 16, 64, 64)
self.fire4 = Fire(128, 32, 128, 128)
self.fire5 = Fire(256, 32, 128, 128)
self.fire6 = Fire(256, 48, 192, 192)
self.fire7 = Fire(384, 48, 192, 192)
self.fire8 = Fire(384, 64, 256, 256)
self.fire9 = Fire(512, 64, 256, 256)
# Final convolution is initialized differently from the rest
self.conv10 = nn.Conv2d(512, num_classes, kernel_size=1, has_bias=True)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2)
self.add = P.Add()
self.dropout = nn.Dropout(keep_prob=0.5)
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.custom_init_weight()
def custom_init_weight(self):
"""
Init the weight of Conv2d in the net.
"""
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
if cell is self.conv10:
cell.weight.set_data(
weight_init.initializer('normal', cell.weight.shape,
cell.weight.dtype))
else:
cell.weight.set_data(
weight_init.initializer('xavier_uniform',
cell.weight.shape,
cell.weight.dtype))
if cell.bias is not None:
cell.bias.set_data(
weight_init.initializer('zeros', cell.bias.shape,
cell.bias.dtype))
def construct(self, x):
"""
Construct squeezenet_residual.
"""
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.fire2(x)
x = self.add(x, self.fire3(x))
x = self.fire4(x)
x = self.max_pool2d(x)
x = self.add(x, self.fire5(x))
x = self.fire6(x)
x = self.add(x, self.fire7(x))
x = self.fire8(x)
x = self.max_pool2d(x)
x = self.add(x, self.fire9(x))
x = self.dropout(x)
x = self.conv10(x)
x = self.relu(x)
x = self.mean(x, (2, 3))
x = self.flatten(x)
return x
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore
r1.1

搜索帮助