diff --git a/docs/api/api_python/mindarmour.adv_robustness.attacks.rst b/docs/api/api_python/mindarmour.adv_robustness.attacks.rst index cb5bfad455f65a25d3eb29197ad986cb489a24b0..8f1357f12364922921a7c74b096bedf7eb9ee686 100644 --- a/docs/api/api_python/mindarmour.adv_robustness.attacks.rst +++ b/docs/api/api_python/mindarmour.adv_robustness.attacks.rst @@ -209,6 +209,36 @@ mindarmour.adv_robustness.attacks 返回: - **numpy.ndarray** - 生成的对抗样本。 +.. py:class:: mindarmour.adv_robustness.attacks.AutoProjectedGradientDescent(network, eps=8 / 255, eps_iter=0.1, bounds=(0.0, 1.0), is_targeted=False, nb_iter=10, norm_level='inf', loss_fn=None, eot_iter=1, thr_decr=0.75) + + 自动投影梯度下降(AutoProjected Gradient Descent)攻击是基本迭代法的变体,也是PGD方法的升级版,在这种方法中,每次迭代之后,扰动被投影在指定半径的p范数球上(除了剪切对抗样本的值,使其位于允许的数据范围内),并引入了自适应步长和动量来加速收敛并提高攻击性能。这是Croce等人提出的用于对抗性训练的攻击。 + + 参考文献:`Croce and Hein, "Reliable evaluation of adversarial robustness with an ensemble of \ + diverse parameter-free attacks" in ICML, 2020 `_。 + + 参数: + - **network** (Cell) - 目标模型。 + - **eps** (float) - 攻击产生的对抗性扰动占数据范围的比例。默认值:``8 / 255``。 + - **eps_iter** (float) - 攻击产生的单步对抗扰动占数据范围的比例。默认值:``0.1``。 + - **bounds** (tuple) - 数据的上下界,表示数据范围。以(数据最小值,数据最大值)的形式出现。默认值:``(0.0, 1.0)``。 + - **is_targeted** (bool) - 如果为 ``True``,则为目标攻击。如果为 ``False``,则为无目标攻击。默认值:``False``。 + - **nb_iter** (int) - 迭代次数。默认值:``10``。 + - **norm_level** (Union[int, str, numpy.inf]) - 范数类型。可取值:np.inf、1、2。默认值:``inf``。 + - **loss_fn** (Union[Loss, None]) - 用于优化的损失函数。如果为 ``None``,则输入网络已配备损失函数。默认值:``None``。 + - **eot_iter** (int) - EOT的迭代次数。默认值: ``1``。 + - **thr_decr** (float) - 步长更新的参数。默认值: ``0.75``。 + + .. py:method:: generate(inputs, labels) + + 根据输入样本和原始标签生成对抗样本。通过带有参数norm_level的投影方法归一化扰动。 + + 参数: + - **inputs** (Union[numpy.ndarray, tuple]) - 良性输入样本,用于创建对抗样本。 + - **labels** (Union[numpy.ndarray, tuple]) - 原始/目标标签。若每个输入有多个标签,将它包装在元组中。 + + 返回: + - **numpy.ndarray** - 生成的对抗样本。 + .. py:class:: mindarmour.adv_robustness.attacks.DiverseInputIterativeMethod(network, eps=0.3, bounds=(0.0, 1.0), is_targeted=False, prob=0.5, loss_fn=None) 多样性输入迭代法(Diverse Input Iterative Method)攻击遵循基本迭代法,并在每次迭代时对输入数据应用随机转换。对输入数据的这种转换可以提高对抗样本的可转移性。 diff --git a/mindarmour/adv_robustness/attacks/__init__.py b/mindarmour/adv_robustness/attacks/__init__.py index 1b995cb0f939583c7bd39ad416e5eb2bb5662eca..d3816413b571cd50f500f8b858590963d4fc57df 100644 --- a/mindarmour/adv_robustness/attacks/__init__.py +++ b/mindarmour/adv_robustness/attacks/__init__.py @@ -18,8 +18,8 @@ in making adversarial examples. from .gradient_method import FastGradientMethod, FastGradientSignMethod, RandomFastGradientMethod, \ RandomFastGradientSignMethod, LeastLikelyClassMethod, RandomLeastLikelyClassMethod from .iterative_gradient_method import IterativeGradientMethod, BasicIterativeMethod, MomentumIterativeMethod, \ - ProjectedGradientDescent, DiverseInputIterativeMethod, MomentumDiverseInputIterativeMethod, \ - VarianceTuningMomentumIterativeMethod, VarianceTuningNesterovIterativeMethod + ProjectedGradientDescent, AutoProjectedGradientDescent, DiverseInputIterativeMethod, \ + MomentumDiverseInputIterativeMethod, VarianceTuningMomentumIterativeMethod, VarianceTuningNesterovIterativeMethod from .deep_fool import DeepFool from .jsma import JSMAAttack from .carlini_wagner import CarliniWagnerL2Attack @@ -44,6 +44,7 @@ __all__ = ['FastGradientMethod', 'VarianceTuningMomentumIterativeMethod', 'VarianceTuningNesterovIterativeMethod', 'ProjectedGradientDescent', + 'AutoProjectedGradientDescent', 'DiverseInputIterativeMethod', 'MomentumDiverseInputIterativeMethod', 'DeepFool', diff --git a/mindarmour/adv_robustness/attacks/iterative_gradient_method.py b/mindarmour/adv_robustness/attacks/iterative_gradient_method.py index f1c8420b916dbc66657cc9666a4c813e794421ae..8cdcd8d230be5d8d286e863430f90a4a8d5b29fb 100644 --- a/mindarmour/adv_robustness/attacks/iterative_gradient_method.py +++ b/mindarmour/adv_robustness/attacks/iterative_gradient_method.py @@ -18,6 +18,9 @@ import copy import numpy as np from PIL import Image, ImageOps +import mindspore.nn as nn +import mindspore.ops as ops +import mindspore as ms from mindspore.nn import Cell from mindarmour.utils.logger import LogUtil @@ -457,7 +460,7 @@ class ProjectedGradientDescent(BasicIterativeMethod): def generate(self, inputs, labels): """ - Iteratively generate adversarial examples based on BIM method. The + Iteratively generate adversarial examples. The perturbation is normalized by projected method with parameter norm_level . Args: @@ -496,6 +499,212 @@ class ProjectedGradientDescent(BasicIterativeMethod): return adv_x +class AutoProjectedGradientDescent(BasicIterativeMethod): + """ + APGD (Adaptive Projected Gradient Descent) is an iterative method for generating adversarial + examples that seeks to minimize the perturbation while ensuring that the adversarial sample + remains within a predefined Lp-ball around the original input. + + + Reference: `Croce and Hein, "Reliable evaluation of adversarial robustness with an ensemble of \ + diverse parameter-free attacks" in ICML, 2020 `_. + + Args: + network (Cell): Target model. + eps (float): Proportion of adversarial perturbation generated by the + attack to data range. Default: ``8 / 255``. + eps_iter (float): Proportion of single-step adversarial perturbation + generated by the attack to data range. Default: ``0.1``. + bounds (tuple): Upper and lower bounds of data, indicating the data range. + In form of (clip_min, clip_max). Default: ``(0.0, 1.0)``. + is_targeted (bool): If ``True``, targeted attack. If ``False``, untargeted + attack. Default: ``False``. + nb_iter (int) : Number of iteration. Default: ``10``. + norm_level (Union[int, str, numpy.inf]): Order of the norm. Possible values: + np.inf, 1 or 2. Default: ``inf``. + loss_fn (Union[Loss, None]): Loss function for optimization. If ``None``, the input network \ + is already equipped with loss function. Default: ``None``. + eot_iter (int): Number of iteration for EOT. Default: ``1``. + thr_decr (float): Parameter for step-size update. Default: ``0.75``. + + Examples: + >>> import mindspore.ops as ops + >>> from mindarmour.adv_robustness.attacks import AutoProjectedGradientDescent + >>> class Net(nn.Cell): + ... def __init__(self): + ... super(Net, self).__init__() + ... self._softmax = ops.Softmax() + ... def construct(self, inputs): + ... out = self._softmax(inputs) + ... return out + >>> net = Net() + >>> attack = AutoProjectedGradientDescent(net, eps=0.3) + >>> inputs = np.asarray([[0.1, 0.2, 0.7]], np.float32) + >>> labels = np.asarray([2],np.int32) + >>> labels = np.eye(3)[labels].astype(np.float32) + >>> net = Net() + >>> adv_x = attack.generate(inputs, labels) + """ + + def __init__(self, network, eps=8 / 255, eps_iter=0.1, bounds=(0.0, 1.0), is_targeted=False, nb_iter=10, + norm_level='inf', loss_fn=None, eot_iter=1, thr_decr=0.75): + super(AutoProjectedGradientDescent, self).__init__(network, eps=eps, eps_iter=eps_iter, bounds=bounds, + is_targeted=is_targeted, nb_iter=nb_iter, loss_fn=loss_fn) + self._norm = norm_level + self._steps = nb_iter + self._eot_iter = eot_iter + self._eps = eps + self._loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + self._network = network + self._norm_level = check_norm_level(norm_level) + self._thr_decr = check_value_positive('thr_decr', thr_decr) + + def generate(self, inputs, labels): + """ + Iteratively generate adversarial examples. The perturbation + is normalized by projected method with parameter norm_level. + + Args: + inputs (Union[numpy.ndarray, tuple]): Benign input samples used as references to + create adversarial examples. + labels (Union[numpy.ndarray, tuple]): Original/target labels. \ + For each input if it has more than one label, it is wrapped in a tuple. + + Returns: + numpy.ndarray, generated adversarial examples. + """ + inputs = ms.Tensor(inputs) + labels = ms.Tensor(labels) + x = inputs.copy() if len(inputs.shape) == 4 else inputs.copy().unsqueeze(0) + y = labels.copy() if len(labels.shape) == 1 else labels.copy().unsqueeze(0) + steps_ratio, steps_min_ratio, size_decr_ratio = 0.22, 0.06, 0.03 + steps, steps_min, size_decr = max(int(steps_ratio * self._steps), 1), max( + int(steps_min_ratio * self._steps), 1), max(int(size_decr_ratio * self._steps), 1) + + if self._norm == 'inf': + t = 2 * ms.Tensor(np.random.rand(*x.shape).astype(np.float32)) - 1 + x_adv = x + self._eps * ms.Tensor(np.ones([x.shape[0], 1, 1, 1]), dtype=ms.float32) * t / ( + t.reshape([t.shape[0], -1]).abs().max(axis=1, keepdims=True)[0].reshape([-1, 1, 1, 1])) + elif self._norm == 'L2': + t = ms.Tensor(np.random.rand(*x.shape).astype(np.float32)) + x_adv = x + self._eps * ms.Tensor(np.ones([x.shape[0], 1, 1, 1]), dtype=ms.float32) * t / ( + (t ** 2).sum(axis=(1, 2, 3), keepdims=True).sqrt() + 1e-12) + + x_adv = x_adv.clip(0.0, 1.0) + x_best = x_adv.copy() + x_best_adv = x_adv.copy() + loss_steps = np.zeros((self._steps, x.shape[0])) + loss_best_steps = np.zeros((self._steps + 1, x.shape[0])) + acc_steps = np.zeros_like(loss_best_steps) + criterion_indiv = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='none') + grad = ops.ZerosLike()(x) + for _ in range(self._eot_iter): + grad += GradWrapWithLoss(WithLossCell(self._network, self._loss_fn))(x_adv, y) + + grad = ops.div(grad, self._eot_iter) + grad_best = grad.copy() + logits = self._network(x_adv) + loss_indiv = criterion_indiv(logits, y) + acc = logits.max(1)[1] == y + acc_steps[0] = acc + 0 + loss_best = loss_indiv.copy() + step_size = self._eps * ms.Tensor(np.ones([x.shape[0], 1, 1, 1]), dtype=ms.float32) * ms.Tensor([2.0]).reshape( + [1, 1, 1, 1]) + x_adv_old = x_adv.copy() + k = steps + 0 + u = np.arange(x.shape[0]) + counter = 0 + momentum = 0.75 + + loss_best_last_check = loss_best.copy() + reduced_last_check = np.zeros(loss_best.shape) == np.zeros(loss_best.shape) + x_best = self._optimize_attack(x_adv, x_best, x_best_adv, x_adv_old, momentum, counter, k, step_size, grad, + x, y, logits, acc, acc_steps, loss_indiv, loss_steps, loss_best, grad_best, + loss_best_steps, reduced_last_check, loss_best_last_check, u, + size_decr, steps_min) + + return x_best.asnumpy() + + def _optimize_attack(self, x_adv, x_best, x_best_adv, x_adv_old, momentum, counter, k, step_size, grad, + x, y, logits, acc, acc_steps, loss_indiv, loss_steps, loss_best, grad_best, loss_best_steps, + reduced_last_check, loss_best_last_check, u, size_decr, steps_min): + """ + This function computes the gradient and optimizes it by updating x_adv with multiple iterations. + """ + for i in range(self._steps): + grad2 = x_adv - x_adv_old + x_adv_old = x_adv.copy() + a = momentum if i > 0 else 1.0 + if self._norm == 'inf': + x_adv_1 = x_adv + step_size * ops.operations.Sign()(grad) + x_adv_1 = ops.clip_by_value(ops.clip_by_value(x_adv_1, x - self._eps, x + self._eps), 0.0, 1.0) + x_adv_1 = ops.clip_by_value( + ops.clip_by_value(x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a), x - self._eps, x + + self._eps), 0.0, 1.0) + elif self._norm == 'L2': + x_adv_1 = x_adv + step_size * grad / ( + ms.ops.square(grad).sum(axis=(1, 2, 3), keepdims=True).sqrt() + 1e-12) + x_adv_1 = ops.clip_by_value(x + (x_adv_1 - x) / ( + ((x_adv_1 - x) ** 2).sum(axis=(1, 2, 3), keepdims=True).sqrt() + 1e-12) * ops.minimum( + self._eps * ops.ones(x.shape, type=ms.float32), + ((x_adv_1 - x) ** 2).sum(axis=(1, 2, 3), keepdims=True).sqrt()), + 0.0, 1.0) + x_adv_1 = x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a) + x_adv_1 = ops.clip_by_value(x + (x_adv_1 - x) / ( + ((x_adv_1 - x) ** 2).sum(axis=(1, 2, 3), keepdims=True).sqrt() + 1e-12) * ops.minimum( + self._eps * ops.ones(x.shape, type=ms.float32), + ((x_adv_1 - x) ** 2).sum(axis=(1, 2, 3), keepdims=True).sqrt() + 1e-12), 0.0, 1.0) + x_adv = x_adv_1 + 0. + grad = ops.zeros_like(x) + for _ in range(self._eot_iter): + grad += GradWrapWithLoss(WithLossCell(self._network, self._loss_fn))(x_adv, y) + grad /= float(self._eot_iter) + pred = logits.max(1)[1] == y + acc = ops.logical_and(acc, pred) + acc_steps[i + 1] = acc + 0 + zero_indices = ops.nonzero(pred == 0).squeeze() + x_best_adv[zero_indices] = x_adv[zero_indices] + ops.zeros_like(x_adv[zero_indices]) + y1 = loss_indiv.copy() + loss_steps[i] = y1.asnumpy() + 0 + ind = (y1 >= loss_best).nonzero().squeeze() + x_best[ind] = x_adv[ind].copy() + grad_best[ind] = grad[ind].copy() + loss_best[ind] = y1[ind] + 0 + loss_best_steps[i + 1] = loss_best.asnumpy() + 0 + counter += 1 + + if counter == k: + fl_oscillation = self._check_oscillation(loss_steps, i, k, + k3=self._thr_decr) + fl_reduce_no_impr = (~reduced_last_check) * (loss_best_last_check.asnumpy() >= loss_best.asnumpy()) + fl_oscillation = ~(~fl_oscillation * ~fl_reduce_no_impr) + reduced_last_check = np.copy(fl_oscillation) + loss_best_last_check = loss_best.copy() + + if np.sum(fl_oscillation) > 0: + step_size_np = step_size.asnumpy() + step_size_np[u[fl_oscillation]] /= 2.0 + step_size = ms.Tensor(step_size_np) + fl_oscillation = np.where(fl_oscillation) + fl_oscillation = ms.Tensor(fl_oscillation) + x_adv[fl_oscillation] = x_best[fl_oscillation].copy() + grad[fl_oscillation] = grad_best[fl_oscillation].copy() + counter = 0 + k = np.maximum(k - size_decr, steps_min) + return x_best + + def _check_oscillation(self, x, j, k, k3=0.75): + """ + This function checks if there is oscillation in a given set of numbers. It counts how many + times the numbers go up in a certain range around a specific number. If this count is less + than or equal to a certain threshold, there is no oscillation. Otherwise, there is oscillation. + """ + t = np.zeros(x.shape[1]) + for counter in range(k): + t += x[j - counter] > x[j - counter - 1] + return t <= k * k3 * np.ones(t.shape) + + class DiverseInputIterativeMethod(BasicIterativeMethod): """ The Diverse Input Iterative Method attack follows the basic iterative method,