diff --git a/docs/api/api_zh_cn/quantization/slb/mindspore_gs.quantization.SlbQuantAwareTraining.rst b/docs/api/api_zh_cn/quantization/slb/mindspore_gs.quantization.SlbQuantAwareTraining.rst index 078554514ab9ab4d25527b7bec1bc38ff1f1f644..bef8183ee90281a43e75575c6d6c398de7eea3ef 100644 --- a/docs/api/api_zh_cn/quantization/slb/mindspore_gs.quantization.SlbQuantAwareTraining.rst +++ b/docs/api/api_zh_cn/quantization/slb/mindspore_gs.quantization.SlbQuantAwareTraining.rst @@ -8,7 +8,9 @@ mindspore_gs.quantization.SlbQuantAwareTraining 参数: - **config** (dict) - 以字典的形式存放用于量化训练的属性,下面列出了受支持的属性: - - **quant_dtype** (QuantDtype) - 权重量化的数据类型,当前支持1、2、4比特。默认值:QuantDtype.INT1。 + - **quant_dtype** (Union[QuantDtype, list, tuple]) - 用于量化权重和激活的数据类型。第一个元素表示激活,第二个元素表示权重。在实际量化推理场景中需要考虑硬件器件的精度支持。当前权重量化支持1、2、4比特,激活量化支持8比特。默认值:(QuantDtype.INT8, QuantDtype.INT1)。 + - **enable_act_quant** (bool) - 在训练中是否开启激活量化。默认值:False。 + - **enable_bn_calibration** (bool) - 在训练中是否开启BN层矫正功能。默认值:False。 - **epoch_size** (int) - 训练的总epoch数。 - **has_trained_epoch** (int) - 预训练的epoch数。 - **t_start_val** (float) - 温度初始值。默认值:1.0。 @@ -17,9 +19,11 @@ mindspore_gs.quantization.SlbQuantAwareTraining - **t_factor** (float) - 温度变化因子。默认值:1.2。 异常: - - **TypeError** - `quant_dtype` 不是QuantDtype。 - - **TypeError** - `epoch_size` 或 `has_trained_epoch` 不是int。 - - **TypeError** - `t_start_val` 、 `t_start_time`、 `t_end_time` 或 `t_factor` 不是float。 + - **TypeError** - `quant_dtype` 的数据类型不是 `QuantDtype` ,或者 `quant_dtype` 存在不是 `QuantDtype` 的元素。 + - **TypeError** - `enable_act_quant` 或者 `enable_bn_calibration` 的数据类型不是bool。 + - **ValueError** - `quant_dtype` 的长度大于2。 + - **TypeError** - `epoch_size` 或 `has_trained_epoch` 的数据类型不是int。 + - **TypeError** - `t_start_val` 、 `t_start_time`、 `t_end_time` 或 `t_factor` 的数据类型不是float。 - **ValueError** - `epoch_size` 小于等于0。 - **ValueError** - `has_trained_epoch` 小于0。 - **ValueError** - `t_start_val` 或 `t_factor` 小于等于0.0。 @@ -34,8 +38,39 @@ mindspore_gs.quantization.SlbQuantAwareTraining - **weight_quant_dtype** (QuantDtype) - 权重量化的数据类型。默认值:QuantDtype.INT1。 异常: - - **TypeError** - `weight_quant_dtype` 不是QuantDtype。 - - **TypeError** - `weight_quant_dtype` 不是 `QuantDtype.INT1` 、 `QuantDtype.INT2` 和 `QuantDtype.INT4` 中的一种。 + - **TypeError** - `weight_quant_dtype` 的数据类型不是QuantDtype。 + - **ValueError** - `weight_quant_dtype` 不是 `QuantDtype.INT1` 、 `QuantDtype.INT2` 和 `QuantDtype.INT4` 中的一种。 + + .. py:method:: set_act_quant_dtype(act_quant_dtype) + + 设置激活量化的数据类型。 + + 参数: + - **act_quant_dtype** (QuantDtype) - 激活量化的数据类型。默认值:QuantDtype.INT8。 + + 异常: + - **TypeError** - `act_quant_dtype` 的数据类型不是QuantDtype。 + - **ValueError** - `act_quant_dtype` 不是 `QuantDtype.INT8` 。 + + .. py:method:: set_enable_act_quant(enable_act_quant) + + 设置是否开启激活量化。 + + 参数: + - **enable_act_quant** (bool) - 在训练中是否开启激活量化。默认值:False。 + + 异常: + - **TypeError** - `enable_act_quant` 的数据类型不是bool。 + + .. py:method:: set_enable_bn_calibration(enable_bn_calibration) + + 设置是否开启BatchNorm层矫正功能。 + + 参数: + - **enable_bn_calibration** (bool) - 在训练中是否开启BatchNorm层矫正功能。默认值:False。 + + 异常: + - **TypeError** - `enable_bn_calibration` 的数据类型不是bool。 .. py:method:: set_epoch_size(epoch_size) @@ -45,7 +80,7 @@ mindspore_gs.quantization.SlbQuantAwareTraining - **epoch_size** (int) - 训练的总epoch数。 异常: - - **TypeError** - `epoch_size` 不是int。 + - **TypeError** - `epoch_size` 的数据类型不是int。 - **ValueError** - `epoch_size` 小于等于0。 .. py:method:: set_has_trained_epoch(has_trained_epoch) @@ -56,7 +91,7 @@ mindspore_gs.quantization.SlbQuantAwareTraining - **has_trained_epoch** (int) - 预训练的epoch数。 异常: - - **TypeError** - `has_trained_epoch` 不是int。 + - **TypeError** - `has_trained_epoch` 的数据类型不是int。 - **ValueError** - `has_trained_epoch` 小于0。 .. py:method:: set_t_start_val(t_start_val) @@ -67,7 +102,7 @@ mindspore_gs.quantization.SlbQuantAwareTraining - **t_start_val** (float) - 温度初始值。默认值:1.0。 异常: - - **TypeError** - `t_start_val` 不是float。 + - **TypeError** - `t_start_val` 的数据类型不是float。 - **ValueError** - `t_start_val` 小于等于0.0。 .. py:method:: set_t_start_time(t_start_time) @@ -78,7 +113,7 @@ mindspore_gs.quantization.SlbQuantAwareTraining - **t_start_time** (float) - 温度开始变化时间。默认值:0.2。 异常: - - **TypeError** - `t_start_time` 不是float。 + - **TypeError** - `t_start_time` 的数据类型不是float。 - **ValueError** - `t_start_time` 小于0.0或大于1.0。 .. py:method:: set_t_end_time(t_end_time) @@ -89,7 +124,7 @@ mindspore_gs.quantization.SlbQuantAwareTraining - **t_end_time** (float) - 温度停止变化时间。默认值:0.6。 异常: - - **TypeError** - `t_end_time` 不是float。 + - **TypeError** - `t_end_time` 的数据类型不是float。 - **ValueError** - `t_end_time` 小于0.0或大于1.0。 .. py:method:: set_t_factor(t_factor) @@ -100,22 +135,24 @@ mindspore_gs.quantization.SlbQuantAwareTraining - **t_factor** (float) - 温度变化因子。默认值:1.2。 异常: - - **TypeError** - `t_factor` 不是float。 + - **TypeError** - `t_factor` 的数据类型不是float。 - **ValueError** - `t_factor` 小于等于0.0。 - .. py:method:: callbacks(model) + .. py:method:: callbacks(model, dataset) 定义SLB量化算法特有的一些callbacks,其中包括用于调节温度因子的callback。 参数: - **model** (Model) - 经过算法修改后的网络构造的mindspore的Model对象。 + - **dataset** (Dataset) - 加载了特定数据集的Dataset对象。 异常: - **RuntimeError** - `epoch_size` 没有初始化。 - **RuntimeError** - `has_trained_epoch` 没有初始化。 - **ValueError** - `epoch_size` 小于等于 `has_trained_epoch` 。 - **ValueError** - `t_end_time` 小于 `t_start_time` 。 - - **TypeError** - `model` 不是Model。 + - **TypeError** - `model` 的数据类型不是mindspore.Model。 + - **TypeError** - `dataset` 的数据类型不是mindspore.dataset.Dataset。 返回: SLB量化算法特有的一些callbacks的列表。 diff --git a/mindspore_gs/quantization/slb/slb_fake_quantizer.py b/mindspore_gs/quantization/slb/slb_fake_quantizer.py index b2a3dbaa760099a406532e215a1ed6f1a6b2aaac..8a377c7756084080ce9f9dc8c7afdb20ae6c4f69 100644 --- a/mindspore_gs/quantization/slb/slb_fake_quantizer.py +++ b/mindspore_gs/quantization/slb/slb_fake_quantizer.py @@ -14,8 +14,11 @@ # ============================================================================ """SlbFakeQuantizer.""" +from functools import partial import numpy as np import mindspore +import mindspore.context as context +from mindspore.ops.operations import _quant_ops as Q from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor from mindspore.ops import operations as P @@ -104,3 +107,65 @@ class SlbFakeQuantizerPerLayer(FakeQuantizer): """Display instance object as string.""" s = 'bit_num={}'.format(self.num_bits) return s + + +class SlbActQuantizer(FakeQuantizer): + """ + Implement of SlbActQuantizer. + 1. statistic the min max value passing through this op + 2. run fake quant execution to simulate the quantize loss + """ + + def __init__(self, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, num_bits=8, quant_delay=900): + super(SlbActQuantizer, self).__init__() + self._ema = ema + self._ema_decay = ema_decay + self._symmetric = symmetric + self._num_bits = num_bits + self._quant_delay = quant_delay + self._narrow_range = narrow_range + self._min_max_update_func = Q.MinMaxUpdatePerLayer(ema=self._ema, ema_decay=self._ema_decay) + self._is_ascend = context.get_context("device_target") == "Ascend" + quant_func = Q.FakeQuantPerLayer + self._init_fake_quant_func(quant_func) + self._float_min = Parameter(Tensor(np.array([-6]).astype(np.float32), mindspore.float32), + name="float_min", requires_grad=False) + self._float_max = Parameter(Tensor(np.array([6]).astype(np.float32), mindspore.float32), + name="float_max", requires_grad=False) + + def _init_fake_quant_func(self, quant_func): + """ + Define fake quant function according to device + """ + if self._is_ascend: + self._fake_quant_train = quant_func(num_bits=self._num_bits, + symmetric=self._symmetric, + narrow_range=self._narrow_range, + quant_delay=self._quant_delay) + self._fake_quant_infer = self._fake_quant_train + else: + quant_func = partial(quant_func, + ema=self._ema, + ema_decay=self._ema_decay, + num_bits=self._num_bits, + symmetric=self._symmetric, + narrow_range=self._narrow_range, + quant_delay=self._quant_delay) + self._fake_quant_train = quant_func(training=True) + self._fake_quant_infer = quant_func(training=False) + + def extend_repr(self): + """Display instance object as string.""" + s = 'bit_num={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}, ' \ + 'quant_delay={}'.format(self._num_bits, self._symmetric, self._narrow_range, + self._ema, self._ema_decay, False, self._quant_delay) + return s + + def construct(self, x): + if self.training: + self._float_min, self._float_max = \ + self._min_max_update_func(x, self._float_min, self._float_max) + out = self._fake_quant_train(x, self._float_min, self._float_max) + else: + out = self._fake_quant_infer(x, self._float_min, self._float_max) + return out diff --git a/mindspore_gs/quantization/slb/slb_layer_policy.py b/mindspore_gs/quantization/slb/slb_layer_policy.py index f9dfe462b1ebd5a84ab80444c2195f689c066278..f9e142ebd99cfa643e34f3c0c861457f29eaf3f0 100644 --- a/mindspore_gs/quantization/slb/slb_layer_policy.py +++ b/mindspore_gs/quantization/slb/slb_layer_policy.py @@ -16,13 +16,12 @@ from typing import Optional from functools import partial - from mindspore.nn import Cell from mindspore.nn.layer.quant import QuantConfig as OpQuantConfig from ..layer_policy import LayerPolicy from ..quantize_wrapper_cell import QuantizeWrapperCell from ..fake_quantizer import FakeQuantizer -from .slb_fake_quantizer import SlbFakeQuantizerPerLayer +from .slb_fake_quantizer import SlbFakeQuantizerPerLayer, SlbActQuantizer from .slb_quant import Conv2dSlbQuant from .slb_quant_config import SlbQuantConfig @@ -30,7 +29,7 @@ from .slb_quant_config import SlbQuantConfig class SlbLayerPolicy(LayerPolicy): """ Derived class of LayerPolicy. slb layer policy. - Use slb perlayer fake quantizer as weight fake quantizer. + Use slb perlayer fake quantizer as weight fake quantizer, linear perlayer fake quantizer as act fake quantizer. Supported Config: ``quant_dtype``. @@ -39,13 +38,21 @@ class SlbLayerPolicy(LayerPolicy): def __init__(self, weight_names: [], act_names: [], config: SlbQuantConfig = SlbQuantConfig()): self._config = config weight_num_bits = config.weight_quant_dtype.num_bits + act_num_bits = config.act_quant_dtype.num_bits if weight_num_bits not in [1, 2, 4]: - raise TypeError("Only support int4|int2|int1 weight quant now!") + raise ValueError("Only support int4|int2|int1 weight quant now!") + if act_num_bits not in [8]: + raise ValueError("Only support int8 activation quant now!") self._weight_quantizer_partial = partial(SlbFakeQuantizerPerLayer, num_bits=weight_num_bits) - self._act_quantizer: Optional[FakeQuantizer] = None - self._input_quantizer: Optional[FakeQuantizer] = None - self._output_quantizer: Optional[FakeQuantizer] = None + if config.enable_act_quant: + self._act_quantizer: Optional[FakeQuantizer] = SlbActQuantizer(num_bits=act_num_bits) + self._input_quantizer: Optional[FakeQuantizer] = SlbActQuantizer(num_bits=act_num_bits) + self._output_quantizer: Optional[FakeQuantizer] = SlbActQuantizer(num_bits=act_num_bits) + else: + self._act_quantizer: Optional[FakeQuantizer] = None + self._input_quantizer: Optional[FakeQuantizer] = None + self._output_quantizer: Optional[FakeQuantizer] = None self._weight_names = weight_names self._act_names = act_names self._input_num = 0 diff --git a/mindspore_gs/quantization/slb/slb_quant_aware_training.py b/mindspore_gs/quantization/slb/slb_quant_aware_training.py index 9224ba708f7012d76bfe3af54e3859e1f7f65998..d8458b0f88576eef8b6f83e66c39bbc05eac3ae2 100644 --- a/mindspore_gs/quantization/slb/slb_quant_aware_training.py +++ b/mindspore_gs/quantization/slb/slb_quant_aware_training.py @@ -12,14 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Basic implementation of slb quantization method, this algorithm regards the discrete weights -in an arbitrary quantized neural network as searchable variables, and utilize a differential method -to search them accurately. In particular, each weight is represented as a probability distribution -over the discrete value set. The probabilities are optimized during training and the values -with the highest probability are selected to establish the desired quantized network. -See more details in `Searching for Low-Bit Weights in Quantized Neural Networks -`. """ +"""SlbQuantAwareTraining.""" +from mindspore.dataset import Dataset from mindspore import Model from mindspore.nn import Cell from mindspore.train.callback import Callback @@ -32,15 +27,27 @@ from .slb_quant_config import SlbQuantConfig class SlbQuantAwareTraining(QuantizationAwareTraining): """ - Derived class of GoldenStick. SLB(Searching for Low-Bit Weights) QAT-algorithm. + Basic implementation of slb quantization method, this algorithm regards the discrete weights + in an arbitrary quantized neural network as searchable variables, and utilize a differential method + to search them accurately. In particular, each weight is represented as a probability distribution + over the discrete value set. The probabilities are optimized during training and the values + with the highest probability are selected to establish the desired quantized network. + See more details in `Searching for Low-Bit Weights in Quantized Neural Networks + `. Args: config (dict): store attributes for quantization aware training, keys are attribute names, values are attribute values. Supported attribute are listed below: - - quant_dtype (QuantDtype): Datatype used to quantize weights, weights quantization - support int4|int2|int1 now. - Default: QuantDtype.INT1. + - quant_dtype (Union[QuantDtype, list, tuple]): Datatype used to quantize weights and activations. The first + element represents activations and the second element represents weights. It is necessary to consider the + precision support of hardware devices in the practical quantization infer scenaries. + Weights quantization support int4|int2|int1, and activations quantization support int8 now. + Default: (QuantDtype.INT8, QuantDtype.INT1). + - enable_act_quant (bool): Whether apply activation quantization while training. + Default: False. + - enable_bn_calibration (bool): Whether apply batchnorm calibration while training. + Default: False. - epoch_size (int): Total training epochs. - has_trained_epoch (int): The trained epochs. - t_start_val (float): Initial value of temperature hyperparameters. Default: 1. @@ -52,7 +59,9 @@ class SlbQuantAwareTraining(QuantizationAwareTraining): Default: 1.2. Raises: - TypeError: If `quant_dtype` is not `QuantDtype`. + TypeError: If `quant_dtype` is not `QuantDtype`, or every element of `quant_dtype` is not `QuantDtype`. + TypeError: If `enable_act_quant` or `enable_bn_calibration` is not bool. + ValueError: If the length of `quant_dtype` is greater than 2. TypeError: If `epoch_size` or `has_trained_epoch` is not an int. TypeError: If `t_start_val`, `t_start_time`, `t_end_time` or `t_factor` is not float. ValueError: If `epoch_size` is not greater than 0. @@ -87,20 +96,29 @@ class SlbQuantAwareTraining(QuantizationAwareTraining): >>> ## 3.1) set_weight_quant_dtype is used to set the weight quantization bit, and support QuantDtype.INT4, QuantDtype.INT2, >>> ## QuantDtype.INT1 now. >>> slb_quantization.set_weight_quant_dtype(QuantDtype.INT1) - >>> ## 3.2) set_epoch_size is used to set the epoch size of training. + >>> ## 3.2) set_act_quant_dtype is used to set the activation quantization bit, and support QuantDtype.INT8 now. + >>> slb_quantization.set_act_quant_dtype(QuantDtype.INT8) + >>> ## 3.3) set_enable_act_quant is used to set whether apply activation quantization. + >>> slb_quantization.set_enable_act_quant(True) + >>> ## 3.4) set_enable_bn_calibration is used to set whether apply batchnorm calibration. + >>> slb_quantization.set_enable_bn_calibration(True) + >>> ## 3.5) set_epoch_size is used to set the epoch size of training. >>> slb_quantization.set_epoch_size(100) - >>> ## 3.3) set_has_trained_epoch is used to set the trained epoch size of training. + >>> ## 3.6) set_has_trained_epoch is used to set the trained epoch size of training. >>> slb_quantization.set_has_trained_epoch(0) - >>> ## 3.4) set_t_start_val is used to set the initial value of temperature hyperparameters. + >>> ## 3.7) set_t_start_val is used to set the initial value of temperature hyperparameters. >>> slb_quantization.set_t_start_val(1.0) - >>> ## 3.5) set_t_start_time is used to set the fraction of epochs after which temperature hyperparameters starting changing. + >>> ## 3.8) set_t_start_time is used to set the fraction of epochs after which temperature hyperparameters starting changing. >>> slb_quantization.set_t_start_time(0.2) - >>> ## 3.6) set_t_end_time is used to set the fraction of epochs after which temperature hyperparameters stopping changing. + >>> ## 3.9) set_t_end_time is used to set the fraction of epochs after which temperature hyperparameters stopping changing. >>> slb_quantization.set_t_end_time(0.6) - >>> ## 3.7) set_t_factor is used to set the multiplicative factor of temperature hyperparameters changing. + >>> ## 3.10) set_t_factor is used to set the multiplicative factor of temperature hyperparameters changing. >>> slb_quantization.set_t_factor(1.2) >>> ## 4) Print SLB QAT-Algorithm object and check the config setting result >>> ## Since we set weight_quant_dtype to be QuantDtype.INT1, the value of the attribute weight_quant_dtype is INT1 + >>> ## Since we set act_quant_dtype to be QuantDtype.INT8, the value of the attribute weight_quant_dtype is INT8 + >>> ## Since we set enable_act_quant to be True, the value of the attribute enable_act_quant is True + >>> ## Since we set enable_bn_calibration to be True, the value of the attribute enable_bn_calibration is True >>> ## Since we set epoch_size to be 100, the value of the attribute epoch_size is 100 >>> ## Since we set has_trained_epoch to be 0, the value of the attribute has_trained_epoch is 0 >>> ## Since we set t_start_val to be 1.0, the value of the attribute t_start_val is 1.0 @@ -108,7 +126,7 @@ class SlbQuantAwareTraining(QuantizationAwareTraining): >>> ## Since we set t_end_time to be 0.6, the value of the attribute t_end_time is 0.6 >>> ## Since we set t_factor to be 1.2, the value of the attribute t_factor is 1.2 >>> print(slb_quantization) - SlbQuantAwareTraining + SlbQuantAwareTraining >>> ## 5) Apply SLB QAT-algorithm to origin network >>> net_qat = slb_quantization.apply(net) >>> ## 6) Print network and check the result. Conv2d should be transformed to QuantizeWrapperCells. @@ -126,6 +144,8 @@ class SlbQuantAwareTraining(QuantizationAwareTraining): in_channels=1, out_channels=6, kernel_size=(5, 5), weight_bit_num=1, stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False (fake_quant_weight): SlbFakeQuantizerPerLayer > + (_input_quantizer): SlbActQuantizer + (_output_quantizer): SlbActQuantizer > > """ @@ -147,26 +167,73 @@ class SlbQuantAwareTraining(QuantizationAwareTraining): def set_weight_quant_dtype(self, weight_quant_dtype): """ - Set value of weight_quant_dtype of `_config` + Set value of weight_quant_dtype of quantization aware training `config` Args: weight_quant_dtype (QuantDtype): Datatype used to quantize weights. Default: QuantDtype.INT1. Raises: TypeError: If `weight_quant_dtype` is not QuantDtype. - TypeError: Only supported if `weight_quant_dtype` is `QuantDtype.INT1`, `QuantDtype.INT2` + ValueError: Only supported if `weight_quant_dtype` is `QuantDtype.INT1`, `QuantDtype.INT2` or `QuantDtype.INT4` yet. """ - weight_quant_dtype = Validator.check_isinstance("weight quant dtype", weight_quant_dtype, QuantDtype) + if not isinstance(weight_quant_dtype, QuantDtype): + raise TypeError("The parameter `weight quant dtype` must be isinstance of QuantDtype, " + "but got {}.".format(weight_quant_dtype)) if weight_quant_dtype not in [QuantDtype.INT1, QuantDtype.INT2, QuantDtype.INT4]: - raise TypeError("Only supported if `weight_quant_dtype` is `QuantDtype.INT1`, " \ - "`QuantDtype.INT2` or `QuantDtype.INT4` yet. " \ - "But got {}".format(weight_quant_dtype)) + raise ValueError("Only supported if `weight_quant_dtype` is `QuantDtype.INT1`, " \ + "`QuantDtype.INT2` or `QuantDtype.INT4` yet. " \ + "But got {}.".format(weight_quant_dtype)) self._config.weight_quant_dtype = weight_quant_dtype + def set_act_quant_dtype(self, act_quant_dtype): + """ + Set value of act_quant_dtype of quantization aware training `config` + + Args: + act_quant_dtype (QuantDtype): Datatype used to quantize activations. Default: QuantDtype.INT8. + + Raises: + TypeError: If `act_quant_dtype` is not QuantDtype. + ValueError: Only supported if `act_quant_dtype` is `QuantDtype.INT8` yet. + """ + if not isinstance(act_quant_dtype, QuantDtype): + raise TypeError("The parameter `act quant dtype` must be isinstance of QuantDtype, " + "but got {}.".format(act_quant_dtype)) + if act_quant_dtype not in [QuantDtype.INT8]: + raise ValueError("Only supported if `act_quant_dtype` is `QuantDtype.INT8` " \ + "yet. But got {}.".format(act_quant_dtype)) + self._config.act_quant_dtype = act_quant_dtype + + def set_enable_act_quant(self, enable_act_quant): + """ + Set value of enable_act_quant of quantization aware training `config` + + Args: + enable_act_quant (bool): Whether apply activation quantization while training, default is False. + + Raises: + TypeError: If `enable_act_quant` is not bool. + """ + enable_act_quant = Validator.check_bool(enable_act_quant, "enable_act_quant", self.__class__.__name__) + self._config.enable_act_quant = enable_act_quant + + def set_enable_bn_calibration(self, enable_bn_calibration): + """ + Set value of enable_bn_calibration of quantization aware training `config` + + Args: + enable_bn_calibration (bool): Whether apply batchnorm calibration while training, default is False. + + Raises: + TypeError: If `enable_bn_calibration` is not bool. + """ + enable_bn_calibration = Validator.check_bool(enable_bn_calibration, "enable_bn_calibration", self.__class__.__name__) + self._config.enable_bn_calibration = enable_bn_calibration + def set_epoch_size(self, epoch_size): """ - Set value of epoch_size of `_config` + Set value of epoch_size of quantization aware training `config` Args: epoch_size (int): the epoch size of training. @@ -180,7 +247,7 @@ class SlbQuantAwareTraining(QuantizationAwareTraining): def set_has_trained_epoch(self, has_trained_epoch): """ - Set value of has_trained_epoch of `_config` + Set value of has_trained_epoch of quantization aware training `config` Args: has_trained_epoch (int): the trained epochs of training. @@ -194,7 +261,7 @@ class SlbQuantAwareTraining(QuantizationAwareTraining): def set_t_start_val(self, t_start_val): """ - Set value of t_start_val of `_config` + Set value of t_start_val of quantization aware training `config` Args: t_start_val (float): Initial value of temperature hyperparameters, default: 1.0. @@ -208,7 +275,7 @@ class SlbQuantAwareTraining(QuantizationAwareTraining): def set_t_start_time(self, t_start_time): """ - Set value of t_start_time of `_config` + Set value of t_start_time of quantization aware training `config` Args: t_start_time (float): Fraction of epochs after which temperature hyperparameters starting changing, default: 0.2. @@ -223,7 +290,7 @@ class SlbQuantAwareTraining(QuantizationAwareTraining): def set_t_end_time(self, t_end_time): """ - Set value of t_end_time of `_config` + Set value of t_end_time of quantization aware training `config` Args: t_end_time (float): Fraction of epochs after which temperature hyperparameters stopping changing, default: 0.6. @@ -238,7 +305,7 @@ class SlbQuantAwareTraining(QuantizationAwareTraining): def set_t_factor(self, t_factor): """ - Set value of t_factor of `_config` + Set value of t_factor of quantization aware training `config` Args: t_factor (float): Multiplicative factor of temperature hyperparameters changing, default: 1.2. @@ -250,14 +317,30 @@ class SlbQuantAwareTraining(QuantizationAwareTraining): t_factor = Validator.check_positive_float(t_factor, "t_factor", self.__class__.__name__) self._config.t_factor = t_factor + @staticmethod + def _convert2list(name, value): + if not isinstance(value, list) and not isinstance(value, tuple): + value = [value, value] + elif len(value) == 1: + value = value + value + elif len(value) > 2: + raise ValueError("The length of input `{}` should not be greater than 2.".format(name)) + return value + def _init_net_policy(self, config): return SlbNetPolicy(config) def _create_qconfig_by_dict(self, config: dict): """Create `_config` from a dict""" self._config = SlbQuantConfig() - self.set_weight_quant_dtype(config.get("quant_dtype", QuantDtype.INT1)) + quant_dtype_list = SlbQuantAwareTraining.\ + _convert2list("quant dtype", config.get("quant_dtype", [QuantDtype.INT8, QuantDtype.INT1])) + + self.set_act_quant_dtype(quant_dtype_list[0]) + self.set_weight_quant_dtype(quant_dtype_list[-1]) + self.set_enable_act_quant(config.get("enable_act_quant", False)) + self.set_enable_bn_calibration(config.get("enable_bn_calibration", False)) if "epoch_size" in config: self.set_epoch_size(config["epoch_size"]) if "has_trained_epoch" in config: @@ -267,19 +350,21 @@ class SlbQuantAwareTraining(QuantizationAwareTraining): self.set_t_end_time(config.get("t_end_time", 0.6)) self.set_t_factor(config.get("t_factor", 1.2)) - def callbacks(self, model: Model) -> [Callback]: + def callbacks(self, model: Model, dataset: Dataset) -> [Callback]: """ Define TemperatureScheduler callback for SLB QAT-algorithm. Args: model (Model): Model to be used. + dataset (Dataset): Dataset to be used. Raises: RuntimeError: If `epoch_size` is not initialized! RuntimeError: If `has_trained_epoch` is not initialized! ValueError: If `epoch_size` is not greater than `has_trained_epoch`. ValueError: If `t_end_time` is less than `t_start_time`. - TypeError: If `model` is not Model. + TypeError: If `model` is not mindspore.Model. + TypeError: If `dataset` is not mindspore.dataset.Dataset. Returns: List of instance of Callbacks. @@ -295,12 +380,22 @@ class SlbQuantAwareTraining(QuantizationAwareTraining): if self._config.t_end_time < self._config.t_start_time: raise ValueError("The `t_end_time` should not be less than `t_start_time`.") - model = Validator.check_isinstance("model", model, Model) + if not isinstance(model, Model): + raise TypeError(f'The parameter `model` must be isinstance of mindspore.Model, ' + f'but got {model}.') + + if not isinstance(dataset, Dataset): + raise TypeError(f'The parameter `dataset` must be isinstance of mindspore.dataset.Dataset, ' + f'but got {dataset}.') cb = [] cb.append(TemperatureScheduler(model, self._config.epoch_size, self._config.has_trained_epoch, self._config.t_start_val, self._config.t_start_time, self._config.t_end_time, self._config.t_factor)) + + if self._config.enable_bn_calibration: + cb.append(BNCalibrationCallback(model, dataset, self._config.epoch_size, + self._config.has_trained_epoch, self._config.t_start_time, False)) return cb def apply(self, network: Cell) -> Cell: @@ -325,11 +420,12 @@ class SlbQuantAwareTraining(QuantizationAwareTraining): def __repr__(self): """Display instance object as string.""" - s = 'SlbQuantAwareTraining'.format(self._config.weight_quant_dtype, self._config.epoch_size, - self._config.has_trained_epoch, self._config.t_start_val, - self._config.t_start_time, self._config.t_end_time, - self._config.t_factor) + s = 'SlbQuantAwareTraining'.format(self._config.weight_quant_dtype, self._config.act_quant_dtype, + self._config.enable_act_quant, self._config.enable_bn_calibration, + self._config.epoch_size, self._config.has_trained_epoch, self._config.t_start_val, + self._config.t_start_time, self._config.t_end_time, self._config.t_factor) return s @@ -366,3 +462,29 @@ class TemperatureScheduler(Callback): cell.set_temperature(t) if epoch >= t_end_epoch: cell.set_temperature_end_flag() + + +class BNCalibrationCallback(Callback): + '''Update discrete state statistics in BN layers.''' + def __init__(self, model, train_set, epoch_size=100, has_trained_epoch=0, + t_start_time=0.2, dataset_sink_mode=False): + self.dataset_sink_mode = dataset_sink_mode + self.model = model + self.train_set = train_set + self.epochs = epoch_size + self.has_trained_epoch = has_trained_epoch + self.t_start_time = t_start_time + + def epoch_end(self, run_context): + """ + Epoch_end. + """ + cb_params = run_context.original_args() + epoch = cb_params.cur_epoch_num + self.has_trained_epoch + t_start_epoch = int(self.epochs*self.t_start_time) + if epoch > t_start_epoch: + # make BN update for train and BNCalibration + for _, cell in self.model.train_network.cells_and_names(): + if cell.cls_name == 'BatchNorm2d': + cell.use_batch_statistics = True + self.model.eval(self.train_set, dataset_sink_mode=self.dataset_sink_mode) diff --git a/mindspore_gs/quantization/slb/slb_quant_config.py b/mindspore_gs/quantization/slb/slb_quant_config.py index 060fac9a9ed66b6e3f93ce65aebbf8ab71e59675..f27f68cbb221ffe3a39d23084838e94c2c855cb7 100644 --- a/mindspore_gs/quantization/slb/slb_quant_config.py +++ b/mindspore_gs/quantization/slb/slb_quant_config.py @@ -22,7 +22,10 @@ class SlbQuantConfig: See more details in slb_quant_aware_training.py """ def __init__(self): + self.act_quant_dtype = QuantDtype.INT8 self.weight_quant_dtype = QuantDtype.INT1 + self.enable_act_quant = False + self.enable_bn_calibration = False self.epoch_size = -1 self.has_trained_epoch = -1 self.t_start_val = 1.0 diff --git a/tests/st/quantization/test_slb_qat.py b/tests/st/quantization/test_slb_qat.py index 694c26c592b1f9344ca3827d64f5d2e90b1d45d9..73476e5d89e320b64d95de441fb805efe10ec6f7 100644 --- a/tests/st/quantization/test_slb_qat.py +++ b/tests/st/quantization/test_slb_qat.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""test qat.""" +"""test slb qat.""" import os import sys @@ -22,7 +22,7 @@ import pytest import numpy as np import mindspore from mindspore import nn, context -from mindspore.train import Model +from mindspore import Model from mindspore.nn.metrics import Accuracy from mindspore_gs.quantization.slb import SlbQuantAwareTraining as SlbQAT from mindspore_gs.quantization.constant import QuantDtype @@ -49,8 +49,9 @@ class NetToQuant(nn.Cell): @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -@pytest.mark.parametrize("quant_bit", ["W4", "W2", "W1"]) -def test_set_config(quant_bit): +@pytest.mark.parametrize("quant_bit", ["W4", "W2", "W1", "W4A8", "W2A8", "W1A8"]) +@pytest.mark.parametrize("enable_bn_calibration", [True, False]) +def test_set_config(quant_bit, enable_bn_calibration): """ Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm set functions. Description: Apply SlbQuantAwareTraining on lenet. @@ -59,12 +60,18 @@ def test_set_config(quant_bit): network = NetToQuant() qat = SlbQAT() - if quant_bit == "W4": + if "W4" in quant_bit: qat.set_weight_quant_dtype(QuantDtype.INT4) - elif quant_bit == "W2": + elif "W2" in quant_bit: qat.set_weight_quant_dtype(QuantDtype.INT2) - elif quant_bit == "W1": + elif "W1" in quant_bit: qat.set_weight_quant_dtype(QuantDtype.INT1) + if "A8" in quant_bit: + qat.set_act_quant_dtype(QuantDtype.INT8) + qat.set_enable_act_quant(True) + else: + qat.set_enable_act_quant(False) + qat.set_enable_bn_calibration(enable_bn_calibration) qat.set_epoch_size(100) qat.set_has_trained_epoch(0) qat.set_t_start_val(1.0) @@ -80,6 +87,7 @@ def test_set_config(quant_bit): conv_handler = conv_quant._handler weight_fake_quant: SlbFakeQuantizerPerLayer = conv_handler.fake_quant_weight assert isinstance(weight_fake_quant, SlbFakeQuantizerPerLayer) + assert qat._config.enable_bn_calibration == enable_bn_calibration assert qat._config.epoch_size == 100 assert qat._config.has_trained_epoch == 0 assert qat._config.t_start_val == 1.0 @@ -93,15 +101,15 @@ def test_set_config(quant_bit): @pytest.mark.env_onecard def test_set_weight_quant_dtype_type(): """ - Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm set function set_weight_quant_dtype(). + Feature: set_weight_quant_dtype api of SLB. Description: Feed int type `weight_quant_dtype` into set_weight_quant_dtype() functional interface. - Expectation: Except ValueError. + Expectation: Except TypeError. """ qat = SlbQAT() try: qat.set_weight_quant_dtype(weight_quant_dtype=3) - except ValueError: + except TypeError: return assert False @@ -111,25 +119,92 @@ def test_set_weight_quant_dtype_type(): @pytest.mark.env_onecard def test_set_weight_quant_dtype_range(): """ - Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm set function set_weight_quant_dtype(). + Feature: set_weight_quant_dtype api of SLB. Description: Feed QuantDtype type `weight_quant_dtype` into set_weight_quant_dtype() functional interface. - Expectation: Except TypeError. + Expectation: Except ValueError. """ qat = SlbQAT() try: qat.set_weight_quant_dtype(weight_quant_dtype=QuantDtype.INT8) + except ValueError: + return + assert False + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_set_act_quant_dtype_type(): + """ + Feature: set_act_quant_dtype api of SLB. + Description: Feed int type `act_quant_dtype` into set_act_quant_dtype() functional interface. + Expectation: Except TypeError. + """ + + qat = SlbQAT() + try: + qat.set_act_quant_dtype(act_quant_dtype=3) except TypeError: return assert False +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_set_act_quant_dtype_range(): + """ + Feature: set_act_quant_dtype api of SLB. + Description: Feed QuantDtype type `act_quant_dtype` into set_act_quant_dtype() functional interface. + Expectation: Except ValueError. + """ + + qat = SlbQAT() + try: + qat.set_act_quant_dtype(act_quant_dtype=QuantDtype.INT1) + except ValueError: + return + assert False + + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_set_enable_act_quant(): + """ + Feature: set_enable_act_quant api of SlbQAT. + Description: Check default value of enable_act_quant and value after called set_enable_act_quant. + Expectation: Config success. + """ + qat = SlbQAT() + assert not qat._config.enable_act_quant + qat.set_enable_act_quant(True) + assert qat._config.enable_act_quant + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_set_enable_bn_calibration(): + """ + Feature: set_enable_bn_calibration api of SlbQAT. + Description: Check default value of enable_bn_calibration and value after called set_enable_bn_calibration. + Expectation: Config success. + """ + qat = SlbQAT() + assert not qat._config.enable_bn_calibration + qat.set_enable_bn_calibration(True) + assert qat._config.enable_bn_calibration + + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_set_epoch_size_type(): """ - Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm set function set_epoch_size(). + Feature: set_epoch_size api of SlbQAT. Description: Feed float type `epoch_size` into set_epoch_size() functional interface. Expectation: Except TypeError. """ @@ -147,7 +222,7 @@ def test_set_epoch_size_type(): @pytest.mark.env_onecard def test_set_epoch_size_range(): """ - Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm set function set_epoch_size(). + Feature: set_epoch_size api of SlbQAT. Description: Feed int type `epoch_size` into set_epoch_size() functional interface. Expectation: Except ValueError. """ @@ -165,7 +240,7 @@ def test_set_epoch_size_range(): @pytest.mark.env_onecard def test_set_has_trained_epoch_type(): """ - Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm set function set_has_trained_epoch(). + Feature: set_has_trained_epoch api of SlbQAT. Description: Feed float type `has_trained_epoch` into set_has_trained_epoch() functional interface. Expectation: Except TypeError. """ @@ -183,7 +258,7 @@ def test_set_has_trained_epoch_type(): @pytest.mark.env_onecard def test_set_has_trained_epoch_range(): """ - Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm set function set_has_trained_epoch(). + Feature: set_has_trained_epoch api of SlbQAT. Description: Feed int type `has_trained_epoch` into set_has_trained_epoch() functional interface. Expectation: Except ValueError. """ @@ -201,7 +276,7 @@ def test_set_has_trained_epoch_range(): @pytest.mark.env_onecard def test_set_t_start_val_type(): """ - Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm set function set_t_start_val(). + Feature: set_t_start_val api of SlbQAT. Description: Feed int type `t_start_val` into set_t_start_val() functional interface. Expectation: Except TypeError. """ @@ -219,7 +294,7 @@ def test_set_t_start_val_type(): @pytest.mark.env_onecard def test_set_t_start_val_range(): """ - Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm set function set_t_start_val(). + Feature: set_t_start_val api of SlbQAT. Description: Feed float type `t_start_val` into set_t_start_val() functional interface. Expectation: Except ValueError. """ @@ -237,7 +312,7 @@ def test_set_t_start_val_range(): @pytest.mark.env_onecard def test_set_t_start_time_type(): """ - Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm set function set_t_start_time(). + Feature: set_t_start_time api of SlbQAT. Description: Feed int type `t_start_time` into set_t_start_time() functional interface. Expectation: Except TypeError. """ @@ -255,7 +330,7 @@ def test_set_t_start_time_type(): @pytest.mark.env_onecard def test_set_t_start_time_range(): """ - Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm set function set_t_start_time(). + Feature: set_t_start_time api of SlbQAT. Description: Feed float type `t_start_time` into set_t_start_time() functional interface. Expectation: Except ValueError. """ @@ -273,7 +348,7 @@ def test_set_t_start_time_range(): @pytest.mark.env_onecard def test_set_t_end_time_type(): """ - Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm set function set_t_end_time(). + Feature: set_t_end_time api of SlbQAT. Description: Feed int type `t_end_time` into set_t_end_time() functional interface. Expectation: Except TypeError. """ @@ -291,7 +366,7 @@ def test_set_t_end_time_type(): @pytest.mark.env_onecard def test_set_t_end_time_range(): """ - Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm set function set_t_end_time(). + Feature: set_t_end_time api of SlbQAT. Description: Feed float type `t_end_time` into set_t_end_time() functional interface. Expectation: Except ValueError. """ @@ -309,7 +384,7 @@ def test_set_t_end_time_range(): @pytest.mark.env_onecard def test_set_t_factor_type(): """ - Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm set function set_t_factor(). + Feature: set_t_factor api of SlbQAT. Description: Feed int type `t_factor` into set_t_factor() functional interface. Expectation: Except TypeError. """ @@ -327,7 +402,7 @@ def test_set_t_factor_type(): @pytest.mark.env_onecard def test_set_t_factor_range(): """ - Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm set function set_t_factor(). + Feature: set_t_factor api of SlbQAT. Description: Feed float type `t_factor` into set_t_factor() functional interface. Expectation: Except ValueError. """ @@ -341,36 +416,46 @@ def test_set_t_factor_range(): @pytest.mark.level0 -@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_callbacks_epoch_initial(): """ - Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm function callbacks(). + Feature: callbacks api of SlbQAT. Description: Not feed `epoch_size` and `has_trained_epoch`. Expectation: Except RuntimeError. """ + from lenet.src.dataset import create_dataset as create_mnist_ds + context.set_context(mode=context.GRAPH_MODE) + data_path = "/home/workspace/mindspore_dataset/mnist/train" + ds_train = create_mnist_ds(data_path, 32, 1) + network = NetToQuant() qat = SlbQAT() new_network = qat.apply(network) model = Model(new_network) try: - qat.callbacks(model=model) + qat.callbacks(model=model, dataset=ds_train) except RuntimeError: return assert False @pytest.mark.level0 -@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_callbacks_epoch_range_compare(): """ - Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm function callbacks(). + Feature: callbacks api of SlbQAT. Description: Feed incorrect `epoch_size` and `has_trained_epoch`. Expectation: Except ValueError. """ + from lenet.src.dataset import create_dataset as create_mnist_ds + context.set_context(mode=context.GRAPH_MODE) + data_path = "/home/workspace/mindspore_dataset/mnist/train" + ds_train = create_mnist_ds(data_path, 32, 1) + network = NetToQuant() qat = SlbQAT() new_network = qat.apply(network) @@ -378,22 +463,27 @@ def test_callbacks_epoch_range_compare(): try: qat.set_epoch_size(epoch_size=100) qat.set_has_trained_epoch(has_trained_epoch=120) - qat.callbacks(model=model) + qat.callbacks(model=model, dataset=ds_train) except ValueError: return assert False @pytest.mark.level0 -@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_callbacks_time_range_compare(): """ - Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm function callbacks(). + Feature: callbacks api of SlbQAT. Description: Feed incorrect `t_start_time` and `t_end_time`. Expectation: Except ValueError. """ + from lenet.src.dataset import create_dataset as create_mnist_ds + context.set_context(mode=context.GRAPH_MODE) + data_path = "/home/workspace/mindspore_dataset/mnist/train" + ds_train = create_mnist_ds(data_path, 32, 1) + network = NetToQuant() qat = SlbQAT() new_network = qat.apply(network) @@ -403,30 +493,60 @@ def test_callbacks_time_range_compare(): qat.set_has_trained_epoch(has_trained_epoch=0) qat.set_t_start_time(t_start_time=0.7) qat.set_t_end_time(t_end_time=0.4) - qat.callbacks(model=model) + qat.callbacks(model=model, dataset=ds_train) except ValueError: return assert False @pytest.mark.level0 -@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_callbacks_model_type(): """ - Feature: SLB(Searching for Low-Bit Weights) QAT-algorithm function callbacks(). + Feature: callbacks api of SlbQAT. Description: Feed int type `model` into callbacks() functional interface. - Expectation: Except ValueError. + Expectation: Except TypeError. """ + from lenet.src.dataset import create_dataset as create_mnist_ds + context.set_context(mode=context.GRAPH_MODE) + data_path = "/home/workspace/mindspore_dataset/mnist/train" + ds_train = create_mnist_ds(data_path, 32, 1) + qat = SlbQAT() try: qat.set_epoch_size(epoch_size=100) qat.set_has_trained_epoch(has_trained_epoch=0) qat.set_t_start_time(t_start_time=0.2) qat.set_t_end_time(t_end_time=0.6) - qat.callbacks(model=10) - except ValueError: + qat.callbacks(model=10, dataset=ds_train) + except TypeError: + return + assert False + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_callbacks_dataset_type(): + """ + Feature: callbacks api of SlbQAT. + Description: Feed int type `dataset` into callbacks() functional interface. + Expectation: Except TypeError. + """ + + network = NetToQuant() + qat = SlbQAT() + new_network = qat.apply(network) + model = Model(new_network) + try: + qat.set_epoch_size(epoch_size=100) + qat.set_has_trained_epoch(has_trained_epoch=0) + qat.set_t_start_time(t_start_time=0.2) + qat.set_t_end_time(t_end_time=0.6) + qat.callbacks(model=model, dataset=5) + except TypeError: return assert False @@ -434,8 +554,9 @@ def test_callbacks_model_type(): @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -@pytest.mark.parametrize("quant_bit", ["W4", "W2", "W1"]) -def test_lenet(quant_bit): +@pytest.mark.parametrize("quant_bit", ["W4", "W2", "W1", "W4A8", "W2A8", "W1A8"]) +@pytest.mark.parametrize("enable_bn_calibration", [True, False]) +def test_lenet(quant_bit, enable_bn_calibration): """ Feature: slb quantization algorithm. Description: Apply slb qat on lenet. @@ -445,17 +566,36 @@ def test_lenet(quant_bit): from lenet.src.lenet import LeNet5 network = LeNet5(10) if quant_bit == "W4": - qat = SlbQAT({"quant_dtype": QuantDtype.INT4, "epoch_size": 100, + qat = SlbQAT({"quant_dtype": [QuantDtype.INT8, QuantDtype.INT4], "enable_act_quant": False, + "enable_bn_calibration": enable_bn_calibration, "epoch_size": 10, "has_trained_epoch": 0, "t_start_val": 1.0, - "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 1.2}) + "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 3.2}) elif quant_bit == "W2": - qat = SlbQAT({"quant_dtype": QuantDtype.INT2, "epoch_size": 100, + qat = SlbQAT({"quant_dtype": [QuantDtype.INT8, QuantDtype.INT2], "enable_act_quant": False, + "enable_bn_calibration": enable_bn_calibration, "epoch_size": 10, "has_trained_epoch": 0, "t_start_val": 1.0, - "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 1.2}) + "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 3.2}) elif quant_bit == "W1": - qat = SlbQAT({"quant_dtype": QuantDtype.INT1, "epoch_size": 100, + qat = SlbQAT({"quant_dtype": [QuantDtype.INT8, QuantDtype.INT1], "enable_act_quant": False, + "enable_bn_calibration": enable_bn_calibration, "epoch_size": 10, + "has_trained_epoch": 0, "t_start_val": 1.0, + "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 3.2}) + elif quant_bit == "W4A8": + qat = SlbQAT({"quant_dtype": [QuantDtype.INT8, QuantDtype.INT4], "enable_act_quant": True, + "enable_bn_calibration": enable_bn_calibration, "epoch_size": 10, "has_trained_epoch": 0, "t_start_val": 1.0, - "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 1.2}) + "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 3.2}) + elif quant_bit == "W2A8": + qat = SlbQAT({"quant_dtype": [QuantDtype.INT8, QuantDtype.INT2], "enable_act_quant": True, + "enable_bn_calibration": enable_bn_calibration, "epoch_size": 10, + "has_trained_epoch": 0, "t_start_val": 1.0, + "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 3.2}) + elif quant_bit == "W1A8": + qat = SlbQAT({"quant_dtype": [QuantDtype.INT8, QuantDtype.INT1], "enable_act_quant": True, + "enable_bn_calibration": enable_bn_calibration, "epoch_size": 10, + "has_trained_epoch": 0, "t_start_val": 1.0, + "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 3.2}) + new_network = qat.apply(network) cells: OrderedDict = new_network.name_cells() assert cells.get("Conv2dSlbQuant", None) is not None @@ -464,26 +604,102 @@ def test_lenet(quant_bit): conv_handler = conv_quant._handler weight_fake_quant: SlbFakeQuantizerPerLayer = conv_handler.fake_quant_weight assert isinstance(weight_fake_quant, SlbFakeQuantizerPerLayer) - assert qat._config.epoch_size == 100 + assert qat._config.enable_bn_calibration == enable_bn_calibration + assert qat._config.epoch_size == 10 assert qat._config.has_trained_epoch == 0 assert qat._config.t_start_val == 1.0 assert qat._config.t_start_time == 0.2 assert qat._config.t_end_time == 0.6 - assert qat._config.t_factor == 1.2 + assert qat._config.t_factor == 3.2 +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize("quant_bit", ["W4", "W2", "W1", "W4A8", "W2A8", "W1A8"]) +@pytest.mark.parametrize("enable_bn_calibration", [True]) +@pytest.mark.parametrize("run_mode", [context.GRAPH_MODE]) +def test_lenet_accuracy_bnon(quant_bit, enable_bn_calibration, run_mode): + """ + Feature: test accuracy of slb qat work on lenet5. + Description: Apply slb qat on lenet5 and test accuracy. + Expectation: accuracy is larger than 0.95. + """ + + from lenet.src.lenet import LeNet5 + from lenet.src.dataset import create_dataset as create_mnist_ds + context.set_context(mode=run_mode) + mnist_path = os.getenv("DATASET_PATH", "/home/workspace/mindspore_dataset/mnist") + data_path = os.path.join(mnist_path, "train") + ds_train = create_mnist_ds(data_path, 32, 1) + network = LeNet5(10) + + # convert network to quantization aware network + if quant_bit == "W4": + qat = SlbQAT({"quant_dtype": [QuantDtype.INT8, QuantDtype.INT4], "enable_act_quant": False, + "enable_bn_calibration": enable_bn_calibration, "epoch_size": 10, + "has_trained_epoch": 0, "t_start_val": 1.0, + "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 3.2}) + elif quant_bit == "W2": + qat = SlbQAT({"quant_dtype": [QuantDtype.INT8, QuantDtype.INT2], "enable_act_quant": False, + "enable_bn_calibration": enable_bn_calibration, "epoch_size": 10, + "has_trained_epoch": 0, "t_start_val": 1.0, + "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 3.2}) + elif quant_bit == "W1": + qat = SlbQAT({"quant_dtype": [QuantDtype.INT8, QuantDtype.INT1], "enable_act_quant": False, + "enable_bn_calibration": enable_bn_calibration, "epoch_size": 10, + "has_trained_epoch": 0, "t_start_val": 1.0, + "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 3.2}) + elif quant_bit == "W4A8": + qat = SlbQAT({"quant_dtype": [QuantDtype.INT8, QuantDtype.INT4], "enable_act_quant": True, + "enable_bn_calibration": enable_bn_calibration, "epoch_size": 10, + "has_trained_epoch": 0, "t_start_val": 1.0, + "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 3.2}) + elif quant_bit == "W2A8": + qat = SlbQAT({"quant_dtype": [QuantDtype.INT8, QuantDtype.INT2], "enable_act_quant": True, + "enable_bn_calibration": enable_bn_calibration, "epoch_size": 10, + "has_trained_epoch": 0, "t_start_val": 1.0, + "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 3.2}) + elif quant_bit == "W1A8": + qat = SlbQAT({"quant_dtype": [QuantDtype.INT8, QuantDtype.INT1], "enable_act_quant": True, + "enable_bn_calibration": enable_bn_calibration, "epoch_size": 10, + "has_trained_epoch": 0, "t_start_val": 1.0, + "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 3.2}) + new_network = qat.apply(network) + + # define network loss + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + # define network optimization + net_opt = nn.Momentum(new_network.trainable_params(), 0.01, 0.9) + + # define model + model = Model(new_network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + + print("============== Starting Training ==============") + model.train(10, ds_train, callbacks=qat.callbacks(model, ds_train)) + print("============== End Training ==============") + + ds_eval = create_mnist_ds(os.path.join(mnist_path, "test"), 32, 1) + + print("============== Starting Testing ==============") + acc = model.eval(ds_eval) + print("============== {} ==============".format(acc)) + assert acc['Accuracy'] > 0.95 + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -@pytest.mark.parametrize("quant_bit", ["W4", "W2", "W1"]) +@pytest.mark.parametrize("quant_bit", ["W4", "W2", "W1", "W4A8", "W2A8", "W1A8"]) +@pytest.mark.parametrize("enable_bn_calibration", [False]) @pytest.mark.parametrize("run_mode", [context.GRAPH_MODE]) -def test_lenet_accuracy(quant_bit, run_mode): +def test_lenet_accuracy_bnoff(quant_bit, enable_bn_calibration, run_mode): """ Feature: test accuracy of slb qat work on lenet5. Description: Apply slb qat on lenet5 and test accuracy. - Expectation: accuracy is larger than 0.98. + Expectation: accuracy is larger than 0.95. """ from lenet.src.lenet import LeNet5 @@ -496,15 +712,33 @@ def test_lenet_accuracy(quant_bit, run_mode): # convert network to quantization aware network if quant_bit == "W4": - qat = SlbQAT({"quant_dtype": QuantDtype.INT4, "epoch_size": 10, + qat = SlbQAT({"quant_dtype": [QuantDtype.INT8, QuantDtype.INT4], "enable_act_quant": False, + "enable_bn_calibration": enable_bn_calibration, "epoch_size": 10, "has_trained_epoch": 0, "t_start_val": 1.0, "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 3.2}) elif quant_bit == "W2": - qat = SlbQAT({"quant_dtype": QuantDtype.INT2, "epoch_size": 10, + qat = SlbQAT({"quant_dtype": [QuantDtype.INT8, QuantDtype.INT2], "enable_act_quant": False, + "enable_bn_calibration": enable_bn_calibration, "epoch_size": 10, "has_trained_epoch": 0, "t_start_val": 1.0, "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 3.2}) elif quant_bit == "W1": - qat = SlbQAT({"quant_dtype": QuantDtype.INT1, "epoch_size": 10, + qat = SlbQAT({"quant_dtype": [QuantDtype.INT8, QuantDtype.INT1], "enable_act_quant": False, + "enable_bn_calibration": enable_bn_calibration, "epoch_size": 10, + "has_trained_epoch": 0, "t_start_val": 1.0, + "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 3.2}) + elif quant_bit == "W4A8": + qat = SlbQAT({"quant_dtype": [QuantDtype.INT8, QuantDtype.INT4], "enable_act_quant": True, + "enable_bn_calibration": enable_bn_calibration, "epoch_size": 10, + "has_trained_epoch": 0, "t_start_val": 1.0, + "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 3.2}) + elif quant_bit == "W2A8": + qat = SlbQAT({"quant_dtype": [QuantDtype.INT8, QuantDtype.INT2], "enable_act_quant": True, + "enable_bn_calibration": enable_bn_calibration, "epoch_size": 10, + "has_trained_epoch": 0, "t_start_val": 1.0, + "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 3.2}) + elif quant_bit == "W1A8": + qat = SlbQAT({"quant_dtype": [QuantDtype.INT8, QuantDtype.INT1], "enable_act_quant": True, + "enable_bn_calibration": enable_bn_calibration, "epoch_size": 10, "has_trained_epoch": 0, "t_start_val": 1.0, "t_start_time": 0.2, "t_end_time": 0.6, "t_factor": 3.2}) new_network = qat.apply(network) @@ -518,7 +752,7 @@ def test_lenet_accuracy(quant_bit, run_mode): model = Model(new_network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) print("============== Starting Training ==============") - model.train(10, ds_train, callbacks=qat.callbacks(model)) + model.train(10, ds_train, callbacks=qat.callbacks(model, ds_train)) print("============== End Training ==============") ds_eval = create_mnist_ds(os.path.join(mnist_path, "test"), 32, 1) @@ -533,9 +767,10 @@ def test_lenet_accuracy(quant_bit, run_mode): @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -@pytest.mark.parametrize("quant_bit", ["W4", "W2", "W1"]) +@pytest.mark.parametrize("quant_bit", ["W4", "W2", "W1", "W4A8", "W2A8", "W1A8"]) +@pytest.mark.parametrize("enable_bn_calibration", [True, False]) @pytest.mark.parametrize("run_mode", [context.GRAPH_MODE, context.PYNATIVE_MODE]) -def test_resnet(quant_bit, run_mode): +def test_resnet(quant_bit, enable_bn_calibration, run_mode): """ Feature: slb quantization algorithm. Description: Apply slb qat on resnet. @@ -550,13 +785,18 @@ def test_resnet(quant_bit, run_mode): network = resnet18(10) qat = SlbQAT() - - if quant_bit == "W4": + if "W4" in quant_bit: qat.set_weight_quant_dtype(QuantDtype.INT4) - elif quant_bit == "W2": + elif "W2" in quant_bit: qat.set_weight_quant_dtype(QuantDtype.INT2) - elif quant_bit == "W1": + elif "W1" in quant_bit: qat.set_weight_quant_dtype(QuantDtype.INT1) + if "A8" in quant_bit: + qat.set_act_quant_dtype(QuantDtype.INT8) + qat.set_enable_act_quant(True) + else: + qat.set_enable_act_quant(False) + qat.set_enable_bn_calibration(enable_bn_calibration) qat.set_epoch_size(100) qat.set_has_trained_epoch(0) qat.set_t_start_val(1.0) @@ -572,6 +812,7 @@ def test_resnet(quant_bit, run_mode): conv_handler = conv_quant._handler weight_fake_quant: SlbFakeQuantizerPerLayer = conv_handler.fake_quant_weight assert isinstance(weight_fake_quant, SlbFakeQuantizerPerLayer) + assert qat._config.enable_bn_calibration == enable_bn_calibration assert qat._config.epoch_size == 100 assert qat._config.has_trained_epoch == 0 assert qat._config.t_start_val == 1.0 @@ -582,7 +823,7 @@ def test_resnet(quant_bit, run_mode): -def _create_resnet_accuracy_model(quant_bit, run_mode=context.GRAPH_MODE): +def _create_resnet_accuracy_model(quant_bit, enable_bn_calibration, run_mode=context.GRAPH_MODE): """ Create model lr dataset for resnet slbqat accuracy test. Merge into test_resnet_accuracy after pynative bug is fixed. @@ -679,12 +920,18 @@ def _create_resnet_accuracy_model(quant_bit, run_mode=context.GRAPH_MODE): # apply golden-stick algo qat = SlbQAT() - if quant_bit == "W4": + if "W4" in quant_bit: qat.set_weight_quant_dtype(QuantDtype.INT4) - elif quant_bit == "W2": + elif "W2" in quant_bit: qat.set_weight_quant_dtype(QuantDtype.INT2) - elif quant_bit == "W1": + elif "W1" in quant_bit: qat.set_weight_quant_dtype(QuantDtype.INT1) + if "A8" in quant_bit: + qat.set_act_quant_dtype(QuantDtype.INT8) + qat.set_enable_act_quant(True) + else: + qat.set_enable_act_quant(False) + qat.set_enable_bn_calibration(enable_bn_calibration) qat.set_epoch_size(100) qat.set_has_trained_epoch(0) qat.set_t_start_val(1.0) @@ -705,16 +952,55 @@ def _create_resnet_accuracy_model(quant_bit, run_mode=context.GRAPH_MODE): metrics = {"acc"} metrics.clear() - model = mindspore.Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics=metrics, - keep_batchnorm_fp32=False) + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics=metrics, + keep_batchnorm_fp32=False) return model, lr, dataset, qat @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -@pytest.mark.parametrize("quant_bit", ["W4", "W2", "W1"]) -def test_resnet_accuracy_graph(quant_bit): +@pytest.mark.parametrize("quant_bit", ["W4", "W2", "W1", "W4A8", "W2A8", "W1A8"]) +@pytest.mark.parametrize("enable_bn_calibration", [True]) +def test_resnet_accuracy_graph_bnon(quant_bit, enable_bn_calibration): + """ + Feature: slb quantization algorithm. + Description: Apply slb qat on resnet and test accuracy + Expectation: Loss of first epoch is smaller than 2.5. + """ + + sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '../')) + from loss_monitor import LossMonitor + + step_threshold = 20 + target = "GPU" + epoch_size = 1 + + mindspore.context.set_context(mode=context.GRAPH_MODE, device_target=target) + model, lr, dataset, qat = _create_resnet_accuracy_model(quant_bit, enable_bn_calibration, context.GRAPH_MODE) + + # define callbacks + monitor = LossMonitor(lr_init=lr.asnumpy(), step_threshold=step_threshold) + callbacks = [monitor] + qat.callbacks(model, dataset) + # train model + dataset_sink_mode = target != "CPU" + print("============== Starting Training ==============") + model.train(epoch_size, dataset, callbacks=callbacks, sink_size=dataset.get_dataset_size(), + dataset_sink_mode=dataset_sink_mode) + print("============== End Training ==============") + expect_avg_step_loss = 2.5 + avg_step_loss = np.mean(np.array(monitor.losses)) + print("average step loss:{}".format(avg_step_loss)) + assert avg_step_loss <= expect_avg_step_loss + + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +@pytest.mark.parametrize("quant_bit", ["W4", "W2", "W1", "W4A8", "W2A8", "W1A8"]) +@pytest.mark.parametrize("enable_bn_calibration", [False]) +def test_resnet_accuracy_graph_bnoff(quant_bit, enable_bn_calibration): """ Feature: slb quantization algorithm. Description: Apply slb qat on resnet and test accuracy @@ -729,11 +1015,11 @@ def test_resnet_accuracy_graph(quant_bit): epoch_size = 1 mindspore.context.set_context(mode=context.GRAPH_MODE, device_target=target) - model, lr, dataset, qat = _create_resnet_accuracy_model(quant_bit, context.GRAPH_MODE) + model, lr, dataset, qat = _create_resnet_accuracy_model(quant_bit, enable_bn_calibration, context.GRAPH_MODE) # define callbacks monitor = LossMonitor(lr_init=lr.asnumpy(), step_threshold=step_threshold) - callbacks = [monitor] + qat.callbacks(model) + callbacks = [monitor] + qat.callbacks(model, dataset) # train model dataset_sink_mode = target != "CPU" print("============== Starting Training ==============") @@ -746,7 +1032,7 @@ def test_resnet_accuracy_graph(quant_bit): assert avg_step_loss <= expect_avg_step_loss -def test_resnet_accuracy_pynative(quant_bit): +def test_resnet_accuracy_pynative(quant_bit, enable_bn_calibration): """ Feature: Simulated quantization algorithm. Description: Apply simulated_quantization on resnet and test accuracy @@ -760,10 +1046,10 @@ def test_resnet_accuracy_pynative(quant_bit): epoch_size = 1 mindspore.context.set_context(mode=context.PYNATIVE_MODE, device_target=target) - model, lr, dataset, qat = _create_resnet_accuracy_model(quant_bit, context.PYNATIVE_MODE) + model, lr, dataset, qat = _create_resnet_accuracy_model(quant_bit, enable_bn_calibration, context.PYNATIVE_MODE) # define callbacks monitor = LossMonitor(lr_init=lr.asnumpy(), step_threshold=step_threshold) - callbacks = [monitor] + qat.callbacks(model) + callbacks = [monitor] + qat.callbacks(model, dataset) # train model dataset_sink_mode = target != "CPU" print("============== Starting Training ==============")