diff --git a/mindspore/_extends/builtin_operations.py b/mindspore/_extends/builtin_operations.py index 7b5f34942e75f86f79a2d95168952976edfabd54..6bd382c1b672599b2f1231c83c1e40f843b50105 100644 --- a/mindspore/_extends/builtin_operations.py +++ b/mindspore/_extends/builtin_operations.py @@ -113,6 +113,24 @@ def bool_or(x, y): """Implement `bool_or`.""" return x or y +def vm_compare(*args): + """Implement `vm_compare` for tensor.""" + obj_str = args[-1] + if obj_str == "shape": + fn = getattr(args[0].asnumpy(), obj_str) + return fn + if len(args) == 2: + fn = getattr(args[0].asnumpy(), obj_str) + return Tensor(fn()) + if isinstance(args[0], Tensor): + fn = getattr(args[0].asnumpy(), obj_str) + y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1] + else: + obj_str = "__r" + obj_str[2:] + fn = getattr(args[1].asnumpy(), obj_str) + y = args[0] + return Tensor(np.array(fn(y))) + def make_list(*xs): """Implement `make_list`.""" diff --git a/mindspore/ccsrc/pipeline/parse/data_converter.cc b/mindspore/ccsrc/pipeline/parse/data_converter.cc index c36ef2452d42a78857f688c4e61c94c1ab853164..20f7c0c9cedd0ced26f2cc8a341de5ec1b5f4507 100644 --- a/mindspore/ccsrc/pipeline/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/parse/data_converter.cc @@ -41,6 +41,35 @@ using TensorPtr = mindspore::tensor::TensorPtr; using MetaTensor = mindspore::tensor::MetaTensor; using MetaTensorPtr = mindspore::tensor::MetaTensorPtr; +FuncGraphPtr ConvertToBpropCut(const py::object &obj) { + std::vector results = data_converter::GetObjKey(obj); + std::string obj_key = results[0]; + py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME); + + auto bprop_graph = std::make_shared(); + std::vector outputs; + + auto fake_bprop = std::make_shared("bprop_cut", py::object()); + fake_bprop->set_hook(bprop_func); + (void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true)); + outputs.push_back(NewValueNode(fake_bprop)); + + py::object code_obj = py::getattr(bprop_func, "__code__"); + size_t inputs_num = py::cast(py::getattr(code_obj, "co_argcount")) - 3; + for (size_t i = 0; i < inputs_num; ++i) { + auto param = bprop_graph->add_parameter(); + outputs.push_back(param); + } + auto p1 = bprop_graph->add_parameter(); + auto p2 = bprop_graph->add_parameter(); + outputs.push_back(p1); + outputs.push_back(p2); + + bprop_graph->set_output(bprop_graph->NewCNode(outputs)); + data_converter::SetObjGraphValue(obj_key, bprop_graph); + return bprop_graph; +} + namespace { bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) { MS_LOG(DEBUG) << "Converting python tuple"; @@ -231,35 +260,6 @@ bool ConvertSlice(const py::object &obj, ValuePtr *const data) { return true; } -FuncGraphPtr ConvertToBpropCut(py::object obj) { - std::vector results = data_converter::GetObjKey(obj); - std::string obj_key = results[0]; - py::function bprop_func = py::getattr(obj, "bprop"); - - FuncGraphPtr bprop_graph = std::make_shared(); - std::vector outputs; - - auto fake_bprop = std::make_shared("bprop_cut", py::object()); - fake_bprop->set_hook(bprop_func); - (void)fake_bprop->AddAttr("bprop", MakeValue(true)); - outputs.push_back(NewValueNode(fake_bprop)); - - py::object code_obj = py::getattr(bprop_func, "__code__"); - size_t inputs_num = py::cast(py::getattr(code_obj, "co_argcount")) - 3; - for (size_t i = 0; i < inputs_num; ++i) { - auto param = bprop_graph->add_parameter(); - outputs.push_back(param); - } - auto p1 = bprop_graph->add_parameter(); - auto p2 = bprop_graph->add_parameter(); - outputs.push_back(p1); - outputs.push_back(p2); - - bprop_graph->set_output(bprop_graph->NewCNode(outputs)); - data_converter::SetObjGraphValue(obj_key, bprop_graph); - return bprop_graph; -} - bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { FuncGraphPtr func_graph = ConvertToFuncGraph(obj); if (func_graph == nullptr) { @@ -267,7 +267,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { return false; } // if the cell object has specified bprop, it has user-defined bprop function parse and record it - if (py::hasattr(obj, "bprop")) { + if (py::hasattr(obj, CUSTOM_BPROP_NAME)) { FuncGraphPtr bprop_graph = nullptr; bool enable_bprop_debug = py::cast(py::getattr(obj, "bprop_debug")); if (enable_bprop_debug) { @@ -276,7 +276,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); } if (bprop_graph != nullptr) { - (void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph))); + (void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true); } diff --git a/mindspore/ccsrc/pipeline/parse/data_converter.h b/mindspore/ccsrc/pipeline/parse/data_converter.h index a8918fa60c12d5ca47d7545ba48333e6804a2961..0165b5536318138ce8eb9196935bd7b648c94a87 100644 --- a/mindspore/ccsrc/pipeline/parse/data_converter.h +++ b/mindspore/ccsrc/pipeline/parse/data_converter.h @@ -51,6 +51,7 @@ void ClearObjectCache(); } // namespace data_converter ClassPtr ParseDataClass(const py::object &cls_obj); +FuncGraphPtr ConvertToBpropCut(const py::object &obj); void CleanDataClassToClassMap(); diff --git a/mindspore/ccsrc/pipeline/parse/parse_base.h b/mindspore/ccsrc/pipeline/parse/parse_base.h index a73b3b0c81d604622b1c89e5a287fb5b2baee754..4961ab78c0fa87cce3ce0ffda2df77b36a902ee6 100644 --- a/mindspore/ccsrc/pipeline/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/parse/parse_base.h @@ -109,6 +109,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags"; // define the parse constant const int MAX_COMPARISON_OPS_SUPPORTED = 1; +const char CUSTOM_BPROP_NAME[] = "bprop"; // define the Namespace name const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace diff --git a/mindspore/ccsrc/pynative/base.h b/mindspore/ccsrc/pynative/base.h index 6d7c83ca57d6f7c0cd2c9f9910b820b04ac714f0..4e86f0493f4c5acb9d85cf07f23c428a045960bf 100644 --- a/mindspore/ccsrc/pynative/base.h +++ b/mindspore/ccsrc/pynative/base.h @@ -45,7 +45,7 @@ enum PynativeStatusCode { PYNATIVE_UNKNOWN_STATE = 0XFF }; -enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_INPUT_MASK, PY_ARGS_NUM }; +enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM }; struct OpExecInfo { PrimitivePyPtr py_primitive; diff --git a/mindspore/ccsrc/pynative/pynative_execute.cc b/mindspore/ccsrc/pynative/pynative_execute.cc index c1cad930c03853074d4f48f703f2f21992eae5d8..30fb4ef1afd881280ff3e2e49f1b65e8c8a02196 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pynative/pynative_execute.cc @@ -110,9 +110,15 @@ py::object GetTupleObj(const py::object &obj) { return obj_tuple; } -void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) { +py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) { auto &py_args = *out_args; + py::tuple input_mask(args.size()); for (size_t i = 0; i < args.size(); ++i) { + if (py::hasattr(args[i], "__parameter__")) { + input_mask[i] = true; + } else { + input_mask[i] = false; + } py_args[i] = GetTupleObj(args[i]); } auto signature = prim->signatures(); @@ -121,7 +127,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple * [](const Signature &sig) { return sig.dtype; }); int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); if (dtypes.size() == 0 || static_cast(dtypes.size()) == empty_dtype_count) { - return; + return input_mask; } std::map> type_indexs; for (size_t i = 0; i < dtypes.size(); ++i) { @@ -160,6 +166,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple * continue; } } + return input_mask; } void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) { @@ -167,7 +174,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn AbstractBasePtrList args_spec_list; for (size_t i = 0; i < size; i++) { ValuePtr input_value = PyAttrValue(py_args[i]); - if (input_value->isa()) { + if (!py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa()) { args_spec_list.emplace_back(abstract::FromValueInside(input_value, true)); } else { args_spec_list.emplace_back(abstract::FromValueInside(input_value, false)); @@ -179,7 +186,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { if (args.size() != PY_ARGS_NUM) { - MS_LOG(ERROR) << "Four args are needed by RunOp"; + MS_LOG(ERROR) << "Three args are needed by RunOp"; return nullptr; } auto op_exec_info = std::make_shared(); @@ -195,14 +202,13 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { size_t input_num = a.size(); op_exec_info->op_inputs = py::tuple(input_num); - ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs); + op_exec_info->inputs_mask = ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs); // use python infer method if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get()); } op_exec_info->py_primitive = prim; op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); - op_exec_info->inputs_mask = args[PY_INPUT_MASK]; if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) { MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask"; return nullptr; @@ -488,14 +494,14 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn return result; } -AnfNodePtr PynativeExecutor::MakeCNode(const py::args &args, const py::tuple &out) { +AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) { if (!grad_flag_ || graph_info_map_.size() == 0) { return nullptr; } std::vector inputs; - auto prim = py::cast(args[PY_PRIM]); + auto prim = op_exec_info->py_primitive; inputs.push_back(NewValueNode(prim)); - py::tuple op_masks = args[PY_INPUT_MASK]; + py::tuple op_masks = op_exec_info->inputs_mask; py::list op_args = args[PY_INPUTS]; AbstractBasePtrList args_spec_list; for (size_t i = 0; i < op_args.size(); i++) { @@ -584,7 +590,7 @@ py::tuple RunOp(const py::args &args) { return err_ret; } - auto node = PynativeExecutor::GetInstance()->MakeCNode(args, result); + auto node = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result); if (node != nullptr) { node->set_abstract(op_exec_info->abstract); MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString(); @@ -705,7 +711,7 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c } cell_graph_map_[cell_id] = curr_g_; auto out_id = GetId(out); - if (!graph_info_map_[curr_g_].obj_node_map.count(out_id)) { + if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) { // cell construct return x, y if (py::isinstance(out)) { std::vector args; @@ -727,12 +733,26 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c } } - auto output_node = GetObjNode(out); + AnfNodePtr output_node; + if (graph_info_map_[curr_g_].param_map.count(out_id)) { + output_node = graph_info_map_[curr_g_].param_map[out_id]; + } else { + output_node = GetObjNode(out); + } curr_g_->set_output(output_node); std::vector inputs; inputs.push_back(NewValueNode(curr_g_)); MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString(); resource_->manager()->AddFuncGraph(curr_g_); + // custom bprop debug + if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { + MS_LOG(DEBUG) << "Use cell custom bprop function."; + FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell); + if (bprop_graph != nullptr) { + (void)curr_g_->transforms().insert(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); + (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_))); + } + } auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_); if (curr_g_ != top_g_) { Popp(); diff --git a/mindspore/ccsrc/pynative/pynative_execute.h b/mindspore/ccsrc/pynative/pynative_execute.h index 19ae9965c847c0c225adb3fda82dde9b7180ad85..a0e8b448f4d6811d93598b030644e72b24b8baa7 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pynative/pynative_execute.h @@ -44,7 +44,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat py::tuple RunOp(const py::args &args); -void ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *out_args); +py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *out_args); void ClearPyNativeSession(); @@ -83,7 +83,7 @@ class PynativeExecutor : public std::enable_shared_from_this { void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) { graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index); } - AnfNodePtr MakeCNode(const py::args &args, const py::tuple &out); + AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out); py::object Run(const py::tuple &args, const py::object &phase); void Pushp(); diff --git a/mindspore/common/_register_for_tensor.py b/mindspore/common/_register_for_tensor.py index da183d9549183f6ef8d6032507119c9f3fc9fcab..8ba2ff7cc4ea62621874b9c25e53605e175966e1 100644 --- a/mindspore/common/_register_for_tensor.py +++ b/mindspore/common/_register_for_tensor.py @@ -16,6 +16,7 @@ """Registry the relation.""" from collections import UserDict +from .. import context class Registry(UserDict): @@ -27,9 +28,16 @@ class Registry(UserDict): def get(self, obj_str): """Get the value by str.""" - if isinstance(obj_str, str): + if not isinstance(obj_str, str): + raise TypeError("key for tensor registry must be string.") + if context.get_context("enable_ge"): + def wrap(*args): + new_args = list(args) + new_args.append(obj_str) + return self["vm_compare"](*new_args) + obj = wrap + else: obj = self[obj_str] return obj - tensor_operator_registry = Registry() diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index c806ec00920ca15c2b0652e428d595db07fa69a4..b7452d8165af7b83b4c51889cc9bd2e6a2bd70b6 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -19,7 +19,6 @@ from .._c_expression import Tensor as Tensor_ from .._c_expression import MetaTensor from .._checkparam import check_type, check_typename from . import dtype as mstype -from .. import context from ._register_for_tensor import tensor_operator_registry __all__ = ['Tensor', 'MetaTensor'] @@ -76,17 +75,19 @@ class Tensor(Tensor_): return out def __eq__(self, other): - if not isinstance(other, Tensor): + if not isinstance(other, (int, float, Tensor)): return False - # The GE backend don't support single `Equal` operator execution. # bool type is not supported for `Equal` operator in backend. - if context.get_context("enable_ge") or self.dtype == mstype.bool_ or other.dtype == mstype.bool_: + if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_): return Tensor(np.array(self.asnumpy() == other.asnumpy())) return tensor_operator_registry.get('__eq__')(self, other) def __ne__(self, other): - if not isinstance(other, Tensor): + if not isinstance(other, (int, float, Tensor)): return True + # bool type is not supported for `NotEqual` operator in backend. + if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_): + return Tensor(np.array(self.asnumpy() != other.asnumpy())) return tensor_operator_registry.get('__ne__')(self, other) def __hash__(self): @@ -105,7 +106,7 @@ class Tensor(Tensor_): return out def __radd__(self, other): - out = tensor_operator_registry.get('__add__')(other, self) + out = tensor_operator_registry.get('__add__')(self, other) return out def __imul__(self, other): @@ -113,15 +114,15 @@ class Tensor(Tensor_): return out def __rmul__(self, other): - out = tensor_operator_registry.get('__mul__')(other, self) + out = tensor_operator_registry.get('__mul__')(self, other) return out def __truediv__(self, other): - out = tensor_operator_registry.get('__div__')(self, other) + out = tensor_operator_registry.get('__truediv__')(self, other) return out def __rtruediv__(self, other): - out = tensor_operator_registry.get('__div__')(other, self) + out = tensor_operator_registry.get('__truediv__')(other, self) return out def __sub__(self, other): @@ -160,7 +161,7 @@ class Tensor(Tensor_): return out def __len__(self): - out = tensor_operator_registry.get('__shape__')(self) + out = tensor_operator_registry.get('shape')(self) if not out: return 1 return out[0] diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 0377e9dcc38c9df5d33c2221647ad6cb05e84c13..65c1ce9548637ec84ff88e8aa1b72378a3e64a35 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -819,4 +819,4 @@ class Cell: """ self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")") - self._enable_hook = True + self.enable_hook = True diff --git a/mindspore/nn/layer/container.py b/mindspore/nn/layer/container.py index 392448bb65b3f3e6bd21ab839ac3a61cf3abeede..48871401bf73cba651160c46250a6f7068d7bb6b 100644 --- a/mindspore/nn/layer/container.py +++ b/mindspore/nn/layer/container.py @@ -140,6 +140,11 @@ class SequentialCell(Cell): def __len__(self): return len(self._cells) + def set_grad(self, flag=True): + self.requires_grad = flag + for cell in self._cells.values(): + cell.set_grad(flag) + def construct(self, input_data): for cell in self.cell_list: input_data = cell(input_data) @@ -246,5 +251,10 @@ class CellList(_CellListBase, Cell): self._cells[str(len(self))] = cell return self + def set_grad(self, flag=True): + self.requires_grad = flag + for cell in self._cells.values(): + cell.set_grad(flag) + def construct(self, *inputs): raise NotImplementedError diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 8b774b80bb0b24db838ee6e97536e37bbe9031c8..63e83a126cf4643b94fa3fc1c895bd52d20a27ea 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -112,7 +112,7 @@ class GradOperation(GradOperation_): grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param) if self.grad_fn is None or self.fn != fn: if self.get_by_list: - if context.get_context("mode") == context.GRAPH_MODE or fn.bprop_debug: + if context.get_context("mode") == context.GRAPH_MODE: @ms_function(obj=fn) def after_grad(*args): return grad_(fn, weights)(*args) diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 4fd7a9d9d272080ebcc54c045e0b55750bd2f379..5637274bfb1c883660538e6ccf0e41cd153a39c7 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -21,6 +21,7 @@ from mindspore.common._register_for_tensor import tensor_operator_registry from .primitive import Primitive from . import operations as P from .operations import _grad_ops +from .._extends import builtin_operations as BP typeof = Primitive('typeof') hastype = Primitive('hastype') @@ -155,7 +156,7 @@ stop_gradient = Primitive("stop_gradient") tensor_operator_registry.register('__add__', tensor_add) tensor_operator_registry.register('__sub__', tensor_sub) tensor_operator_registry.register('__mul__', tensor_mul) -tensor_operator_registry.register('__div__', tensor_div) +tensor_operator_registry.register('__truediv__', tensor_div) #ms cannot support Tensor(True) compare tensor_operator_registry.register('__eq__', equal) tensor_operator_registry.register('__ne__', not_equal) @@ -164,4 +165,6 @@ tensor_operator_registry.register('__lt__', tensor_lt) tensor_operator_registry.register('__le__', tensor_le) tensor_operator_registry.register('__gt__', tensor_gt) tensor_operator_registry.register('__ge__', tensor_ge) -tensor_operator_registry.register('__shape__', shape) +tensor_operator_registry.register('shape', shape) +#support GE backend for no compare operators +tensor_operator_registry.register('vm_compare', BP.vm_compare) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index adc2911742e534add90d5777c9baa716705709aa..201583abe0fe883e4a6bccd4356d5e94797ed6d9 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -933,6 +933,8 @@ class TupleToArray(PrimitiveWithInfer): args = list() if isinstance(x, range): args.append(tuple(x)) + else: + args.append(x) return _run_op(self, self.name, args) diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 8d2a675b90087737e2a809c4478b5bfcb8e61483..61da7587a1f74c2acc0c66fcdf82f00c46009d67 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -341,13 +341,7 @@ def constexpr(fn=None, get_instance=True, name=None): @_wrap_func def _run_op(obj, op_name, args): """Single op execution function supported by ge in PyNative mode.""" - op_mask = [0] * len(args) - op_inputs = [] - for i, arg in enumerate(args): - if hasattr(arg, '__parameter__'): - op_mask[i] = 1 - op_inputs.append(arg) - output = real_run_op(obj, op_name, args, tuple(op_mask)) + output = real_run_op(obj, op_name, args) if not output: raise RuntimeError("Pynative run op %s failed!" % op_name) if len(output) == 1: diff --git a/tests/ut/cpp/pynative/pynative_execute_test.cc b/tests/ut/cpp/pynative/pynative_execute_test.cc index 34184516c2a7090358f63ba9b02637f1737a7887..88f82d996b1cc200fa9a4ca86fcf56c4e1a96da9 100644 --- a/tests/ut/cpp/pynative/pynative_execute_test.cc +++ b/tests/ut/cpp/pynative/pynative_execute_test.cc @@ -63,8 +63,7 @@ OpExecInfoPtr ConstructOpExecInfo() { auto conv_obj = prim::GetPythonOps("conv2d_prim", "gtest_input.pynative"); py::none py_none; - py::tuple op_mask = py::make_tuple(0, 1); - return GenerateOpExecInfo(py::make_tuple(conv_obj, op_name, op_inputs, op_mask)); + return GenerateOpExecInfo(py::make_tuple(conv_obj, op_name, op_inputs)); } TEST_F(TestPynativeExecute, TestRunOpInVM) { @@ -79,7 +78,7 @@ TEST_F(TestPynativeExecute, TestRunOp) { py::none py_none; auto op_exec_info_ptr = ConstructOpExecInfo(); py::tuple outputs = pynative::RunOp(py::make_tuple(op_exec_info_ptr->py_primitive, op_exec_info_ptr->op_name, - op_exec_info_ptr->op_inputs, op_exec_info_ptr->inputs_mask)); + op_exec_info_ptr->op_inputs)); if (outputs.size() == 0) { FAIL(); } else { diff --git a/tests/ut/python/ir/test_tensor.py b/tests/ut/python/ir/test_tensor.py index 43e7d4956a22bc295cb6c39f0369bb1d29f978f1..ff0a5c971f3bd35632bffb232250a0cf26a4ab12 100644 --- a/tests/ut/python/ir/test_tensor.py +++ b/tests/ut/python/ir/test_tensor.py @@ -452,5 +452,5 @@ def test_tensor_operation(): assert np.all(res.asnumpy() == np.ones((3, 3)) * 2) res = 8 / x assert np.all(res.asnumpy() == np.ones((3, 3)) * 2) - with pytest.raises(TypeError): + with pytest.raises(ValueError): res = x * (2, 3) diff --git a/tests/ut/python/pynative_mode/test_hook.py b/tests/ut/python/pynative_mode/test_hook.py index 35062336c4a0d1379ae7af5347469518daef8f49..023f039a97405c71c486f5327afbcb9168392bb7 100644 --- a/tests/ut/python/pynative_mode/test_hook.py +++ b/tests/ut/python/pynative_mode/test_hook.py @@ -8,6 +8,9 @@ from mindspore.nn import WithLossCell, Momentum from mindspore.ops import composite as C context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") +cell_hook_done = False +var_hook_done = False +cell_bprop_done = False def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): @@ -32,15 +35,35 @@ def weight_variable(): def cell_hook_function(cell_id, grad_input, grad_output): print(cell_id) + global cell_hook_done + cell_hook_done = True assert (grad_output[0].asnumpy().shape == (32, 6, 14, 14)) assert (grad_input[0].asnumpy().shape == (32, 16, 10, 10)) def var_hook_function(grad_out): print("grad:", grad_out) + global var_hook_done + var_hook_done = True assert (grad_out[0].asnumpy().shape == (32, 120)) +class Block(nn.Cell): + def __init__(self): + super(Block, self).__init__() + self.relu = nn.ReLU() + + def construct(self, x): + x = self.relu(x) + return x + + def bprop(self, x, out, dout): + global cell_bprop_done + cell_bprop_done = True + grad = out.asnumpy() * dout.asnumpy() + grad = Tensor(grad) + return (grad,) + class LeNet5(nn.Cell): """ Lenet network @@ -59,6 +82,7 @@ class LeNet5(nn.Cell): self.conv1 = conv(1, 6, 5) self.conv2 = conv(6, 16, 5) self.conv2.register_backward_hook(cell_hook_function) + self.block = Block() self.fc1 = fc_with_initialize(16 * 5 * 5, 120) self.fc2 = fc_with_initialize(120, 84) self.fc3 = fc_with_initialize(84, self.num_class) @@ -72,7 +96,7 @@ class LeNet5(nn.Cell): x = self.relu(x) x = self.max_pool2d(x) x = self.conv2(x) - x = self.relu(x) + x = self.block(x) x = self.max_pool2d(x) x = self.reshape(x, (self.batch_size, -1)) x = self.fc1(x) @@ -110,6 +134,9 @@ def test_hook(): loss_output = criterion(output, label) grads = train_network(input_data, label) success = optimizer(grads) + assert cell_hook_done + assert var_hook_done + assert cell_bprop_done print(loss_output.asnumpy().shape)