diff --git a/mindspore/_extends/graph_kernel/model/model.py b/mindspore/_extends/graph_kernel/model/model.py index 2bf6c6cd22028e530cb974f713f114887715f525..e11b38ed07886d17a262429d625074587820a784 100644 --- a/mindspore/_extends/graph_kernel/model/model.py +++ b/mindspore/_extends/graph_kernel/model/model.py @@ -177,7 +177,6 @@ class PrimLib: 'ReduceMax': Prim(REDUCE), 'ReduceMin': Prim(REDUCE), 'MakeTuple': Prim(CONTROL), - 'ControlDepend': Prim(CONTROL), 'Assign': Prim(ELEMWISE), 'Tanh': Prim(ELEMWISE), 'ExpandDims': Prim(RESHAPE), diff --git a/mindspore/ops/_grad/grad_implementations.py b/mindspore/ops/_grad/grad_implementations.py index c72b34d199f2adbd90c95c0231dc09f0adc751b2..407caaff758a231218755504978a9d7bebc575f9 100644 --- a/mindspore/ops/_grad/grad_implementations.py +++ b/mindspore/ops/_grad/grad_implementations.py @@ -261,12 +261,6 @@ def bprop_bool_and(x, y, out, dout): return C.zeros_like(x), C.zeros_like(y) -@bprops.register("ControlDepend") -def bprop_control_depend(x, y, out, dout): - """Backpropagator for primitive `Control_depend`.""" - return C.zeros_like(x), C.zeros_like(y) - - @bprops.register("Switch") def bprop_switch(cond, tb, fb, out, dout): """Backpropagator for primitive `switch`.""" diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 2dcd13a5bf6d554fc5fd31400626d2c9cc8cce8a..9645c63a509366c8564f7e2d4f92de684adbc823 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -42,11 +42,6 @@ shape = P.Shape() rank = P.Rank() reshape = P.Reshape() -# control_depend: represent dependency between two operators -def control_depend(src, dst): - control_depend_op = P.ControlDepend() - return control_depend_op(src, dst) - merge = P.Merge() geswitch = P.GeSwitch() addn = P.AddN() diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 90764509625fcd0fa26fafc581e73bb329c13acb..594a399b16937dd7b7ceccdf97a667c465a4b6a1 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -40,7 +40,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, _HostAllGather, _HostReduceScatter) from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, TensorSummary, HistogramSummary, Print, Assert) -from .control_ops import ControlDepend, GeSwitch, Merge +from .control_ops import GeSwitch, Merge from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, @@ -275,7 +275,6 @@ __all__ = [ 'ScalarToArray', 'ScalarToTensor', 'TupleToArray', - 'ControlDepend', 'GeSwitch', 'Merge', 'SameTypeShape', diff --git a/mindspore/ops/operations/control_ops.py b/mindspore/ops/operations/control_ops.py index 99306860ed6e85f46aee4e054b0f1fbc90e0f8e8..73cc069b5b707c14c9c7c13749ea37a474fea5d4 100644 --- a/mindspore/ops/operations/control_ops.py +++ b/mindspore/ops/operations/control_ops.py @@ -14,76 +14,9 @@ # ============================================================================ """control_ops""" -from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register -from ..._checkparam import Rel +from ..primitive import PrimitiveWithInfer, prim_attr_register from ..._checkparam import Validator as validator from ...common import dtype as mstype -from ...common._decorator import deprecated - - -class ControlDepend(Primitive): - """ - Adds control dependency relation between source and destination operations. - - In many cases, we need to control the execution order of operations. ControlDepend is designed for this. - ControlDepend will instruct the execution engine to run the operations in a specific order. ControlDepend - tells the engine that the destination operations must depend on the source operation which means the source - operations must be executed before the destination. - - Note: - This operation does not work in `PYNATIVE_MODE`. - `ControlDepend` is deprecated from version 1.1 and will be removed in a future version, use `Depend` instead. - Args: - depend_mode (int): Use 0 for a normal dependency relation and 1 for a user-defined dependency relation. - Default: 0. - - Inputs: - - **src** (Any) - The source input. It can be a tuple of operations output or a single operation output. We do - not concern about the input data, but concern about the operation that generates the input data. - If `depend_mode` is 1 and the source input is Parameter, we will try to find the operations that - used the parameter as input. - - **dst** (Any) - The destination input. It can be a tuple of operations output or a single operation output. - We do not concern about the input data, but concern about the operation that generates the input data. - If `depend_mode` is 1 and the source input is Parameter, we will try to find the operations that - used the parameter as input. - - Outputs: - This operation has no actual data output, it will be used to setup the order of relative operations. - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> class Net(nn.Cell): - ... def __init__(self): - ... super(Net, self).__init__() - ... self.control_depend = P.ControlDepend() - ... self.softmax = ops.Softmax() - ... - ... def construct(self, x, y): - ... mul = x * y - ... softmax = self.softmax(x) - ... ret = self.control_depend(mul, softmax) - ... return ret - ... - >>> x = Tensor(np.ones([4, 5]), dtype=mindspore.float32) - >>> y = Tensor(np.ones([4, 5]), dtype=mindspore.float32) - >>> net = Net() - >>> output = net(x, y) - >>> print(output) - [[1. 1. 1. 1. 1.] - [1. 1. 1. 1. 1.] - [1. 1. 1. 1. 1.] - [1. 1. 1. 1. 1.]] - """ - @deprecated("1.1", "Depend") - @prim_attr_register - def __init__(self, depend_mode=0): - """init""" - validator.check_int_range(depend_mode, 0, 1, Rel.INC_BOTH, "depend_mode", self.name) - - def __call__(self, src, dst): - return src class GeSwitch(PrimitiveWithInfer): diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 0e488a846df9dab9bdb7d8e46f1b7267eea81acf..9d9f4467e0119cdfa2afb0ba6bcaefa4b2c6ebdc 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -420,16 +420,12 @@ class Depend(Primitive): Depend is used for processing dependency operations. In some side-effect scenarios, we need to ensure the execution order of operators. - In order to ensure that operator A is executed before operator B, it is recommended - to insert the Depend operator between operators A and B. - - Previously, the ControlDepend operator was used to control the execution order. - Since the ControlDepend operator is deprecated from version 1.1, it is recommended - to use the Depend operator instead. The replacement method is as follows:: + In order to ensure that operator A is executed before operator B, it is recommended to + insert the Depend operator between operators A and B. The usage method is as follows:: a = A(x) ---> a = A(x) b = B(y) ---> y = Depend(y, a) - ControlDepend(a, b) ---> b = B(y) + ---> b = B(y) Inputs: - **value** (Tensor) - the real value to return for depend operator. diff --git a/model_zoo/official/nlp/bert/src/adam.py b/model_zoo/official/nlp/bert/src/adam.py index c7a952e2bb444ee64556a2293acad6f0d869d13f..f04dfb8080e9534af377334118ce47047446753f 100644 --- a/model_zoo/official/nlp/bert/src/adam.py +++ b/model_zoo/official/nlp/bert/src/adam.py @@ -115,8 +115,8 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, op_sqrt = P.Sqrt() scatter_add = P.ScatterAdd(use_locking) - assign_m = F.assign(m, op_mul(beta1, m)) - assign_v = F.assign(v, op_mul(beta2, v)) + success = F.depend(success, F.assign(m, op_mul(beta1, m))) + success = F.depend(success, F.assign(v, op_mul(beta2, v))) grad_indices = gradient.indices grad_value = gradient.values @@ -131,27 +131,18 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, if use_nesterov: m_temp = next_m * _scaler_ten - assign_m_nesterov = F.assign(m, op_mul(beta1, next_m)) + F.assign(m, op_mul(beta1, next_m)) div_value = scatter_add(m, op_mul(grad_indices, _scaler_one), op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value)) param_update = div_value / (op_sqrt(next_v) + eps) - - m_recover = F.assign(m, m_temp / _scaler_ten) - - F.control_depend(m_temp, assign_m_nesterov) - F.control_depend(assign_m_nesterov, div_value) - F.control_depend(param_update, m_recover) + F.assign(m, m_temp / _scaler_ten) else: param_update = next_m / (op_sqrt(next_v) + eps) lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power) - next_param = param - lr_t * param_update - F.control_depend(assign_m, next_m) - F.control_depend(assign_v, next_v) - success = F.depend(success, F.assign(param, next_param)) success = F.depend(success, F.assign(m, next_m)) success = F.depend(success, F.assign(v, next_v)) diff --git a/model_zoo/official/nlp/gru/src/gru_for_train.py b/model_zoo/official/nlp/gru/src/gru_for_train.py index 85eaada706739622f5e9ed9839d0dba69d75b06f..76ff0c70163c9bfe69e44872d54069949b09dae9 100644 --- a/model_zoo/official/nlp/gru/src/gru_for_train.py +++ b/model_zoo/official/nlp/gru/src/gru_for_train.py @@ -172,7 +172,6 @@ class GRUTrainOneStepWithLossScaleCell(nn.Cell): self.get_status = P.NPUGetFloatStatus() self.clear_before_grad = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() self.hyper_map = C.HyperMap() diff --git a/model_zoo/research/cv/FaceDetection/src/network_define.py b/model_zoo/research/cv/FaceDetection/src/network_define.py index d6d541768c24324a4831784ffba5fb9ab3192825..8673fb30f7eadd88e0addc2b502424151c169ff7 100644 --- a/model_zoo/research/cv/FaceDetection/src/network_define.py +++ b/model_zoo/research/cv/FaceDetection/src/network_define.py @@ -17,7 +17,7 @@ import numpy as np import mindspore.nn as nn from mindspore.ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, \ - LessEqual, ControlDepend + LessEqual from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore import Tensor @@ -25,7 +25,7 @@ from mindspore.context import ParallelMode from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.ops import operations as P -from mindspore.common.parameter import ParameterTuple +from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common import dtype as mstype @@ -69,7 +69,6 @@ class TrainOneStepWithLossScaleCell(nn.Cell): self.base = Tensor(1, mstype.float32) self.reducer_flag = False self.less_equal = LessEqual() - self.depend_parameter_use = ControlDepend(depend_mode=1) self.allreduce = P.AllReduce() self.parallel_mode = _get_parallel_mode() self.grad_reducer = None diff --git a/model_zoo/research/nlp/ternarybert/src/cell_wrapper.py b/model_zoo/research/nlp/ternarybert/src/cell_wrapper.py index 2c505b9db0401f08666782cf11c471914f5d774b..04906f434c1782b2f97639a7ec7aacf13217f9fa 100644 --- a/model_zoo/research/nlp/ternarybert/src/cell_wrapper.py +++ b/model_zoo/research/nlp/ternarybert/src/cell_wrapper.py @@ -341,7 +341,6 @@ class BertTrainWithLossScaleCell(nn.Cell): self.get_status = P.NPUGetFloatStatus() self.clear_before_grad = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() self.hyper_map = C.HyperMap() @@ -378,27 +377,22 @@ class BertTrainWithLossScaleCell(nn.Cell): sens=None): """Defines the computation performed.""" weights = self.weights - saved = () for i in range(self.length): - saved = saved + (F.assign(self.saved_params[i], weights[i]),) - assign_embedding = () + F.assign(self.saved_params[i], weights[i]) + for i in range(self.quant_embedding_list_length): quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]]) - assign_embedding = assign_embedding + (F.assign(weights[self.quant_embedding_list[i]], quant_embedding),) - F.control_depend(saved, assign_embedding[i]) - assign_weight = () + F.assign(weights[self.quant_embedding_list[i]], quant_embedding) + for i in range(self.quant_weight_list_length): quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]]) - assign_weight = assign_weight + (F.assign(weights[self.quant_weight_list[i]], quant_weight),) - F.control_depend(saved, assign_weight[i]) - for i in range(self.quant_embedding_list_length): - F.control_depend(assign_embedding[i], input_ids) - for i in range(self.quant_weight_list_length): - F.control_depend(assign_weight[i], input_ids) + F.assign(weights[self.quant_weight_list[i]], quant_weight) + if sens is None: scaling_sens = self.loss_scale else: scaling_sens = sens + # alloc status and clear should be right before grad operation init = self.alloc_status() self.clear_before_grad(init) @@ -408,15 +402,15 @@ class BertTrainWithLossScaleCell(nn.Cell): label_ids, self.cast(scaling_sens, mstype.float32)) - F.control_depend(input_ids, grads) # apply grad reducer on grads grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads) - restore = () + for i in range(self.length): - restore = restore + (F.assign(weights[i], self.saved_params[i]),) - F.control_depend(grads, restore[i]) + param = F.depend(self.saved_params[i], grads) + F.assign(weights[i], param) + self.get_status(init) flag_sum = self.reduce_sum(init, (0,)) if self.is_distributed: @@ -432,8 +426,6 @@ class BertTrainWithLossScaleCell(nn.Cell): succ = False else: succ = self.optimizer(grads) - for i in range(self.length): - F.control_depend(restore[i], succ) return succ @@ -492,38 +484,30 @@ class BertTrainCell(nn.Cell): label_ids): """Defines the computation performed.""" weights = self.weights - saved = () for i in range(self.length): - saved = saved + (F.assign(self.saved_params[i], weights[i]),) - assign_embedding = () + F.assign(self.saved_params[i], weights[i]) + for i in range(self.quant_embedding_list_length): quant_embedding = self.quantize_embedding(weights[self.quant_embedding_list[i]]) - assign_embedding = assign_embedding + (F.assign(weights[self.quant_embedding_list[i]], quant_embedding),) - F.control_depend(saved, assign_embedding[i]) - assign_weight = () + F.assign(weights[self.quant_embedding_list[i]], quant_embedding) + for i in range(self.quant_weight_list_length): quant_weight = self.quantize_weight(weights[self.quant_weight_list[i]]) - assign_weight = assign_weight + (F.assign(weights[self.quant_weight_list[i]], quant_weight),) - F.control_depend(saved, assign_weight[i]) - for i in range(self.quant_embedding_list_length): - F.control_depend(assign_embedding[i], input_ids) - for i in range(self.quant_weight_list_length): - F.control_depend(assign_weight[i], input_ids) + F.assign(weights[self.quant_weight_list[i]], quant_weight) + grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, label_ids, self.cast(F.tuple_to_array((self.sens,)), mstype.float32)) - F.control_depend(input_ids, grads) # apply grad reducer on grads grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(clip_grad, self.clip_type, self.clip_value), grads) - restore = () + for i in range(self.length): - restore = restore + (F.assign(weights[i], self.saved_params[i]),) - F.control_depend(grads, restore[i]) + param = F.depend(self.saved_params[i], grads) + F.assign(weights[i], param) + succ = self.optimizer(grads) - for i in range(self.length): - F.control_depend(restore[i], succ) return succ diff --git a/tests/st/auto_monad/test_auto_monad_gpu.py b/tests/st/auto_monad/test_auto_monad_gpu.py index 40f54a3d8b16613a8405890acf9c8c35c22078cd..685d686128ad0d273d46eb801458555b0645944e 100644 --- a/tests/st/auto_monad/test_auto_monad_gpu.py +++ b/tests/st/auto_monad/test_auto_monad_gpu.py @@ -399,7 +399,6 @@ class MixControlNet(Cell): kernel_size=1, stride=1, has_bias=False, weight_init='ones', pad_mode='same') self.bn = BatchNorm2d(num_features=in_channel) - self.controldepend = P.ControlDepend() self.assignadd = P.AssignAdd() self.assign = P.Assign() self.relu = ReLU() @@ -428,9 +427,8 @@ class MixControlNet(Cell): if x < 20: out = self.biasadd(out, self.bias) if x % 2 == 0: + self.assignadd(self.bias, self.value) out = self.biasadd(out, self.bias) - assign = self.assignadd(self.bias, self.value) - self.controldepend(assign, out) out = self.bn(out) else: out = self.conv(out) diff --git a/tests/st/graph_kernel/model/test_graph_parallel.py b/tests/st/graph_kernel/model/test_graph_parallel.py index 4f2fa89a62a4194a56423cbafc7484cd4703665e..130dbc0fb020b9eab1376f4dc1923690694ef77c 100644 --- a/tests/st/graph_kernel/model/test_graph_parallel.py +++ b/tests/st/graph_kernel/model/test_graph_parallel.py @@ -33,14 +33,6 @@ def reduce_graph(shape, reduce_axis): gb.emit('ReduceSum', a3, 'C', attrs={'reduce_axis': reduce_axis}) return gb.get()[0] -def control_graph(shape): - gb = model.GraphBuilder() - with gb.graph_scope('control') as _: - a1 = gb.tensor(shape, 'float32') - a2 = gb.emit('Abs', a1) - gb.emit('ControlDepend', a2) - return gb.get()[0] - def block_fusion(graphs): gain = model.parallel_estimate(graphs) print("fusion = {}, bottleneck = {}, gain = {}".format(gain.fusion_type, gain.bottleneck, gain.gain)) @@ -51,4 +43,3 @@ if __name__ == "__main__": assert block_fusion([reduce_graph([1024, 1024], [1]), injective_graph([24, 1024])]) assert not block_fusion([reduce_graph([1024, 1024], [1]), injective_graph([50, 1024])]) assert not block_fusion([reduce_graph([1024, 1024], [0, 1]), injective_graph([1024, 1024])]) - assert block_fusion([control_graph([20, 128]), injective_graph([40, 1024])])