256 Star 931 Fork 105

MindSpore / mindarmour

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
model_coverage_metrics.py 30.16 KB
一键复制 编辑 原始数据 按行查看 历史
jxl 提交于 2023-10-25 11:31 . fix a bug of summary
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Model-Test Coverage Metrics.
"""
from abc import abstractmethod
from collections import defaultdict
import math
import time
import numpy as np
from mindspore import Tensor
from mindspore import Model
from mindspore.train.summary.summary_record import _get_summary_tensor_data
from mindarmour.utils._check_param import check_model, check_numpy_param, check_int_positive, \
check_param_type, check_value_positive
from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance()
TAG = 'CoverageMetrics'
class CoverageMetrics:
"""
The abstract base class for Neuron coverage classes calculating coverage metrics.
As we all known, each neuron output of a network will have a output range after training (we call it original
range), and test dataset is used to estimate the accuracy of the trained network. However, neurons' output
distribution would be different with different test datasets. Therefore, similar to function fuzz, model fuzz means
testing those neurons' outputs and estimating the proportion of original range that has emerged with test
datasets.
Reference: `DeepGauge: Multi-Granularity Testing Criteria for Deep Learning Systems
<https://arxiv.org/abs/1803.07519>`_
Args:
model (Model): The pre-trained model which waiting for testing.
incremental (bool): Metrics will be calculate in incremental way or not. Default: ``False``.
batch_size (int): The number of samples in a fuzz test batch. Default: ``32``.
"""
def __init__(self, model, incremental=False, batch_size=32):
self._model = check_model('model', model, Model)
self.incremental = check_param_type('incremental', incremental, bool)
self.batch_size = check_int_positive('batch_size', batch_size)
self._activate_table = defaultdict(list)
@abstractmethod
def get_metrics(self, dataset):
"""
Calculate coverage metrics of given dataset.
Args:
dataset (numpy.ndarray): Dataset used to calculate coverage metrics.
Raises:
NotImplementedError: It is an abstract method.
"""
msg = 'The function get_metrics() is an abstract method in class `CoverageMetrics`, and should be' \
' implemented in child class.'
LOGGER.error(TAG, msg)
raise NotImplementedError(msg)
def _init_neuron_activate_table(self, data):
"""
Initialise the activate table of each neuron in the model with format:
{'layer1': [n1, n2, n3, ..., nn], 'layer2': [n1, n2, n3, ..., nn], ...}
Args:
data (numpy.ndarray): Data used for initialising the activate table.
Return:
dict, return a activate_table.
"""
self._model.predict(Tensor(data))
time.sleep(0.01)
layer_out = _get_summary_tensor_data()
if not layer_out:
msg = 'User must use TensorSummary() operation to specify the middle layer of the model participating in ' \
'the coverage calculation.'
LOGGER.error(TAG, msg)
raise ValueError(msg)
activate_table = defaultdict()
for layer, value in layer_out.items():
activate_table[layer] = np.zeros(value.shape[1], np.bool)
return activate_table
def _get_bounds(self, train_dataset):
"""
Update the lower and upper boundaries of neurons' outputs.
Args:
train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries.
Return:
- numpy.ndarray, upper bounds of neuron' outputs.
- numpy.ndarray, lower bounds of neuron' outputs.
"""
upper_bounds = defaultdict(list)
lower_bounds = defaultdict(list)
batches = math.ceil(train_dataset.shape[0] / self.batch_size)
for i in range(batches):
inputs = train_dataset[i * self.batch_size: (i + 1) * self.batch_size]
self._model.predict(Tensor(inputs))
time.sleep(0.01)
layer_out = _get_summary_tensor_data()
for layer, tensor in layer_out.items():
value = tensor.asnumpy()
value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))]))
min_value = np.min(value, axis=0)
max_value = np.max(value, axis=0)
if np.any(upper_bounds[layer]):
max_flag = upper_bounds[layer] > max_value
min_flag = lower_bounds[layer] < min_value
upper_bounds[layer] = upper_bounds[layer] * max_flag + max_value * (1 - max_flag)
lower_bounds[layer] = lower_bounds[layer] * min_flag + min_value * (1 - min_flag)
else:
upper_bounds[layer] = max_value
lower_bounds[layer] = min_value
return upper_bounds, lower_bounds
def _activate_rate(self):
"""
Calculate the activate rate of neurons.
"""
total_neurons = 0
activated_neurons = 0
for _, value in self._activate_table.items():
activated_neurons += np.sum(value)
total_neurons += len(value)
activate_rate = activated_neurons / total_neurons
return activate_rate
class NeuronCoverage(CoverageMetrics):
"""
Calculate the neurons activated coverage. Neuron is activated when its output is greater than the threshold.
Neuron coverage equals the proportion of activated neurons to total neurons in the network.
Args:
model (Model): The pre-trained model which waiting for testing.
threshold (float): Threshold used to determined neurons is activated or not. Default: ``0.1``.
incremental (bool): Metrics will be calculate in incremental way or not. Default: ``False``.
batch_size (int): The number of samples in a fuzz test batch. Default: ``32``.
"""
def __init__(self, model, threshold=0.1, incremental=False, batch_size=32):
super(NeuronCoverage, self).__init__(model, incremental, batch_size)
threshold = check_param_type('threshold', threshold, float)
self.threshold = check_value_positive('threshold', threshold)
def get_metrics(self, dataset):
"""
Get the metric of neuron coverage: the proportion of activated neurons to total neurons in the network.
Args:
dataset (numpy.ndarray): Dataset used to calculate coverage metrics.
Returns:
float, the metric of 'neuron coverage'.
Examples:
>>> from mindspore.common.initializer import TruncatedNormal
>>> from mindspore.ops import operations as P
>>> from mindspore.train import Model
>>> from mindspore.ops import TensorSummary
>>> from mindarmour.fuzz_testing import NeuronCoverage
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.conv1 = nn.Conv2d(1, 6, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid")
... self.conv2 = nn.Conv2d(6, 16, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid")
... self.fc1 = nn.Dense(16 * 5 * 5, 120, TruncatedNormal(0.02), TruncatedNormal(0.02))
... self.fc2 = nn.Dense(120, 84, TruncatedNormal(0.02), TruncatedNormal(0.02))
... self.fc3 = nn.Dense(84, 10, TruncatedNormal(0.02), TruncatedNormal(0.02))
... self.relu = nn.ReLU()
... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
... self.reshape = P.Reshape()
... self.summary = TensorSummary()
... def construct(self, x):
... x = self.conv1(x)
... x = self.relu(x)
... self.summary('conv1', x)
... x = self.max_pool2d(x)
... x = self.conv2(x)
... x = self.relu(x)
... self.summary('conv2', x)
... x = self.max_pool2d(x)
... x = self.reshape(x, (-1, 16 * 5 * 5))
... x = self.fc1(x)
... x = self.relu(x)
... self.summary('fc1', x)
... x = self.fc2(x)
... x = self.relu(x)
... self.summary('fc2', x)
... x = self.fc3(x)
... self.summary('fc3', x)
... return x
>>> net = Net()
>>> model = Model(net)
>>> batch_size = 8
>>> num_classe = 10
>>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
>>> test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
>>> nc = NeuronCoverage(model, threshold=0.1)
>>> nc_metrics = nc.get_metrics(test_images)
"""
dataset = check_numpy_param('dataset', dataset)
batches = math.ceil(dataset.shape[0] / self.batch_size)
if not self.incremental or not self._activate_table:
self._activate_table = self._init_neuron_activate_table(dataset[0:1])
for i in range(batches):
inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size]
self._model.predict(Tensor(inputs))
time.sleep(0.01)
layer_out = _get_summary_tensor_data()
for layer, tensor in layer_out.items():
value = tensor.asnumpy()
value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))]))
activate = np.sum(value > self.threshold, axis=0) > 0
self._activate_table[layer] = np.logical_or(self._activate_table[layer], activate)
neuron_coverage = self._activate_rate()
return neuron_coverage
class TopKNeuronCoverage(CoverageMetrics):
"""
Calculate the top k activated neurons coverage. Neuron is activated when its output has the top k largest value in
that hidden layers. Top k neurons coverage equals the proportion of activated neurons to total neurons in the
network.
Args:
model (Model): The pre-trained model which waiting for testing.
top_k (int): Neuron is activated when its output has the top k largest value in that hidden layers.
Default: ``3``.
incremental (bool): Metrics will be calculate in incremental way or not. Default: ``False``.
batch_size (int): The number of samples in a fuzz test batch. Default: ``32``.
"""
def __init__(self, model, top_k=3, incremental=False, batch_size=32):
super(TopKNeuronCoverage, self).__init__(model, incremental=incremental, batch_size=batch_size)
self.top_k = check_int_positive('top_k', top_k)
def get_metrics(self, dataset):
"""
Get the metric of Top K activated neuron coverage.
Args:
dataset (numpy.ndarray): Dataset used to calculate coverage metrics.
Returns:
float, the metrics of 'top k neuron coverage'.
Examples:
>>> from mindspore.common.initializer import TruncatedNormal
>>> from mindspore.ops import operations as P
>>> from mindspore.train import Model
>>> from mindspore.ops import TensorSummary
>>> from mindarmour.fuzz_testing import TopKNeuronCoverage
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.conv1 = nn.Conv2d(1, 6, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid")
... self.conv2 = nn.Conv2d(6, 16, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid")
... self.fc1 = nn.Dense(16 * 5 * 5, 120, TruncatedNormal(0.02), TruncatedNormal(0.02))
... self.fc2 = nn.Dense(120, 84, TruncatedNormal(0.02), TruncatedNormal(0.02))
... self.fc3 = nn.Dense(84, 10, TruncatedNormal(0.02), TruncatedNormal(0.02))
... self.relu = nn.ReLU()
... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
... self.reshape = P.Reshape()
... self.summary = TensorSummary()
... def construct(self, x):
... x = self.conv1(x)
... x = self.relu(x)
... self.summary('conv1', x)
... x = self.max_pool2d(x)
... x = self.conv2(x)
... x = self.relu(x)
... self.summary('conv2', x)
... x = self.max_pool2d(x)
... x = self.reshape(x, (-1, 16 * 5 * 5))
... x = self.fc1(x)
... x = self.relu(x)
... self.summary('fc1', x)
... x = self.fc2(x)
... x = self.relu(x)
... self.summary('fc2', x)
... x = self.fc3(x)
... self.summary('fc3', x)
... return x
>>> net = Net()
>>> model = Model(net)
>>> batch_size = 8
>>> num_classe = 10
>>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
>>> test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
>>> tknc = TopKNeuronCoverage(model, top_k=3)
>>> metrics = tknc.get_metrics(test_images)
"""
dataset = check_numpy_param('dataset', dataset)
batches = math.ceil(dataset.shape[0] / self.batch_size)
if not self.incremental or not self._activate_table:
self._activate_table = self._init_neuron_activate_table(dataset[0:1])
for i in range(batches):
inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size]
self._model.predict(Tensor(inputs))
time.sleep(0.01)
layer_out = _get_summary_tensor_data()
for layer, tensor in layer_out.items():
value = tensor.asnumpy()
if len(value.shape) > 2:
value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))]))
top_k_value = np.sort(value)[:, -self.top_k].reshape(value.shape[0], 1)
top_k_value = np.sum((value - top_k_value) >= 0, axis=0) > 0
self._activate_table[layer] = np.logical_or(self._activate_table[layer], top_k_value)
top_k_neuron_coverage = self._activate_rate()
return top_k_neuron_coverage
class SuperNeuronActivateCoverage(CoverageMetrics):
"""
Get the metric of 'super neuron activation coverage'. :math:`SNAC = |UpperCornerNeuron|/|N|`. SNAC refers to the
proportion of neurons whose neurons output value in the test set exceeds the upper bounds of the corresponding
neurons output value in the training set.
Args:
model (Model): The pre-trained model which waiting for testing.
train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries.
incremental (bool): Metrics will be calculate in incremental way or not. Default: ``False``.
batch_size (int): The number of samples in a fuzz test batch. Default: ``32``.
"""
def __init__(self, model, train_dataset, incremental=False, batch_size=32):
super(SuperNeuronActivateCoverage, self).__init__(model, incremental=incremental, batch_size=batch_size)
train_dataset = check_numpy_param('train_dataset', train_dataset)
self.upper_bounds, self.lower_bounds = self._get_bounds(train_dataset=train_dataset)
def get_metrics(self, dataset):
"""
Get the metric of 'super neuron activation coverage'.
Args:
dataset (numpy.ndarray): Dataset used to calculate coverage metrics.
Returns:
float, the metric of 'super neuron activation coverage'.
Examples:
>>> from mindspore.common.initializer import TruncatedNormal
>>> from mindspore.ops import operations as P
>>> from mindspore.train import Model
>>> from mindspore.ops import TensorSummary
>>> from mindarmour.fuzz_testing import SuperNeuronActivateCoverage
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.conv1 = nn.Conv2d(1, 6, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid")
... self.conv2 = nn.Conv2d(6, 16, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid")
... self.fc1 = nn.Dense(16 * 5 * 5, 120, TruncatedNormal(0.02), TruncatedNormal(0.02))
... self.fc2 = nn.Dense(120, 84, TruncatedNormal(0.02), TruncatedNormal(0.02))
... self.fc3 = nn.Dense(84, 10, TruncatedNormal(0.02), TruncatedNormal(0.02))
... self.relu = nn.ReLU()
... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
... self.reshape = P.Reshape()
... self.summary = TensorSummary()
... def construct(self, x):
... x = self.conv1(x)
... x = self.relu(x)
... self.summary('conv1', x)
... x = self.max_pool2d(x)
... x = self.conv2(x)
... x = self.relu(x)
... self.summary('conv2', x)
... x = self.max_pool2d(x)
... x = self.reshape(x, (-1, 16 * 5 * 5))
... x = self.fc1(x)
... x = self.relu(x)
... self.summary('fc1', x)
... x = self.fc2(x)
... x = self.relu(x)
... self.summary('fc2', x)
... x = self.fc3(x)
... self.summary('fc3', x)
... return x
>>> net = Net()
>>> model = Model(net)
>>> batch_size = 8
>>> num_classe = 10
>>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
>>> test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
>>> snac = SuperNeuronActivateCoverage(model, train_images)
>>> metrics = snac.get_metrics(test_images)
"""
dataset = check_numpy_param('dataset', dataset)
if not self.incremental or not self._activate_table:
self._activate_table = self._init_neuron_activate_table(dataset[0:1])
batches = math.ceil(dataset.shape[0] / self.batch_size)
for i in range(batches):
inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size]
self._model.predict(Tensor(inputs))
time.sleep(0.01)
layer_out = _get_summary_tensor_data()
for layer, tensor in layer_out.items():
value = tensor.asnumpy()
if len(value.shape) > 2:
value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))]))
activate = np.sum(value > self.upper_bounds[layer], axis=0) > 0
self._activate_table[layer] = np.logical_or(self._activate_table[layer], activate)
snac = self._activate_rate()
return snac
class NeuronBoundsCoverage(SuperNeuronActivateCoverage):
"""
Get the metric of 'neuron boundary coverage' :math:`NBC = (|UpperCornerNeuron| + |LowerCornerNeuron|)/(2*|N|)`,
where :math:`|N|` is the number of neurons, NBC refers to the proportion of neurons whose neurons output value in
the test dataset exceeds the upper and lower bounds of the corresponding neurons output value in the training
dataset.
Args:
model (Model): The pre-trained model which waiting for testing.
train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries.
incremental (bool): Metrics will be calculate in incremental way or not. Default: ``False``.
batch_size (int): The number of samples in a fuzz test batch. Default: ``32``.
"""
def __init__(self, model, train_dataset, incremental=False, batch_size=32):
super(NeuronBoundsCoverage, self).__init__(model, train_dataset, incremental=incremental, batch_size=batch_size)
def get_metrics(self, dataset):
"""
Get the metric of 'neuron boundary coverage'.
Args:
dataset (numpy.ndarray): Dataset used to calculate coverage metrics.
Returns:
float, the metric of 'neuron boundary coverage'.
Examples:
>>> from mindspore.common.initializer import TruncatedNormal
>>> from mindspore.ops import operations as P
>>> from mindspore.train import Model
>>> from mindspore.ops import TensorSummary
>>> from mindarmour.fuzz_testing import NeuronBoundsCoverage
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.conv1 = nn.Conv2d(1, 6, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid")
... self.conv2 = nn.Conv2d(6, 16, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid")
... self.fc1 = nn.Dense(16 * 5 * 5, 120, TruncatedNormal(0.02), TruncatedNormal(0.02))
... self.fc2 = nn.Dense(120, 84, TruncatedNormal(0.02), TruncatedNormal(0.02))
... self.fc3 = nn.Dense(84, 10, TruncatedNormal(0.02), TruncatedNormal(0.02))
... self.relu = nn.ReLU()
... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
... self.reshape = P.Reshape()
... self.summary = TensorSummary()
... def construct(self, x):
... x = self.conv1(x)
... x = self.relu(x)
... self.summary('conv1', x)
... x = self.max_pool2d(x)
... x = self.conv2(x)
... x = self.relu(x)
... self.summary('conv2', x)
... x = self.max_pool2d(x)
... x = self.reshape(x, (-1, 16 * 5 * 5))
... x = self.fc1(x)
... x = self.relu(x)
... self.summary('fc1', x)
... x = self.fc2(x)
... x = self.relu(x)
... self.summary('fc2', x)
... x = self.fc3(x)
... self.summary('fc3', x)
... return x
>>> net = Net()
>>> model = Model(net)
>>> batch_size = 8
>>> num_classe = 10
>>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
>>> test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
>>> nbc = NeuronBoundsCoverage(model, train_images)
>>> metrics = nbc.get_metrics(test_images)
"""
dataset = check_numpy_param('dataset', dataset)
if not self.incremental or not self._activate_table:
self._activate_table = self._init_neuron_activate_table(dataset[0:1])
batches = math.ceil(dataset.shape[0] / self.batch_size)
for i in range(batches):
inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size]
self._model.predict(Tensor(inputs))
time.sleep(0.01)
layer_out = _get_summary_tensor_data()
for layer, tensor in layer_out.items():
value = tensor.asnumpy()
if len(value.shape) > 2:
value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))]))
outer = np.logical_or(value > self.upper_bounds[layer], value < self.lower_bounds[layer])
activate = np.sum(outer, axis=0) > 0
self._activate_table[layer] = np.logical_or(self._activate_table[layer], activate)
nbc = self._activate_rate()
return nbc
class KMultisectionNeuronCoverage(SuperNeuronActivateCoverage):
"""
Get the metric of 'k-multisection neuron coverage'. KMNC measures how thoroughly the given set of test inputs
covers the range of neurons output values derived from training dataset.
Args:
model (Model): The pre-trained model which waiting for testing.
train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries.
segmented_num (int): The number of segmented sections of neurons' output intervals. Default: ``100``.
incremental (bool): Metrics will be calculate in incremental way or not. Default: ``False``.
batch_size (int): The number of samples in a fuzz test batch. Default: ``32``.
"""
def __init__(self, model, train_dataset, segmented_num=100, incremental=False, batch_size=32):
super(KMultisectionNeuronCoverage, self).__init__(model, train_dataset, incremental=incremental,
batch_size=batch_size)
self.segmented_num = check_int_positive('segmented_num', segmented_num)
self.intervals = defaultdict(list)
for keys in self.upper_bounds.keys():
self.intervals[keys] = (self.upper_bounds[keys] - self.lower_bounds[keys]) / self.segmented_num
def _init_k_multisection_table(self, data):
""" Initial the activate table."""
self._model.predict(Tensor(data))
time.sleep(0.01)
layer_out = _get_summary_tensor_data()
activate_section_table = defaultdict()
for layer, value in layer_out.items():
activate_section_table[layer] = np.zeros((value.shape[1], self.segmented_num), np.bool)
return activate_section_table
def get_metrics(self, dataset):
"""
Get the metric of 'k-multisection neuron coverage'.
Args:
dataset (numpy.ndarray): Dataset used to calculate coverage metrics.
Returns:
float, the metric of 'k-multisection neuron coverage'.
Examples:
>>> from mindspore.common.initializer import TruncatedNormal
>>> from mindspore.ops import operations as P
>>> from mindspore.train import Model
>>> from mindspore.ops import TensorSummary
>>> from mindarmour.fuzz_testing import KMultisectionNeuronCoverage
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.conv1 = nn.Conv2d(1, 6, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid")
... self.conv2 = nn.Conv2d(6, 16, 5, padding=0, weight_init=TruncatedNormal(0.02), pad_mode="valid")
... self.fc1 = nn.Dense(16 * 5 * 5, 120, TruncatedNormal(0.02), TruncatedNormal(0.02))
... self.fc2 = nn.Dense(120, 84, TruncatedNormal(0.02), TruncatedNormal(0.02))
... self.fc3 = nn.Dense(84, 10, TruncatedNormal(0.02), TruncatedNormal(0.02))
... self.relu = nn.ReLU()
... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
... self.reshape = P.Reshape()
... self.summary = TensorSummary()
... def construct(self, x):
... x = self.conv1(x)
... x = self.relu(x)
... self.summary('conv1', x)
... x = self.max_pool2d(x)
... x = self.conv2(x)
... x = self.relu(x)
... self.summary('conv2', x)
... x = self.max_pool2d(x)
... x = self.reshape(x, (-1, 16 * 5 * 5))
... x = self.fc1(x)
... x = self.relu(x)
... self.summary('fc1', x)
... x = self.fc2(x)
... x = self.relu(x)
... self.summary('fc2', x)
... x = self.fc3(x)
... self.summary('fc3', x)
... return x
>>> net = Net()
>>> model = Model(net)
>>> batch_size = 8
>>> num_classe = 10
>>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32)
>>> test_images = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
>>> kmnc = KMultisectionNeuronCoverage(model, train_images, segmented_num=100)
>>> metrics = kmnc.get_metrics(test_images)
"""
dataset = check_numpy_param('dataset', dataset)
if not self.incremental or not self._activate_table:
self._activate_table = self._init_k_multisection_table(dataset[0:1])
batches = math.ceil(dataset.shape[0] / self.batch_size)
for i in range(batches):
inputs = dataset[i * self.batch_size: (i + 1) * self.batch_size]
self._model.predict(Tensor(inputs))
time.sleep(0.01)
layer_out = _get_summary_tensor_data()
for layer, tensor in layer_out.items():
value = tensor.asnumpy()
value = np.mean(value, axis=tuple([i for i in range(2, len(value.shape))]))
hits = np.floor((value - self.lower_bounds[layer]) / self.intervals[layer]).astype(int)
hits = np.transpose(hits, [1, 0])
for n in range(len(hits)):
for sec in hits[n]:
if sec >= self.segmented_num or sec < 0:
continue
self._activate_table[layer][n][sec] = True
kmnc = self._activate_rate() / self.segmented_num
return kmnc
Python
1
https://gitee.com/mindspore/mindarmour.git
git@gitee.com:mindspore/mindarmour.git
mindspore
mindarmour
mindarmour
master

搜索帮助