diff --git a/mindspore/ccsrc/pynative/forward/pyboost/template/pyboost_api_h.tpl b/mindspore/ccsrc/pynative/forward/pyboost/template/pyboost_api_h.tpl index b0de3ad1427a137c33f1a3e7d9cb9b91369f3476..fee6a08796f5210dffb0233a0dbb2a4c6875a14a 100644 --- a/mindspore/ccsrc/pynative/forward/pyboost/template/pyboost_api_h.tpl +++ b/mindspore/ccsrc/pynative/forward/pyboost/template/pyboost_api_h.tpl @@ -15,7 +15,9 @@ */ #include "pynative/utils/pynative_utils.h" +#include "pynative/forward/pyboost/converter.h" #include "pybind11/pybind11.h" +#include "op_def/auto_generate/gen_ops_def.h" #include namespace py = pybind11; diff --git a/mindspore/python/mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py b/mindspore/python/mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py index 32adba66afb68905d14d7a9fda267a3ae793d668..a9f2d4007b94bda75112228eacde7b0303c1dadd 100644 --- a/mindspore/python/mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +++ b/mindspore/python/mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py @@ -84,27 +84,7 @@ class TensorFuncRegCppGenerator(BaseGenerator): 'trace::CapturePy(parse_args.arg_list_, mindspore::prim::kPrim${class_name}, &res);\n' 'return res;\n' ) - ops_data = safe_load_yaml_from_dir(os.path.join(K.WORK_DIR, K.PARALLEL_OP_YAML_PATH)) - self.layout_infer_ops = list(ops_data.keys()) - self.pyboost_with_layout_infer_template = Template( - '${arg_handler_processor}\n' - 'MS_LOG(INFO) << "Call Tensor${class_name} with LayoutInfer";\n' - '// Construct py_args including self at the correct position\n' - 'py::list py_args;\n' - 'for (size_t i = 0; i < parse_args.arg_list_.size(); ++i) {\n' - ' py_args.append(parse_args.arg_list_[i]);\n' - '}\n' - 'auto res = mindspore::pynative::WithLayoutInfer(\n' - ' mindspore::prim::kPrim${class_name},\n' - ' [](const PrimitivePtr &prim, const std::vector &source_types${lambda_params}) {\n' - ' return mindspore::pynative::${pyboost_function}(prim, source_types${lambda_args});\n' - ' },\n' - ' py_args.ptr(),\n' - ' mindspore::prim::kPrim${class_name}, parse_args.src_types_${convert_args_comma}\n' - ');\n' - 'trace::CapturePy(parse_args.arg_list_, mindspore::prim::kPrim${class_name}, &res);\n' - 'return res;\n' - ) + self._init_with_layout_infer() self.callback_python_template = Template( 'py::object self_new = py::reinterpret_borrow(self);\n' @@ -131,6 +111,45 @@ class TensorFuncRegCppGenerator(BaseGenerator): "(PyObject* self, PyObject* py_args, PyObject* py_kwargs);\n" ) + def _init_with_layout_infer(self): + """ + Generates C++ tensor function WithLayoutInfer code for distributed ops. + """ + + self.layout_infer_ops = safe_load_yaml_from_dir(os.path.join(K.WORK_DIR, K.PARALLEL_OP_YAML_PATH)) + self.pyboost_with_layout_infer_template = dict() + prepare_pyargs_code = ('MS_LOG(INFO) << "Call Tensor${class_name} with LayoutInfer";\n' + '// Construct py_args including self at the correct position\n' + 'py::list py_args;\n' + 'for (size_t i = 0; i < parse_args.arg_list_.size(); ++i) {\n' + ' py_args.append(parse_args.arg_list_[i]);\n' + '}\n') + self.pyboost_with_layout_infer_template['default'] = Template( + '${arg_handler_processor}\n' + + prepare_pyargs_code + + 'auto res = mindspore::pynative::WithLayoutInfer${suffix}(\n' + ' mindspore::prim::kPrim${class_name},\n' + ' [](const PrimitivePtr &prim, const std::vector &source_types${lambda_params}) {\n' + ' return mindspore::pynative::${pyboost_function}(prim, source_types${lambda_args});\n' + ' },\n' + ' py_args.ptr(),\n' + ' mindspore::prim::kPrim${class_name}, parse_args.src_types_${convert_args_comma}\n' + ');\n' + 'trace::CapturePy(parse_args.arg_list_, mindspore::prim::kPrim${class_name}, &res);\n' + 'return res;\n' + ) + self.pyboost_with_layout_infer_template['without_parse'] = Template( + prepare_pyargs_code + + 'auto res = mindspore::pynative::WithLayoutInfer${suffix}(\n' + ' mindspore::prim::kPrim${class_name},\n' + ' [](const PrimitivePtr &prim, const std::vector &source_types${lambda_params}) {\n' + ' return mindspore::pynative::${pyboost_function}(prim, source_types${lambda_args});\n' + ' },\n' + ' py_args.ptr());\n' + 'trace::CapturePy(parse_args.arg_list_, mindspore::prim::kPrimReshape, &res);\n' + 'return res;\n' + ) + def generate(self, work_path, op_protos, func_protos_data, alias_func_mapping): """ Generates C++ header and source files for tensor function registrations. @@ -481,7 +500,7 @@ class TensorFuncRegCppGenerator(BaseGenerator): op_pyboost_func_name = op_parser.get_pyboost_func_name() + "_OP" convert_args_str = op_parser.get_convert_args_str(func_proto.op_proto, is_tensor_api=True) self_index = op_parser.get_input_tensor_index(func_proto.op_proto) - if func_proto.op_proto.op_class.name in self.layout_infer_ops: + if func_proto.op_proto.op_class.name in self.layout_infer_ops.keys(): num_args = len(func_proto.op_proto.op_args) lambda_params = "" lambda_args = "" @@ -493,13 +512,16 @@ class TensorFuncRegCppGenerator(BaseGenerator): convert_args_comma = ", " + convert_args_str if convert_args_str else "" - return self.pyboost_with_layout_infer_template.replace( + layout_infer_info = self.layout_infer_ops[func_proto.op_proto.op_class.name] + template_name = layout_infer_info.get('template_name', 'default') + return self.pyboost_with_layout_infer_template[template_name].replace( arg_handler_processor=arg_handler_processor_str, class_name=func_proto.op_proto.op_class.name, pyboost_function=op_pyboost_func_name, lambda_params=lambda_params, lambda_args=lambda_args, - convert_args_comma=convert_args_comma + convert_args_comma=convert_args_comma, + suffix=layout_infer_info.get('infer_layout_suffix', '') ) return self.pyboost_return_template.replace(arg_handler_processor=arg_handler_processor_str, class_name=func_proto.op_proto.op_class.name, diff --git a/mindspore/python/mindspore/ops_generate/common/template.py b/mindspore/python/mindspore/ops_generate/common/template.py index 0e503a08c3e754589ff3527c380bca45f7e9a041..340627ffbd7b351ae1ac66376ae6a491d0309217 100644 --- a/mindspore/python/mindspore/ops_generate/common/template.py +++ b/mindspore/python/mindspore/ops_generate/common/template.py @@ -459,5 +459,11 @@ OP_DEF_INC_HEAD_TEMPLATE = Template( LAYOUT_INFER_DEF_TEMPLATE = Template.load_from_file( os.path.join(K.WORK_DIR, './mindspore/python/mindspore/parallel/spmd/layout_infer_def.tpl')) -PYBOOST_API_BODY_WITH_LAYOUT_CC_TEMPLATE = Template.load_from_file( - os.path.join(K.WORK_DIR, './mindspore/python/mindspore/parallel/spmd/pyboost_api_body_with_layout_cc.tpl')) +PYBOOST_API_BODY_WITH_LAYOUT_CC_TEMPLATE = dict() +PYBOOST_API_BODY_WITH_LAYOUT_CC_TEMPLATE['default'] = Template.load_from_file( + os.path.join(K.WORK_DIR, './mindspore/python/mindspore/parallel/spmd/pyboost_api_body' + '/pyboost_api_body_with_layout_cc.tpl')) + +PYBOOST_API_BODY_WITH_LAYOUT_CC_TEMPLATE['without_parse'] = Template.load_from_file( + os.path.join(K.WORK_DIR, './mindspore/python/mindspore/parallel/spmd/pyboost_api_body' + '/pyboost_api_body_with_layout_without_parse_cc.tpl')) diff --git a/mindspore/python/mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py b/mindspore/python/mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py index 83a9d34fc0263b5ad64994c3d200f166c23fe658..cae745cf7f26a2dc0c42a39296532d95b046253c 100644 --- a/mindspore/python/mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +++ b/mindspore/python/mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py @@ -57,8 +57,7 @@ class PyboostFunctionsGenerator(BaseGenerator): self.OP_DEF_INC_HEAD_TEMPLATE = template.OP_DEF_INC_HEAD_TEMPLATE self.MARK_SIDE_EFFECT_STR = "PyNativeAlgo::PyBoost::MarkSideEffect(PyList_GetItem(args, 0));" self.pyboost_api_body_template = template.PYBOOST_API_BODY_CC_TEMPLATE - ops_data = safe_load_yaml_from_dir(os.path.join(K.WORK_DIR, K.PARALLEL_OP_YAML_PATH)) - self.layout_infer_ops = list(ops_data.keys()) + self.layout_infer_ops = safe_load_yaml_from_dir(os.path.join(K.WORK_DIR, K.PARALLEL_OP_YAML_PATH)) def generate(self, work_path, op_protos): """ @@ -136,7 +135,7 @@ class PyboostFunctionsGenerator(BaseGenerator): op_args_str = [op_arg.arg_name for op_arg in op_proto.op_args] side_effect_str = self._generate_mark_side_effect_str(op_proto) - if op_proto.op_class.name in self.layout_infer_ops: + if op_proto.op_class.name in self.layout_infer_ops.keys(): input_args = [arg.arg_name for arg in op_proto.op_args] lambda_params = [] lambda_args = [] @@ -148,7 +147,9 @@ class PyboostFunctionsGenerator(BaseGenerator): lambda_args_str = "".join(lambda_args) forward_args_str = "".join([f", {arg}" for arg in input_args]) - pyboost_api_body_str += template.PYBOOST_API_BODY_WITH_LAYOUT_CC_TEMPLATE.replace( + layout_infer_info = self.layout_infer_ops[op_proto.op_class.name] + template_name = layout_infer_info.get('template_name', 'default') + pyboost_api_body_str += template.PYBOOST_API_BODY_WITH_LAYOUT_CC_TEMPLATE[template_name].replace( func_name=op_pyboost_func_name, op_def_name=op_def_name_str, parser_body=parser_body_str, @@ -156,7 +157,8 @@ class PyboostFunctionsGenerator(BaseGenerator): lambda_params=lambda_params_str, lambda_args=lambda_args_str, forward_args=forward_args_str, - mark_side_effect=side_effect_str + mark_side_effect=side_effect_str, + suffix=layout_infer_info.get('infer_layout_suffix', '') ) else: pyboost_api_body_str += self.pyboost_api_body_template.replace( diff --git a/mindspore/python/mindspore/parallel/spmd/layout_infer_def.tpl b/mindspore/python/mindspore/parallel/spmd/layout_infer_def.tpl index f36b6de24da8b0235a280638b11438db34bf558e..5feac873916e3be0a3de52f65281c3a142d5f40b 100644 --- a/mindspore/python/mindspore/parallel/spmd/layout_infer_def.tpl +++ b/mindspore/python/mindspore/parallel/spmd/layout_infer_def.tpl @@ -73,9 +73,6 @@ PyObject* WithLayoutInfer(const PrimitivePtr &prim, Func &&func, PyObject* py_ar MS_LOG(EXCEPTION) << "Input args is not a list."; } py::list py_args_list = py::cast(py_args); - auto& cache_manager = LayoutCacheManager::GetInstance(); - auto& layout_cache = cache_manager.GetLayoutCache()[prim->name()]; - py::object distribute_op = cache_manager.GetDistributedOp(prim->name()); LayoutCacheKey cache_key; py::list input_layouts; @@ -120,6 +117,8 @@ PyObject* WithLayoutInfer(const PrimitivePtr &prim, Func &&func, PyObject* py_ar if (!contain_parallel_args) { return std::forward(func)(std::forward(args)...); } + auto& cache_manager = LayoutCacheManager::GetInstance(); + auto& layout_cache = cache_manager.GetLayoutCache()[prim->name()]; py::object output_layout; auto it = layout_cache.find(cache_key); @@ -127,6 +126,7 @@ PyObject* WithLayoutInfer(const PrimitivePtr &prim, Func &&func, PyObject* py_ar if (it != layout_cache.end()) { output_layout = it->second; } else { + py::object distribute_op = cache_manager.GetDistributedOp(prim->name()); py::tuple all_args = py::make_tuple(input_layouts, extra_args); output_layout = distribute_op.attr("infer_layout")(*all_args); layout_cache[cache_key] = output_layout; @@ -155,6 +155,173 @@ PyObject* WithLayoutInfer(const PrimitivePtr &prim, Func &&func, PyObject* py_ar obj.attr("_layout") = output_layout; } + return py_output; + } catch (const py::error_already_set &e) { + MS_LOG(ERROR) << "Python exception in layout inference: " << e.what(); + throw; + } catch (const std::exception &e) { + MS_LOG(ERROR) << "Exception in layout inference: " << e.what(); + throw; + } +} + +template +PyObject* WithLayoutInferReshape(const PrimitivePtr &prim, Func &&func, PyObject* py_args) { + try { + if (!py::isinstance(py_args)) { + MS_LOG(EXCEPTION) << "Input args is not a list."; + } + py::list py_args_list = py::cast(py_args); + + static Converter converter(&ops::gReshape); + converter.Parse(py_args); + auto source_type = converter.source_type(); + auto input = converter.ToTensor(py_args, kIndex0); + + const auto &input_py = py_args_list[kIndex0]; + if (!py::hasattr(input_py, "_layout")) { + auto shape = converter.ToBasicIntVector(py_args, kIndex1); + return func(prim, source_type, input, shape); + } + + // Input layouts. + LayoutCacheKey cache_key; + py::list input_layouts; + py::object layout = input_py.attr("_layout"); + input_layouts.append(layout); + py::object layout_id = layout.attr("compact_str"); + std::string id_str = py::cast(py::str(layout_id)); + cache_key.layout_ids.emplace_back(id_str); + + // Extra args: shape, input shape. + py::list extra_args; + extra_args.append(py_args_list[kIndex1]); + cache_key.layout_ids.emplace_back(py::str(py_args_list[kIndex1])); + const auto &input_shape = input_py.attr("shape"); + extra_args.append(input_shape); + cache_key.layout_ids.emplace_back(py::str(input_shape)); + + // Run Reshape infer layout or use cache. + auto& cache_manager = LayoutCacheManager::GetInstance(); + auto& layout_cache = cache_manager.GetLayoutCache()[prim->name()]; + py::object infer_output; + auto it = layout_cache.find(cache_key); + if (it != layout_cache.end()) { + infer_output = it->second; + } else { + py::object distribute_op = cache_manager.GetDistributedOp(prim->name()); + py::tuple all_args = py::make_tuple(input_layouts, extra_args); + infer_output = distribute_op.attr("infer_layout")(*all_args); + layout_cache[cache_key] = infer_output; + } + + // Run Reshape op with local destination shape. + py::tuple infer_output_tuple = py::cast(infer_output); + auto local_shape = infer_output_tuple[kIndex1]; + py::list converter_input; + converter_input.append(py::none()); + converter_input.append(local_shape); + auto local_shape_vector = converter.ToBasicIntVector(converter_input.ptr(), kIndex1); + auto py_output = func(prim, source_type, input, local_shape_vector); + + auto obj = py::reinterpret_borrow(py_output); + obj.attr("_layout") = infer_output_tuple[kIndex0]; + return py_output; + } catch (const py::error_already_set &e) { + MS_LOG(ERROR) << "Python exception in layout inference: " << e.what(); + throw; + } catch (const std::exception &e) { + MS_LOG(ERROR) << "Exception in layout inference: " << e.what(); + throw; + } +} + +template +PyObject* WithLayoutInferWithShape(const PrimitivePtr &prim, Func &&func, PyObject* py_args, Args &&... args) { + try { + if (!py::isinstance(py_args)) { + MS_LOG(EXCEPTION) << "Input args is not a list."; + } + py::list py_args_list = py::cast(py_args); + + LayoutCacheKey cache_key; + py::list input_layouts; + py::list extra_args; + py::list input_shapes; + bool contain_parallel_args = false; + + // Collect layout and no layout args + for (auto arg : py_args_list) { + if (arg.is_none()) { + input_layouts.append(py::none()); + input_shapes.append(py::none()); + continue; + } + if (!py::hasattr(arg, "_layout")) { + py::object arg_str = py::str(arg); + std::string id_str = py::cast(arg_str); + cache_key.layout_ids.emplace_back(id_str); + extra_args.append(arg); + input_layouts.append(py::none()); + } else { + contain_parallel_args = true; + py::object layout = arg.attr("_layout"); + py::object layout_id = layout.attr("compact_str"); + std::string id_str = py::cast(py::str(layout_id)); + cache_key.layout_ids.push_back(id_str); + input_layouts.append(layout); + } + + if (!py::hasattr(arg, "shape")) { + input_shapes.append(py::none()); + } else { + const auto &input_shape = arg.attr("shape"); + input_shapes.append(input_shape); + cache_key.layout_ids.emplace_back(py::str(input_shape)); + } + } + + if (!contain_parallel_args) { + return std::forward(func)(std::forward(args)...); + } + auto& cache_manager = LayoutCacheManager::GetInstance(); + auto& layout_cache = cache_manager.GetLayoutCache()[prim->name()]; + + py::object output_layout; + auto it = layout_cache.find(cache_key); + + if (it != layout_cache.end()) { + output_layout = it->second; + } else { + extra_args.append(input_shapes); + py::object distribute_op = cache_manager.GetDistributedOp(prim->name()); + py::tuple all_args = py::make_tuple(input_layouts, extra_args); + output_layout = distribute_op.attr("infer_layout")(*all_args); + layout_cache[cache_key] = output_layout; + } + + auto py_output = std::forward(func)(std::forward(args)...); + if (py::isinstance(py_output)) { + py::tuple output_tuple = py::cast(py_output); + if (py::isinstance(output_layout)) { + py::tuple layout_tuple = py::cast(output_layout); + if (output_tuple.size() == layout_tuple.size()) { + for (size_t i = 0; i < output_tuple.size(); ++i) { + output_tuple[i].attr("_layout") = layout_tuple[i]; + } + } else { + MS_LOG(ERROR) << "Output tuple size (" << output_tuple.size() + << ") does not match layout tuple size (" << layout_tuple.size() << ")"; + throw std::runtime_error("Output and layout tuple size mismatch"); + } + } else { + MS_LOG(ERROR) << "Output is a tuple but layout is not"; + throw std::runtime_error("Output is tuple but layout is not"); + } + } else { + auto obj = py::reinterpret_borrow(py_output); + obj.attr("_layout") = output_layout; + } return py_output; } catch (const py::error_already_set &e) { MS_LOG(ERROR) << "Python exception in layout inference: " << e.what(); diff --git a/mindspore/python/mindspore/parallel/spmd/ops/parallel_reshape.py b/mindspore/python/mindspore/parallel/spmd/ops/parallel_reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..47da3652d95f96cab99e48b6c0ec02f2082daf5c --- /dev/null +++ b/mindspore/python/mindspore/parallel/spmd/ops/parallel_reshape.py @@ -0,0 +1,157 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Distributed implementation for Reshape operator. +""" + +from mindspore.parallel import Layout +from mindspore.common.tensor import Tensor +from .parallel_ops import DistributedOp + + +class ReshapeDistributedOp(DistributedOp): + """Distributed implementation for MatMul operator.""" + + def _get_dynamic_shape_info(self, shape): + total_size = 1 + dynamic_axis = -1 + for axis, s in enumerate(shape): + total_size *= s + if s < 0: + dynamic_axis = axis + return total_size < 0, dynamic_axis, total_size + + def _handle_dynamic_shape(self, input_shape, output_shape): + """ + Check dynamic shape. Calculate unknown axis if one of input and output shape is known. If both are unknown, + calculate the relative multiple. + [2, -1, 8], [4, -1, 8] -> [2, -2, 8], [4, -1, 8] + """ + input_shape = list(input_shape) + output_shape = list(output_shape) + is_input_dynamic, input_dynamic_axis, input_total_size = self._get_dynamic_shape_info(input_shape) + is_output_dynamic, output_dynamic_axis, output_total_size = self._get_dynamic_shape_info(output_shape) + dynamic_can_shard = False + if not is_input_dynamic and not is_output_dynamic: + if input_total_size != output_total_size: + raise ValueError(f"The total elements number of input shape {input_shape} and output shape " + f"{output_shape} are different.") + return input_shape, output_shape, dynamic_can_shard + + if not is_input_dynamic: + accurate_output_shape = output_shape + accurate_output_shape[output_dynamic_axis] = -input_total_size // output_total_size + return input_shape, accurate_output_shape, dynamic_can_shard + + if not is_output_dynamic: + accurate_input_shape = input_shape + accurate_input_shape[input_dynamic_axis] = -output_total_size // input_total_size + return accurate_input_shape, output_shape, dynamic_can_shard + + if output_total_size >= input_total_size: + output_shape[output_dynamic_axis] = -(input_total_size // output_total_size) + dynamic_can_shard = True + else: + input_shape[input_dynamic_axis] = -(output_total_size // input_total_size) + return input_shape, output_shape, dynamic_can_shard + + def _merge_unshared_axis(self, global_shape, tensor_map): + """ + Merge those axes that are not sharded to the high dimension which is shared. + shape[4, 2, 6, 8], tensor map[-1, -1, 0, -1] -> merged shape[8, 48] + """ + merged_size = 1 + merged_shape = [] + merged_tensor_map = [] + for axis in range(len(global_shape) - 1, -1, -1): + merged_size *= global_shape[axis] + if tensor_map[axis] != -1: + merged_shape.insert(0, merged_size) + merged_tensor_map.insert(0, tensor_map[axis]) + merged_size = 1 + if tensor_map[0] == -1: + merged_shape.insert(0, merged_size) + merged_tensor_map.insert(0, -1) + return merged_shape, merged_tensor_map + + def infer_layout(self, input_layouts, extra_args): + """ + Infer output layout for reshape operator. + + For reshape operations, data slice on each device after reshape should be same as data slice before reshape. + + Args: + input_layouts (Layout): Layout of input x + extra_args: (destination shape, original shape) + + Returns: + tuple: Layout for output tensor + """ + x_layout = input_layouts[0] + x_dict = x_layout.to_dict() + + if len(extra_args) != 2: + raise ValueError("Reshape requires output shape and input shape.") + # Check output shape. + dst_shape = extra_args[0] + if isinstance(dst_shape, Tensor): + dst_shape = dst_shape.tolist() + if not isinstance(dst_shape, list) and not isinstance(dst_shape, tuple): + raise ValueError("Shape should be a tensor or a tuple or a list.") + + input_shape = extra_args[1] + + x_map = x_dict["tensor_map"] + x_device_matrix = x_dict["device_matrix"] + + input_shape, dst_shape, dynamic_can_shard = self._handle_dynamic_shape(input_shape, dst_shape) + merged_shape, merge_tensor_map = self._merge_unshared_axis(input_shape, x_map) + + output_tensor_map = [] + cur_axis = len(merged_shape) - 1 + cur_size = merged_shape[cur_axis] + for shape in reversed(dst_shape): + if cur_size % shape != 0: + raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}") + cur_size = cur_size // shape + if cur_size == 1: + shard_size = x_device_matrix[-merge_tensor_map[cur_axis] - 1] + if shape < 0: + if not dynamic_can_shard: + raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}") + elif shard_size > shape or shape % shard_size != 0: + raise ValueError(f"Can not reshape {input_shape} to {dst_shape} with tensor map {x_map}") + output_tensor_map.insert(0, merge_tensor_map[cur_axis]) + cur_axis -= 1 + cur_size = merged_shape[cur_axis] + else: + output_tensor_map.insert(0, -1) + + output_layout = Layout( + device_matrix=x_device_matrix, + alias_name=x_layout.alias_name, + rank_list=x_layout.rank_list + ) + output_map = [] + local_dst_shape = [] + for idx, map_id in enumerate(output_tensor_map): + if map_id < 0: + output_map.append("None") + local_dst_shape.append(dst_shape[idx] if dst_shape[idx] > 0 else -1) + else: + output_map.append(x_dict["alias_name"][-1 - map_id]) + local_dst_shape.append(dst_shape[idx] // x_device_matrix[-1 - map_id] if dst_shape[idx] > 0 else -1) + out_layout = output_layout(*output_map) + return out_layout, local_dst_shape diff --git a/mindspore/python/mindspore/parallel/spmd/ops/parallel_split.py b/mindspore/python/mindspore/parallel/spmd/ops/parallel_split.py new file mode 100644 index 0000000000000000000000000000000000000000..85d7b091ba494d3a8b2aabce23b34c9aef16a0df --- /dev/null +++ b/mindspore/python/mindspore/parallel/spmd/ops/parallel_split.py @@ -0,0 +1,57 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Distributed implementation for TopK operator. +""" + +from .parallel_ops import DistributedOp + + +class SplitDistributedOp(DistributedOp): + """Distributed implementation for TopK operator.""" + + def infer_layout(self, input_layouts, extra_args): + """ + Infer output layouts for Split operator. + + Rules: + 1. Shared axis can not be split. + + Args: + input_layouts (Layout): Layout of input tensor + extra_args (list): split size or sections, axis, input shape + + Returns: + tuple: Layouts for output tensors + """ + + input_layout = input_layouts[0] + axis = extra_args[1] + # Check shared axis can not be split. + tensor_map = input_layout.tensor_map + if tensor_map[axis] != -1: + raise ValueError(f"Can not split tensor at sharded axis[{axis}], layout: {input_layout}") + + split_size_or_sections = extra_args[0] + if isinstance(split_size_or_sections, (list, tuple)): + output_num = len(split_size_or_sections) + else: + input_shape = extra_args[2][0] + output_num = input_shape[axis] // split_size_or_sections + if input_shape[axis] % split_size_or_sections != 0: + output_num += 1 + + output_layouts = (input_layout,) * output_num + return output_layouts diff --git a/mindspore/python/mindspore/parallel/spmd/ops/yaml/element_wise_ops.yaml b/mindspore/python/mindspore/parallel/spmd/ops/yaml/element_wise_ops.yaml index f395f10b6854c69a405dec4b687f381934406792..ac860c154939f55e3aa49c9b5f5a0e51691004a6 100644 --- a/mindspore/python/mindspore/parallel/spmd/ops/yaml/element_wise_ops.yaml +++ b/mindspore/python/mindspore/parallel/spmd/ops/yaml/element_wise_ops.yaml @@ -61,4 +61,39 @@ Sigmoid: Silu: dist_op_name: _silu_dist_op distributed_op_class: ElementWiseDistributedOp + distributed_op_file: parallel_elementwise + +Cos: + dist_op_name: _cos_dist_op + distributed_op_class: ElementWiseDistributedOp + distributed_op_file: parallel_elementwise + +Sin: + dist_op_name: _sin_dist_op + distributed_op_class: ElementWiseDistributedOp + distributed_op_file: parallel_elementwise + +Log: + dist_op_name: _log_dist_op + distributed_op_class: ElementWiseDistributedOp + distributed_op_file: parallel_elementwise + +Neg: + dist_op_name: _neg_dist_op + distributed_op_class: ElementWiseDistributedOp + distributed_op_file: parallel_elementwise + +Exp: + dist_op_name: _exp_dist_op + distributed_op_class: ElementWiseDistributedOp + distributed_op_file: parallel_elementwise + +NotEqual: + dist_op_name: _not_equal_dist_op + distributed_op_class: ElementWiseDistributedOp + distributed_op_file: parallel_elementwise + +Cast: + dist_op_name: _cast_dist_op + distributed_op_class: ElementWiseDistributedOp distributed_op_file: parallel_elementwise \ No newline at end of file diff --git a/mindspore/python/mindspore/parallel/spmd/ops/yaml/reshape_ops.yaml b/mindspore/python/mindspore/parallel/spmd/ops/yaml/reshape_ops.yaml new file mode 100644 index 0000000000000000000000000000000000000000..91ffb3037973a09152d0d0a7c9b7ce92774d61d9 --- /dev/null +++ b/mindspore/python/mindspore/parallel/spmd/ops/yaml/reshape_ops.yaml @@ -0,0 +1,6 @@ +Reshape: + dist_op_name: _reshape_dist_op + distributed_op_class: ReshapeDistributedOp + distributed_op_file: parallel_reshape + template_name: without_parse + infer_layout_suffix: Reshape diff --git a/mindspore/python/mindspore/parallel/spmd/ops/yaml/split_ops.yaml b/mindspore/python/mindspore/parallel/spmd/ops/yaml/split_ops.yaml new file mode 100644 index 0000000000000000000000000000000000000000..63e22e9867568b5716e25235116dcff2de715ba8 --- /dev/null +++ b/mindspore/python/mindspore/parallel/spmd/ops/yaml/split_ops.yaml @@ -0,0 +1,16 @@ +Split: + dist_op_name: _split_dist_op + distributed_op_class: SplitDistributedOp + distributed_op_file: parallel_split + infer_layout_suffix: WithShape + +SplitTensor: + dist_op_name: _split_dist_op + distributed_op_class: SplitDistributedOp + distributed_op_file: parallel_split + infer_layout_suffix: WithShape + +SplitWithSize: + dist_op_name: _split_dist_op + distributed_op_class: SplitDistributedOp + distributed_op_file: parallel_split diff --git a/mindspore/python/mindspore/parallel/spmd/pyboost_api_body_with_layout_cc.tpl b/mindspore/python/mindspore/parallel/spmd/pyboost_api_body/pyboost_api_body_with_layout_cc.tpl similarity index 94% rename from mindspore/python/mindspore/parallel/spmd/pyboost_api_body_with_layout_cc.tpl rename to mindspore/python/mindspore/parallel/spmd/pyboost_api_body/pyboost_api_body_with_layout_cc.tpl index 59dca2b3fa2065049b91f3d4c2fa969b490006f2..949aff487444c9963ec7382cc5d9decd1f309465 100644 --- a/mindspore/python/mindspore/parallel/spmd/pyboost_api_body_with_layout_cc.tpl +++ b/mindspore/python/mindspore/parallel/spmd/pyboost_api_body/pyboost_api_body_with_layout_cc.tpl @@ -6,7 +6,7 @@ PYNATIVE_EXPORT PyObject* ${func_name}_Base(const PrimitivePtr &prim, PyObject* ${parser_body} auto source_type = converter.source_type(); - return WithLayoutInfer( + return WithLayoutInfer${suffix}( prim, [](const PrimitivePtr &p, const std::vector &st${lambda_params}) { return ${func_name}_OP(p, st${lambda_args}); @@ -20,4 +20,3 @@ PYNATIVE_EXPORT PyObject* ${func_name}_Base(const PrimitivePtr &prim, PyObject* return res.release().ptr(); #endif } - diff --git a/mindspore/python/mindspore/parallel/spmd/pyboost_api_body/pyboost_api_body_with_layout_without_parse_cc.tpl b/mindspore/python/mindspore/parallel/spmd/pyboost_api_body/pyboost_api_body_with_layout_without_parse_cc.tpl new file mode 100644 index 0000000000000000000000000000000000000000..9138ab6497a79e1d7e8761d63495d65ec20c20ba --- /dev/null +++ b/mindspore/python/mindspore/parallel/spmd/pyboost_api_body/pyboost_api_body_with_layout_without_parse_cc.tpl @@ -0,0 +1,15 @@ +PYNATIVE_EXPORT PyObject* ${func_name}_Base(const PrimitivePtr &prim, PyObject* args) { +#ifndef ENABLE_TEST + ${mark_side_effect} + return WithLayoutInfer${suffix}( + prim, + [](const PrimitivePtr &p, const std::vector &st${lambda_params}) { + return ${func_name}_OP(p, st${lambda_args}); + }, + args); +#else + py::object py_args = py::reinterpret_borrow(args); + py::object res = PyNativeAlgo::PyBoost::RunPyFunction(prim, py_args); + return res.release().ptr(); +#endif +} diff --git a/tests/st/auto_parallel/spmd/shard/ops/reshape/parallel_reshape.py b/tests/st/auto_parallel/spmd/shard/ops/reshape/parallel_reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..f87a95f9f36ed7343e930ba2fc794178280729b8 --- /dev/null +++ b/tests/st/auto_parallel/spmd/shard/ops/reshape/parallel_reshape.py @@ -0,0 +1,154 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import time +import numpy as np +import mindspore as ms +import mindspore.communication.management as D +from mindspore import nn, Tensor +from mindspore.parallel import Layout +from mindspore.parallel.spmd.hsdp.hsdp import hsdp + +learning_rate = 0.01 +epochs = 2 + + +class SimpleModel(nn.Cell): + def __init__(self, input_size, output_size): + super().__init__() + self.weight = ms.Parameter( + Tensor(np.ones([input_size, output_size]).astype(np.float32)), + name='weight' + ) + + self.relu = ms.mint.nn.ReLU() + + def construct(self, x): + target_shape = x.shape[:-2] + (self.weight.shape[0],) + x = ms.ops.reshape(x, target_shape) + x = ms.mint.matmul(x, self.weight) + x = self.relu(x) + return x + + +def create_dtensor(data, layout): + """create_dtensor""" + tensor = Tensor(data, dtype=ms.float32) + return tensor.local_to_global(layout) + + +def create_tensor(data): + return Tensor(data, dtype=ms.float32) + + +def run_standalone(x, input_size, output_size): + model = SimpleModel(input_size, output_size) + + def forward_fn(data): + logits = model(data) + return logits + + optimizer = nn.Adam(model.trainable_params(), learning_rate=learning_rate) + grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False) + + x = create_tensor(x) + + ret_loss = None + ret_grads = None + for epoch in range(epochs): + start = time.time() + (loss_value, grads) = grad_fn(x) + optimizer(grads) + end = time.time() + ret_loss = loss_value + ret_grads = grads + print(f"[standalone] Epoch: {epoch+1}/{epochs}, Loss: {loss_value}, Time: {end - start}") + + return ret_loss, ret_grads + + +def run_parallel(local_x, local_input_size, local_output_size, x_layout, w_layout, relu_strategy, hsdp_shard_size): + model = SimpleModel(local_input_size, local_output_size) + + model.weight = model.weight.local_to_global(w_layout) + model.shard(in_strategy=(x_layout,)) + model.relu.shard(in_strategy=relu_strategy[0], out_strategy=relu_strategy[1]) + model = hsdp(model, shard_size=hsdp_shard_size) + + def forward_fn(data): + logits = model(data) + return logits + + optimizer = nn.Adam(model.trainable_params(), learning_rate=learning_rate) + grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False) + + x = create_dtensor(local_x, x_layout) + + ret_loss = None + ret_grads = None + for epoch in range(epochs): + start = time.time() + (loss_value, grads) = grad_fn(x) + optimizer(grads) + end = time.time() + ret_loss = loss_value + ret_grads = grads + print(f"[parallel] Epoch: {epoch+1}/{epochs}, Loss: {loss_value}, Time: {end - start}") + + return ret_loss, ret_grads + + +def base_case(dp, mp): + D.init() + + # standalone + input_size = 32 + output_size = 2 + batch_size = 4 + x_last_dim = 4 + + x = np.ones([batch_size, input_size// x_last_dim, x_last_dim]).astype(np.float32) + standalone_loss, standalone_grads = run_standalone(x, input_size, output_size) + + # parallel + hsdp_shard_size = dp + local_batch_size = batch_size // dp + local_input_size = input_size // mp + local_output_size = output_size + local_x = np.ones([local_batch_size, local_input_size // x_last_dim, x_last_dim]).astype(np.float32) + layout = Layout((dp, mp), ("dp", "mp")) + x_layout = layout("dp", "mp", "None") + w_layout = layout("mp", "None") + relu_strategy = ((layout("dp", "None"),), (layout("dp", "None"),)) + parallel_loss, parallel_grads = run_parallel(local_x, local_input_size, local_output_size, x_layout, w_layout, + relu_strategy, hsdp_shard_size) + + # compare loss + assert np.allclose(standalone_loss.asnumpy(), parallel_loss.asnumpy(), 0.001, 0.001) + + standalone_grad = standalone_grads[0].asnumpy() + # note: this way of obtaining grad slice is a simplified way, not a strict way + standalone_grad_slice = standalone_grad[:local_input_size, :local_output_size] + parallel_grad = parallel_grads[0].asnumpy() + assert np.allclose(standalone_grad_slice, parallel_grad, 0.001, 0.001) + + +def test_parallel_reshape_0(): + ''' + Feature: Reshape parallel op. + Description: Test Reshape parallel op. + Expectation: Run success. + ''' + base_case(dp=4, mp=2) diff --git a/tests/st/auto_parallel/spmd/shard/ops/reshape/test_parallel_reshape.py b/tests/st/auto_parallel/spmd/shard/ops/reshape/test_parallel_reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..2115b37028da67c05f9ae6907233deb86ca607d9 --- /dev/null +++ b/tests/st/auto_parallel/spmd/shard/ops/reshape/test_parallel_reshape.py @@ -0,0 +1,29 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +from tests.mark_utils import arg_mark + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_shard_in_python_pynative(): + ''' + Feature: run shard in python. + Description: Test shard in python. + Expectation: Run success. + ''' + return_code = os.system( + "msrun --worker_num=8 --local_worker_num=8 --master_addr=127.0.0.1 --master_port=10677 --join=True " \ + "pytest -s parallel_reshape.py" + ) + assert return_code == 0 diff --git a/tests/st/auto_parallel/spmd/shard/ops/split/parallel_split.py b/tests/st/auto_parallel/spmd/shard/ops/split/parallel_split.py new file mode 100644 index 0000000000000000000000000000000000000000..7baf19dea1a0ec045e6b54c30fbea4373e4fa626 --- /dev/null +++ b/tests/st/auto_parallel/spmd/shard/ops/split/parallel_split.py @@ -0,0 +1,131 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import time +import numpy as np +import mindspore as ms +import mindspore.communication.management as D +from mindspore import nn, Tensor +from mindspore.parallel import Layout, hsdp, init_parameters +from mindspore.nn.utils import no_init_parameters +from mindspore.common.initializer import initializer + +learning_rate = 0.01 +epochs = 2 + +class SimpleModel(nn.Cell): + def __init__(self, batch_size, input_size, output_size): + super().__init__() + self.batch_size = batch_size + self.weight = ms.Parameter(initializer("ones", [input_size, output_size], ms.float32), name='weight') + self.relu = ms.mint.nn.ReLU() + + def construct(self, x): + x, _, _ = ms.mint.split(x, self.batch_size, 0) + x = ms.mint.matmul(x, self.weight) + x = self.relu(x) + x = ms.mint.sum(x) + return x + + +def create_dtensor(data, layout): + """create_dtensor""" + tensor = Tensor(data, dtype=ms.float32) + return tensor.local_to_global(layout) + + +def run_model(x, model): + def forward_fn(data): + logits = model(data) + return logits + + optimizer = nn.Adam(model.trainable_params(), learning_rate=learning_rate) + grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False) + + ret_loss = None + ret_grads = None + for epoch in range(epochs): + start = time.time() + (loss_value, grads) = grad_fn(x) + optimizer(grads) + end = time.time() + ret_loss = loss_value + ret_grads = grads + print(f"[standalone] Epoch: {epoch+1}/{epochs}, Loss: {loss_value}, Time: {end - start}") + + return ret_loss, ret_grads + + +def base_case(dp, mp, hsdp_shard_size): + D.init() + + # standalone + input_size = 32 + output_size = 2 + batch_size = 4 + + standalone_x = Tensor(np.ones([batch_size * 3, input_size]).astype(np.float32), dtype=ms.float32) + standalone_model = SimpleModel(batch_size, input_size, output_size) + standalone_loss, standalone_grads = run_model(standalone_x, standalone_model) + + # parallel + local_batch_size = batch_size + local_input_size = input_size // mp + local_output_size = output_size + local_x = np.ones([local_batch_size * 3, local_input_size]).astype(np.float32) + layout = Layout((dp, mp), ("dp", "mp")) + x_layout = layout("None", "mp") + w_layout = layout("mp", "None") + out_layout = layout() + relu_strategy = ((layout("dp", "None"),), (layout("dp", "None"),)) + + # step 1: define network with no init parameters + with no_init_parameters(): + model = SimpleModel(batch_size, input_size, output_size) + + # step 2: shard + model.shard(in_strategy=(x_layout,), out_strategy=(out_layout,), parameter_plan={"weight": w_layout}) + model.relu.shard(in_strategy=relu_strategy[0], out_strategy=relu_strategy[1]) + + # step 3: hsdp + model = hsdp(model, shard_size=hsdp_shard_size, threshold=0) + + # step 4: init parameters + model = init_parameters(model) + + x = create_dtensor(local_x, x_layout) + parallel_loss, parallel_grads = run_model(x, model) + + # compare loss + assert np.allclose(standalone_loss.asnumpy(), parallel_loss.asnumpy(), 0.001, 0.001) + + # compare grad + if hsdp_shard_size < 0: + hsdp_shard_size = dp + + standalone_grad = standalone_grads[0].asnumpy() + # note: this way of obtaining grad slice is a simplified way, not a strict way + standalone_grad_slice = standalone_grad[:local_input_size // hsdp_shard_size, :local_output_size] + parallel_grad = parallel_grads[0].asnumpy() + assert np.allclose(standalone_grad_slice, parallel_grad, 0.001, 0.001) + + +def test_parallel_split(): + ''' + Feature: with no_init_parameters + cell shard + hsdp + init param + loss repeat + partial. + Description: Test base shard. + Expectation: Run success. + ''' + base_case(dp=4, mp=2, hsdp_shard_size=4) diff --git a/tests/st/auto_parallel/spmd/shard/ops/split/test_parallel_split.py b/tests/st/auto_parallel/spmd/shard/ops/split/test_parallel_split.py new file mode 100644 index 0000000000000000000000000000000000000000..e1af0242e9caf9ad91b0e57c864549d34d80c4e8 --- /dev/null +++ b/tests/st/auto_parallel/spmd/shard/ops/split/test_parallel_split.py @@ -0,0 +1,29 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +from tests.mark_utils import arg_mark + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level1", card_mark="allcards", essential_mark="unessential") +def test_shard_in_python_pynative(): + ''' + Feature: run shard in python. + Description: Test shard in python. + Expectation: Run success. + ''' + return_code = os.system( + "msrun --worker_num=8 --local_worker_num=8 --master_addr=127.0.0.1 --master_port=10677 --join=True " \ + "pytest -s parallel_split.py" + ) + assert return_code == 0 diff --git a/tests/ut/python/parallel/parallel_ops_infer/test_parallel_reshape.py b/tests/ut/python/parallel/parallel_ops_infer/test_parallel_reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..ca4252fa2d579720955c080864e30000cce0c68a --- /dev/null +++ b/tests/ut/python/parallel/parallel_ops_infer/test_parallel_reshape.py @@ -0,0 +1,251 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from mindspore.parallel import Layout +from mindspore.parallel.spmd.ops.parallel_ops_register import get_distributed_op + +op = get_distributed_op("Reshape") + + +def test_reshape_layout_not_change_sharded_axis(): + """ + Feature: Reshape do not change sharded axis + Description: Reshape do not change sharded axis + Expectation: Success + """ + base_device_matrix = (2, 4) + base_alias_name = ("dp", "mp") + base_rank_list = list(range(8)) + + x_layout = Layout(base_device_matrix, base_alias_name, base_rank_list) + x_layout = x_layout("dp", "None", "None") + src_shape = (1024, 512, 512) + dst_shape = (1024, 2, 256, 512) + + output_layout, local_dst_shape = op.infer_layout((x_layout,), (dst_shape, src_shape)) + expected_map = (1, -1, -1, -1) # Expected output tensor map + assert output_layout.tensor_map == expected_map, (f"Reshape do not change sharded axis failed. Expected " + f"expected_map {expected_map} bug got {output_layout.tensor_map}") + expected_local_dst_shape = [512, 2, 256, 512] # Expected local dst shape + assert local_dst_shape == expected_local_dst_shape, (f"Reshape do not change sharded axis failed. Expected " + f"expected_local_dst_shape {expected_local_dst_shape} got " + f"{local_dst_shape}") + + +def test_reshape_layout_merge_sharded_axis(): + """ + Feature: Reshape merge shared axis with not shared axis + Description: Reshape merge shared axis with not shared axis + Expectation: Success + """ + base_device_matrix = (2, 4) + base_alias_name = ("dp", "mp") + base_rank_list = list(range(8)) + + x_layout = Layout(base_device_matrix, base_alias_name, base_rank_list) + x_layout = x_layout("dp", "None", "None") + src_shape = (4, 4, 8) + dst_shape = (16, 8) + + output_layout, local_dst_shape = op.infer_layout((x_layout,), (dst_shape, src_shape)) + expected_map = (1, -1) # Expected output tensor map + assert output_layout.tensor_map == expected_map, (f"Reshape do not change sharded axis failed. Expected " + f"expected_map {expected_map} bug got {output_layout.tensor_map}") + expected_local_dst_shape = [8, 8] + assert local_dst_shape == expected_local_dst_shape, (f"Reshape do not change sharded axis failed. Expected " + f"expected_local_dst_shape {expected_local_dst_shape} got " + f"{local_dst_shape}") + + +def test_reshape_layout_split_sharded_axis(): + """ + Feature: Reshape split shared asix + Description: Reshape do not change sharded axis + Expectation: Success + """ + base_device_matrix = (2, 4) + base_alias_name = ("dp", "mp") + base_rank_list = list(range(8)) + + x_layout = Layout(base_device_matrix, base_alias_name, base_rank_list) + x_layout = x_layout("dp", "None", "None") + src_shape = (32, 128) + dst_shape = (4, 8, 128) + + output_layout, local_dst_shape = op.infer_layout((x_layout,), (dst_shape, src_shape)) + expected_map = (1, -1, -1) # Expected output tensor map + assert output_layout.tensor_map == expected_map, (f"Reshape do not change sharded axis failed. Expected " + f"expected_map {expected_map} bug got {output_layout.tensor_map}") + expected_local_dst_shape = [2, 8, 128] + assert local_dst_shape == expected_local_dst_shape, (f"Reshape do not change sharded axis failed. Expected " + f"expected_local_dst_shape {expected_local_dst_shape} got " + f"{local_dst_shape}") + + +def test_reshape_layout_multi_axes_shared(): + """ + Feature: Reshape split, merge, resize axes + Description: Reshape split, merge, resize axes + Expectation: Success + """ + base_device_matrix = (2, 2, 2) + base_alias_name = ("dp", "mp", "cp") + base_rank_list = list(range(8)) + + x_layout = Layout(base_device_matrix, base_alias_name, base_rank_list) + x_layout = x_layout("cp", "dp", "None", "mp", "None") + src_shape = (32, 6, 128, 28, 10) + dst_shape = (4, 8, 2, 384, 280) + + output_layout, local_dst_shape = op.infer_layout((x_layout,), (dst_shape, src_shape)) + expected_map = (0, -1, 2, -1, 1) # Expected output tensor map + assert output_layout.tensor_map == expected_map, (f"Reshape do not change sharded axis failed. Expected " + f"expected_map {expected_map} bug got {output_layout.tensor_map}") + expected_local_dst_shape = [2, 8, 1, 384, 140] + assert local_dst_shape == expected_local_dst_shape, (f"Reshape do not change sharded axis failed. Expected" + f" expected_local_dst_shape {expected_local_dst_shape} got" + f" {local_dst_shape}") + + +def test_reshape_layout_can_not_reshape1(): + """ + Feature: Reshape can not be shared + Description: Can not be reshaped + Expectation: Fail + """ + base_device_matrix = (2, 2, 2) + base_alias_name = ("dp", "mp", "cp") + base_rank_list = list(range(8)) + + x_layout = Layout(base_device_matrix, base_alias_name, base_rank_list) + x_layout = x_layout("None", "None", "None", "mp") + src_shape = (4, 8, 4, 12) + dst_shape = (4, 8, 12, 4) + + with pytest.raises(ValueError): + _, _ = op.infer_layout((x_layout,), (dst_shape, src_shape)) + + +def test_reshape_layout_can_not_reshape2(): + """ + Feature: Reshape can not be shared + Description: Can not be reshaped + Expectation: Fail + """ + base_device_matrix = (2, 4) + base_alias_name = ("dp", "mp") + base_rank_list = list(range(8)) + + x_layout = Layout(base_device_matrix, base_alias_name, base_rank_list) + x_layout = x_layout("None", "None", "mp", "None") + src_shape = (4, 8, 12, 7) + dst_shape = (4, 8, 2, 42) + + with pytest.raises(ValueError): + _, _ = op.infer_layout((x_layout,), (dst_shape, src_shape)) + + +def test_reshape_layout_dynamic_shape1(): + """ + Feature: Reshape parallel op with dynamic shape + Description: Reshape parallel op with dynamic shape + Expectation: Success + """ + base_device_matrix = (2, 4) + base_alias_name = ("dp", "mp") + base_rank_list = list(range(8)) + + x_layout = Layout(base_device_matrix, base_alias_name, base_rank_list) + x_layout = x_layout("dp", "None", "None", "None") + src_shape = (1024, -1, 256, 512) + dst_shape = (1024, -1, 512) + + output_layout, local_dst_shape = op.infer_layout((x_layout,), (dst_shape, src_shape)) + expected_map = (1, -1, -1) # Expected output tensor map + assert output_layout.tensor_map == expected_map, (f"Reshape do not change sharded axis failed. Expected " + f"expected_map {expected_map} bug got {output_layout.tensor_map}") + expected_local_dst_shape = [512, -1, 512] # Expected local dst shape + assert local_dst_shape == expected_local_dst_shape, (f"Reshape do not change sharded axis failed. Expected " + f"expected_local_dst_shape {expected_local_dst_shape} got " + f"{local_dst_shape}") + + +def test_reshape_layout_dynamic_shape2(): + """ + Feature: Reshape parallel op with dynamic shape + Description: Reshape parallel op with dynamic shape + Expectation: Success + """ + base_device_matrix = (2, 4) + base_alias_name = ("dp", "mp") + base_rank_list = list(range(8)) + + x_layout = Layout(base_device_matrix, base_alias_name, base_rank_list) + x_layout = x_layout("dp", "None", "None") + src_shape = (1024, -1, 512) + dst_shape = (1024, -1, 256, 512) + + output_layout, local_dst_shape = op.infer_layout((x_layout,), (dst_shape, src_shape)) + expected_map = (1, -1, -1, -1) # Expected output tensor map + assert output_layout.tensor_map == expected_map, (f"Reshape do not change sharded axis failed. Expected " + f"expected_map {expected_map} bug got {output_layout.tensor_map}") + expected_local_dst_shape = [512, -1, 256, 512] # Expected local dst shape + assert local_dst_shape == expected_local_dst_shape, (f"Reshape do not change sharded axis failed. Expected " + f"expected_local_dst_shape {expected_local_dst_shape} got " + f"{local_dst_shape}") + + +def test_reshape_layout_dynamic_shape3(): + """ + Feature: Reshape parallel op with dynamic shape + Description: Reshape parallel op with dynamic shape + Expectation: Success + """ + base_device_matrix = (2, 4) + base_alias_name = ("dp", "mp") + base_rank_list = list(range(8)) + + x_layout = Layout(base_device_matrix, base_alias_name, base_rank_list) + x_layout = x_layout("dp", "None", "None") + src_shape = (-1, 256, 512) + dst_shape = (-1, 512) + + output_layout, local_dst_shape = op.infer_layout((x_layout,), (dst_shape, src_shape)) + expected_map = (1, -1) # Expected output tensor map + assert output_layout.tensor_map == expected_map, (f"Reshape do not change sharded axis failed. Expected " + f"expected_map {expected_map} bug got {output_layout.tensor_map}") + expected_local_dst_shape = [-1, 512] # Expected local dst shape + assert local_dst_shape == expected_local_dst_shape, (f"Reshape do not change sharded axis failed. Expected " + f"expected_local_dst_shape {expected_local_dst_shape} got " + f"{local_dst_shape}") + + +def test_reshape_layout_dynamic_shape4(): + """ + Feature: Reshape parallel op with dynamic shape + Description: Reshape parallel op with dynamic shape + Expectation: Success + """ + base_device_matrix = (2, 4) + base_alias_name = ("dp", "mp") + base_rank_list = list(range(8)) + + x_layout = Layout(base_device_matrix, base_alias_name, base_rank_list) + x_layout = x_layout("dp", "None", "None") + src_shape = (-1, 512) + dst_shape = (-1, 256, 512) + + with pytest.raises(ValueError): + _, _ = op.infer_layout((x_layout,), (dst_shape, src_shape)) diff --git a/tests/ut/python/parallel/parallel_ops_infer/test_parallel_split.py b/tests/ut/python/parallel/parallel_ops_infer/test_parallel_split.py new file mode 100644 index 0000000000000000000000000000000000000000..3eaa5fefec20a62329b8117075150973c0ecc67d --- /dev/null +++ b/tests/ut/python/parallel/parallel_ops_infer/test_parallel_split.py @@ -0,0 +1,90 @@ +import pytest +from mindspore.parallel.spmd.ops.parallel_split import SplitDistributedOp +from mindspore.parallel import Layout + +# 初始化一个SplitDistributedOp实例 +split_op = SplitDistributedOp("split") + + +# 定义一个辅助函数来创建Layout对象 +def create_layout(tensor_map): + base_device_matrix = (2, 2, 2) + base_alias_name = ("dp", "mp", "cp") + base_rank_list = list(range(8)) + x_layout = Layout(base_device_matrix, base_alias_name, base_rank_list) + alias_names = [(base_alias_name[-idx - 1] if idx >= 0 else "None") for idx in tensor_map] + x_layout = x_layout(*alias_names) + return x_layout + + +def test_infer_layout_normal(): + """ + Feature: Split operator layout inference under normal conditions + Description: Test normal split where axis is not sharded + Expectation: Output layouts are correctly generated with same tensor_map + """ + input_layout = create_layout([1, -1, 0]) + axis = 1 + split_size_or_sections = 2 + input_shape = [4, 6, 8] + extra_args = [split_size_or_sections, axis, [input_shape,]] + + output_layouts = split_op.infer_layout([input_layout], extra_args) + + expected_output_num = input_shape[axis] // split_size_or_sections + \ + (1 if input_shape[axis] % split_size_or_sections != 0 else 0) + assert len(output_layouts) == expected_output_num + assert all(layout.tensor_map == input_layout.tensor_map for layout in output_layouts) + + +def test_infer_layout_invalid_axis(): + """ + Feature: Split operator layout inference with invalid axis + Description: Test when trying to split a sharded axis (which is not allowed) + Expectation: ValueError is raised + """ + input_layout = create_layout([1, 0, -1]) + axis = 0 + split_size_or_sections = 2 + input_shape = [4, 6, 8] + extra_args = [split_size_or_sections, axis, [input_shape,]] + + with pytest.raises(ValueError): + split_op.infer_layout([input_layout], extra_args) + + +def test_infer_layout_with_sections(): + """ + Feature: Split operator layout inference with sections list + Description: Test split using a list of section sizes + Expectation: Output number matches the length of sections list + """ + input_layout = create_layout([-1, 1, -1]) + axis = 2 + split_size_or_sections = [2, 3, 3] + input_shape = [4, 6, 8] + extra_args = [split_size_or_sections, axis, [input_shape,]] + + output_layouts = split_op.infer_layout([input_layout], extra_args) + + assert len(output_layouts) == len(split_size_or_sections) + assert all(layout.tensor_map == input_layout.tensor_map for layout in output_layouts) + + +def test_infer_layout_with_remainder(): + """ + Feature: Split operator layout inference with non-divisible size + Description: Test split when input shape is not divisible by split size + Expectation: Output count includes an extra tensor for the remainder + """ + input_layout = create_layout([-1, -1, 0]) + axis = 1 + split_size_or_sections = 3 + input_shape = [5, 7, 9] + extra_args = [split_size_or_sections, axis, [input_shape,]] + + output_layouts = split_op.infer_layout([input_layout], extra_args) + + expected_output_num = input_shape[axis] // split_size_or_sections + 1 + assert len(output_layouts) == expected_output_num + assert all(layout.tensor_map == input_layout.tensor_map for layout in output_layouts)