代码拉取完成,页面将自动刷新
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Examples of image inversion attack
"""
import numpy as np
import matplotlib.pyplot as plt
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore import Tensor, context
from mindspore import nn
from mindarmour.privacy.evaluation.inversion_attack import ImageInversionAttack
from mindarmour.utils.logger import LogUtil
from examples.common.networks.lenet5.lenet5_net import LeNet5, conv, fc_with_initialize
from examples.common.dataset.data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
LOGGER.set_level('INFO')
TAG = 'InversionAttack'
# pylint: disable=invalid-name
class LeNet5_part(nn.Cell):
"""
Part of LeNet5 network.
"""
def __init__(self):
super(LeNet5_part, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16*5*5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
return x
def mnist_inversion_attack(net):
"""
Image inversion attack based on LeNet5 and MNIST dataset.
"""
# upload trained network
ckpt_path = '../../common/networks/lenet5/trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
load_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, load_dict)
# get original data and their inferred fearures
data_list = "../../common/dataset/MNIST/train"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size)
i = 0
batch_num = 1
sample_num = 30
for data in ds.create_tuple_iterator(output_numpy=True):
i += 1
images = data[0].astype(np.float32)
true_labels = data[1][: sample_num]
target_features = net(Tensor(images)).asnumpy()[:sample_num]
original_images = images[: sample_num]
if i >= batch_num:
break
# run attacking
inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), loss_weights=[1, 0.1, 5])
inversion_images = inversion_attack.generate(target_features, iters=100)
# get the predict results of inversion images on a new trained model
net2 = LeNet5()
new_ckpt_path = '../../common/networks/lenet5/new_trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
new_load_dict = load_checkpoint(new_ckpt_path)
load_param_into_net(net2, new_load_dict)
pred_labels = np.argmax(net2(Tensor(inversion_images).astype(np.float32)).asnumpy(), axis=1)
# evaluate the quality of inversion images
avg_l2_dis, avg_ssim, avg_confi = inversion_attack.evaluate(original_images, inversion_images, true_labels, net2)
LOGGER.info(TAG, 'The average L2 distance between original images and inverted images is: {}'.format(avg_l2_dis))
LOGGER.info(TAG, 'The average ssim value between original images and inverted images is: {}'.format(avg_ssim))
LOGGER.info(TAG, 'The average prediction confidence on true labels of inverted images is: {}'.format(avg_confi))
LOGGER.info(TAG, 'True labels of original images are: %s' % true_labels)
LOGGER.info(TAG, 'Predicted labels of inverted images are: %s' % pred_labels)
# plot 10 images
plot_num = min(sample_num, 10)
for n in range(1, plot_num+1):
plt.subplot(2, plot_num, n)
if n == 1:
plt.title('Original images', fontsize=16, loc='left')
plt.gray()
plt.imshow(images[n - 1].reshape(32, 32))
plt.subplot(2, plot_num, n + plot_num)
if n == 1:
plt.title('Inverted images', fontsize=16, loc='left')
plt.gray()
plt.imshow(inversion_images[n - 1].reshape(32, 32))
plt.show()
if __name__ == '__main__':
# device_target can be "CPU", "GPU" or "Ascend"
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
# attack based on complete LeNet5
mnist_inversion_attack(LeNet5())
# attack based on part of LeNet5. The network is more shallower and can lead to a better attack result
mnist_inversion_attack(LeNet5_part())
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。