diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index b69b83f257607b8a40a339f6c762666292187bab..fb53fedb01bcd8659d8fbc756c9faf4adc456dda 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1645,8 +1645,36 @@ std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &n return nullptr; } +std::shared_ptr FindParameterNextLayout(const AnfNodePtr &node) { + FuncGraphManagerPtr manager = node->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodeIndexSet node_set = manager->node_users()[node]; + for (auto &node_pair : node_set) { + CNodePtr use_apply = node_pair.first->cast(); + if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { + continue; + } + ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + PrimitivePtr node_prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(node_prim); + if ((node_prim->name() == DEPEND && node_pair.second != 1) || node_prim->name() == RESHAPE) { + continue; + } + if (IsParallelCareNode(use_apply) && use_apply->has_user_data()) { + auto layout = GetInputLayoutFromCNode(node_pair); + return std::make_shared(layout); + } + } + return nullptr; +} + std::shared_ptr CreateParameterLayout(const AnfNodePtr &node) { // Create DataParallel tensor layout for parameter(support WideDeep). + auto next_layout = FindParameterNextLayout(node); + if (next_layout != nullptr) { + return next_layout; + } CheckGlobalDeviceManager(); int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); TensorLayout input_tensor_layout; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index ca049d1704dcbee258458b52bb8e8b91f1cd0ceb..a9a4d941b254ad973341f7544bd82a430cc0ec27 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -156,6 +156,8 @@ using ParameterUsersInfo = std::pair FindParameterNextLayout(const AnfNodePtr &node); + ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)(const CNodePtr &)); } // namespace parallel } // namespace mindspore diff --git a/tests/ut/python/parallel/test_auto_parallel_reshape.py b/tests/ut/python/parallel/test_auto_parallel_reshape.py index bde208456215564e86f72cffa3794f7d79fb7ab8..a54660bf6cbdc8b784ca11b1666dd96d15594a51 100644 --- a/tests/ut/python/parallel/test_auto_parallel_reshape.py +++ b/tests/ut/python/parallel/test_auto_parallel_reshape.py @@ -292,3 +292,25 @@ def test_reshape_auto_6(): context.set_auto_parallel_context(parallel_mode="auto_parallel") net.set_auto_parallel() _executor.compile(net, x, y) + +def test_reshape_auto_7(): + class Net(nn.Cell): + def __init__(self): + super().__init__() + self.reshape = P.Reshape() + self.mul = P.Mul().set_strategy(((1, 2, 4), (2, 4))) + self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight") + + def construct(self, x): + weight = self.reshape(self.mul_weight, (1, 128, 96)) + out = self.mul(weight, self.mul_weight) + return out + + size = 8 + context.set_auto_parallel_context(device_num=size, global_rank=0) + x = Tensor(np.ones([128, 28]), dtype=ms.float32) + + net = GradWrap(NetWithLoss(Net())) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + net.set_auto_parallel() + _executor.compile(net, x)