diff --git a/hyper_parallel/__init__.py b/hyper_parallel/__init__.py index 808a271bf94ce723d3ab76431c24bc91b0fbef5f..55883b58ecbc041c5e0ccac05b8edb7641dc36f0 100644 --- a/hyper_parallel/__init__.py +++ b/hyper_parallel/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ "hyper parallel interface" -from hyper_parallel.platform import set_platform, get_platform +from hyper_parallel.platform import get_platform from hyper_parallel.core.hsdp.hsdp_api import hsdp, hsdp_wait_grad_handle, HSDPCell from hyper_parallel.core.layout import Layout from hyper_parallel.core.dtensor import DTensor, SkipDTensorDispatch @@ -23,5 +23,5 @@ from hyper_parallel.core.tensor_parallel.shard import parallelize_value_and_grad from hyper_parallel.core.tensor_parallel.local_func import custom_shard -__all__ = ["set_platform", "get_platform", "hsdp", "hsdp_wait_grad_handle", "HSDPCell", "Layout", "DTensor", +__all__ = ["get_platform", "hsdp", "hsdp_wait_grad_handle", "HSDPCell", "Layout", "DTensor", "init_parameters", "shard", "custom_shard", "parallelize_value_and_grad", "SkipDTensorDispatch"] diff --git a/hyper_parallel/core/hsdp/hsdp_api.py b/hyper_parallel/core/hsdp/hsdp_api.py index 8789dcba539c7e2ba11499987a1c4cd21e0d8775..9480c2a3c8dd4980d598fcd8c7254aac2d1c4512 100644 --- a/hyper_parallel/core/hsdp/hsdp_api.py +++ b/hyper_parallel/core/hsdp/hsdp_api.py @@ -14,7 +14,10 @@ # ============================================================================ """hybrid shard data parallel interface""" from typing import Optional, Any -from hyper_parallel.core.hsdp.hsdp_utils import PlatformType, OptimizerLevel +from hyper_parallel.core.hsdp.hsdp_utils import OptimizerLevel +from hyper_parallel.platform.platform import PlatformType +from hyper_parallel.platform import get_platform +platform = get_platform() origin_class_to_extend_class = {} optimizer_level_map = { @@ -22,7 +25,7 @@ optimizer_level_map = { "level2": OptimizerLevel.SHARD_OPT_GRAD, "level3": OptimizerLevel.SHARD_OPT_GRAD_PARAM, } -current_platform = None + class HSDPCell: """ @@ -31,23 +34,17 @@ class HSDPCell: Supported Platforms: ``MindSpore`` ``torch`` """ + # pylint: disable=C0415 def hsdp_init(self, platform_type, cell, shard_size, threshold, optimizer_level, enable_grad_accumulation, grad_scale, reduce_dtype, comm_async, comm_fusion, bucket_size): """init hsdp scheduler.""" scheduler_class = None - self.platform_type = platform_type - global current_platform - if self.platform_type == PlatformType.MINDSPORE: - from hyper_parallel.platform.mindspore.platform import MindSporePlatform + if platform_type == PlatformType.MINDSPORE: from hyper_parallel.platform.mindspore.hsdp.scheduler import MindSporeHSDPScheduler - current_platform = MindSporePlatform() scheduler_class = MindSporeHSDPScheduler else: - from hyper_parallel.platform.torch.platform import TorchPlatform from hyper_parallel.platform.torch.hsdp.scheduler import TorchHSDPScheduler - current_platform = TorchPlatform() scheduler_class = TorchHSDPScheduler - self.hsdp_scheduler = scheduler_class(cell, shard_size, @@ -117,9 +114,23 @@ def _extend_cell_with_hsdp_interface(cell): origin_class_to_extend_class[origin_class] = extend_class cell.__class__ = extend_class -def _check_hsdp_input_valid(platform_type, shard_size, threshold, optimizer_level, enable_grad_accumulation, grad_scale, - reduce_dtype, comm_async, comm_fusion, bucket_size): +# pylint: disable=C0415 +def _check_cell_valid(platform_type, cell): + """check cell valid""" + if platform_type == PlatformType.MINDSPORE: + from mindspore.nn.cell import Cell + if not isinstance(cell, Cell): + raise ValueError(f"cell's type must be nn.cell but got {type(cell)}.") + else: + from torch.nn import Module + if not isinstance(cell, Module): + raise ValueError(f"cell's type must be nn.Module but got {type(cell)}.") + +# pylint: disable=C0415 +def _check_hsdp_input_valid(platform_type, cell, shard_size, threshold, optimizer_level, enable_grad_accumulation, + grad_scale, reduce_dtype, comm_async, comm_fusion, bucket_size): """check hsdp input valid""" + _check_cell_valid(platform_type, cell) if not isinstance(shard_size, int) or (shard_size <= 0 and shard_size != -1): raise ValueError(f"shard_size must be a positive integer, but got {shard_size}.") if not isinstance(threshold, int) or threshold < 0: @@ -212,19 +223,11 @@ def hsdp( ValueError: If `comm_fusion` is not bool. ValueError: If the `bucket_size` is not a positive integer or -1. """ - platform_type = PlatformType.MINDSPORE - try: - from mindspore.nn.cell import Cell - if not isinstance(cell, Cell): - raise ValueError(f"cell's type must be nn.Module but got {type(cell)}.") - except ImportError: - from torch.nn import Module as Cell - platform_type = PlatformType.PYTORCH - if not isinstance(cell, Cell): - raise ValueError(f"cell's type must be nn.Module but got {type(cell)}.") + platform_type = platform.platform_type _check_hsdp_input_valid( platform_type, + cell, shard_size, threshold, optimizer_level, @@ -254,7 +257,6 @@ def hsdp( def hsdp_wait_grad_handle(): """wait for hsdp gradient handle to be completed""" - global current_platform - if current_platform is None: + if platform is None: return - current_platform.wait_grad_handle() + platform.wait_grad_handle() diff --git a/hyper_parallel/core/hsdp/hsdp_param.py b/hyper_parallel/core/hsdp/hsdp_param.py index 4bca12ae606a9aff99f25df03e6c23a580122e69..6b41934a95c1b4143bf36c284c7e037ffd6f4fb6 100644 --- a/hyper_parallel/core/hsdp/hsdp_param.py +++ b/hyper_parallel/core/hsdp/hsdp_param.py @@ -14,6 +14,7 @@ # ============================================================================ """HSDP parameter""" import functools +from hyper_parallel.core.dtensor import DTensor from hyper_parallel.core.hsdp.hsdp_utils import OptimizerLevel, GroupInfo @@ -51,15 +52,128 @@ class HSDPParam: def _init_rank_info(self): """init parameter rank info""" - pass + self.rank_id = self.platform.get_rank() + self.hsdp_rank = self.rank_id + self.local_rank = self.rank_id + self.tp_rank = 0 + if not isinstance(self.param, DTensor) or self.param.layout is None: + self.rank_size = self.platform.get_world_size() + return + + if len(self.param.layout.rank_list) == 1: + self.rank_size = 1 + return + + try: + self.local_rank = self.param.layout.rank_list.index(self.rank_id) + except ValueError as e: + raise ValueError(f"HSDP invalid rank {self.rank_id} with rank list {self.param.layout.rank_list}.") from e + + tensor_map = self.param.layout.tensor_map + sharded_axis_set = set() + for axis in tensor_map: + if isinstance(axis, int) and axis != -1: + sharded_axis_set.add(axis) + continue + if isinstance(axis, tuple): + for item in axis: + sharded_axis_set.add(item) + self.sharded_axis_set = sharded_axis_set + self.rank_size = 1 + self.unsharded_reverse_axis_list = [] + self.global_rank_stride_list = [] + self.hsdp_rank_stride_list = [] + self.tp_rank_stride_list = [] + device_dims = len(self.param.layout.device_matrix) + stride = 1 + hsdp_stride = 1 + tp_stride = 1 + for axis in range(device_dims): + r_axis = device_dims - 1 - axis + self.global_rank_stride_list.append(stride) + self.hsdp_rank_stride_list.append(hsdp_stride) + self.tp_rank_stride_list.append(tp_stride) + stride = stride * self.param.layout.device_matrix[r_axis] + if axis in self.sharded_axis_set: + tp_stride = tp_stride * self.param.layout.device_matrix[r_axis] + continue + + hsdp_stride = hsdp_stride * self.param.layout.device_matrix[r_axis] + self.unsharded_reverse_axis_list.append(r_axis) + self.rank_size = self.rank_size * self.param.layout.device_matrix[r_axis] + self.global_rank_stride_list.reverse() + self.hsdp_rank_stride_list.reverse() + self.tp_rank_stride_list.reverse() + self.unsharded_reverse_axis_list.reverse() + + rank_indices = [] + index = self.local_rank + for stride in self.global_rank_stride_list: + rank_indices.append(index // stride) + index = index % stride + self.rank_indices = rank_indices + hsdp_rank = 0 + for axis in self.unsharded_reverse_axis_list: + hsdp_rank = hsdp_rank + rank_indices[axis] * self.hsdp_rank_stride_list[axis] + self.hsdp_rank = hsdp_rank + tp_rank = 0 + for axis in range(device_dims): + if axis in self.sharded_axis_set: + r_axis = device_dims - 1 - axis + tp_rank = tp_rank + rank_indices[r_axis] * self.tp_rank_stride_list[r_axis] + self.tp_rank = tp_rank + + def _hsdp_rank_to_global_rank(self, hsdp_rank_list): + """transform from hsdp rank to global rank""" + rank_list = [] + for hsdp_rank in hsdp_rank_list: + local_index = hsdp_rank + local_indices_dict = {} + for axis in self.unsharded_reverse_axis_list: + stride = self.hsdp_rank_stride_list[axis] + local_indices_dict[axis] = local_index // stride + local_index = local_index % stride + global_rank = 0 + for axis, index in enumerate(self.rank_indices): + index = local_indices_dict.get(axis, index) + global_rank = global_rank + index * self.global_rank_stride_list[axis] + if self.param.layout is not None: + if global_rank >= len(self.param.layout.rank_list): + raise ValueError(f"HSDP invalid index {global_rank} with" + f"rank list len {len(self.param.layout.rank_list)}.") + global_rank = self.param.layout.rank_list[global_rank] + rank_list.append(global_rank) + return rank_list + + def _get_op_rank_list(self): + """get data parallel rank list""" + if isinstance(self.param, DTensor): + rank_base = self.hsdp_rank // self.shard_size * self.shard_size + hsdp_rank_list = [i + rank_base for i in range(self.shard_size)] + return self._hsdp_rank_to_global_rank(hsdp_rank_list) + rank_base = self.local_rank // self.shard_size * self.shard_size + rank_list = [i + rank_base for i in range(self.shard_size)] + return rank_list + + def _get_dp_rank_list(self): + """get optimizer parallel rank list""" + if isinstance(self.param, DTensor): + rank_stride = self.shard_size + rank_base = self.hsdp_rank % rank_stride + hsdp_rank_list = [i * rank_stride + rank_base for i in range(self.dp_size)] + return self._hsdp_rank_to_global_rank(hsdp_rank_list) + rank_stride = self.shard_size + rank_base = self.local_rank % rank_stride + rank_list = [i * rank_stride + rank_base for i in range(self.dp_size)] + return rank_list def _init_sharded_param(self): """add and init sharded param""" - pass + raise NotImplementedError("HSDP param subclasses must implement _init_sharded_param") def _init_unsharded_param(self): """add and init unshared param""" - pass + raise NotImplementedError("HSDP param subclasses must implement _init_unsharded_param") def _get_unsharded_param_data(self, async_op): """get unsharded param data with async comm""" @@ -103,19 +217,6 @@ class HSDPParam: if rank_gcd % self.shard_size != 0: self.shard_size = 1 self.param.hsdp_effective_shard_size = self.shard_size - - def _get_op_rank_list(self): - """get data parallel rank list""" - rank_base = self.local_rank // self.shard_size * self.shard_size - rank_list = [i + rank_base for i in range(self.shard_size)] - return rank_list - - def _get_dp_rank_list(self): - """get optimizer parallel rank list""" - rank_stride = self.shard_size - rank_base = self.local_rank % rank_stride - rank_list = [i * rank_stride + rank_base for i in range(self.dp_size)] - return rank_list def _create_sharded_dp_group(self): """create communication group for sharded parameter""" diff --git a/hyper_parallel/core/hsdp/hsdp_utils.py b/hyper_parallel/core/hsdp/hsdp_utils.py index 62c44b2abde2590bd7411870abd60e9d0e1e8d48..cf5c4f4fb36af11873750952defc8c7f8714424b 100644 --- a/hyper_parallel/core/hsdp/hsdp_utils.py +++ b/hyper_parallel/core/hsdp/hsdp_utils.py @@ -15,12 +15,6 @@ """HSDP optimizer shared level""" from enum import auto, Enum -class PlatformType(Enum): - """ - PlatformType - """ - MINDSPORE = auto() - PYTORCH = auto() class OptimizerLevel(Enum): """ @@ -41,9 +35,9 @@ class GroupInfo: GroupInfo """ def __init__(self, group_name, group, rank_size): - self.group_name = group_name - self.group = group - self.rank_size = rank_size + self.group_name = group_name + self.group = group + self.rank_size = rank_size class HSDPConfig: """HSDP config""" @@ -71,7 +65,7 @@ class HSDPConfig: self.shard_level = shard_level self.use_eager_hook = use_eager_hook self.reduce_dtype = reduce_dtype - self.comm_async = comm_fusion + self.comm_async = comm_async self.comm_fusion = comm_fusion self.bucket_size = bucket_size self.grad_fusion = comm_fusion and bucket_size != 0 diff --git a/hyper_parallel/core/tensor_parallel/_op_dispatch.py b/hyper_parallel/core/tensor_parallel/_op_dispatch.py index 12573d71dda1321bb7d7c098394bdd4e826f61ec..094ec23ac79a524dc0467d8d71f73442927b92a1 100644 --- a/hyper_parallel/core/tensor_parallel/_op_dispatch.py +++ b/hyper_parallel/core/tensor_parallel/_op_dispatch.py @@ -36,7 +36,6 @@ def disable_dtensor_dispatch(): _dtensor_dispatch = False def get_dtensor_dispatch(): - global _dtensor_dispatch return _dtensor_dispatch @@ -97,7 +96,8 @@ class OpDispatcher: self.layout_infer_ops = self.safe_load_yaml_from_dir() self.whitelist = ["InplaceAddExt", "InplaceSubExt", "InplaceMul", "InplaceDiv", "typeof", "DistCommIsend", "DistCommIrecv", "DistCommBroadcast", "DistCommAllReduce", "DistCommAllGather", - "DistCommReduceScatter", "requires_grad_", "item", "__get__", "register_hook"] + "DistCommReduceScatter", "requires_grad_", "item", "__get__", "__set__", "register_hook", + "is_complex", "chunk"] for op_name, config in self.layout_infer_ops.items(): class_name = config['distributed_op_class'] module_name = "hyper_parallel.core.tensor_parallel.ops." + config['distributed_op_file'] @@ -139,7 +139,6 @@ class OpDispatcher: cache_manager = LayoutCacheManager.get_instance() layout_cache = cache_manager.get_layout_cache() - global platform func_name = platform.get_op_name(func) if func_name not in layout_cache: layout_cache[func_name] = {} @@ -208,7 +207,6 @@ class OpDispatcher: cache_manager = LayoutCacheManager.get_instance() layout_cache = cache_manager.get_layout_cache() - global platform func_name = platform.get_op_name(func) if func_name not in layout_cache: layout_cache[func_name] = {} @@ -241,7 +239,7 @@ class OpDispatcher: return DTensor.from_local(py_output, output_layout) - def _with_layout_infer_reshape(self, func: callable, *args, **kwargs) -> Tensor: + def _with_layout_infer_reshape(self, func: callable, *args) -> Tensor: """_with_layout_infer_reshape""" input_tensor = args[0] shape = args[1] @@ -263,7 +261,6 @@ class OpDispatcher: cache_manager = LayoutCacheManager.get_instance() layout_cache = cache_manager.get_layout_cache() - global platform func_name = platform.get_op_name(func) if func_name not in layout_cache: layout_cache[func_name] = {} @@ -325,7 +322,6 @@ class OpDispatcher: cache_manager = LayoutCacheManager.get_instance() layout_cache = cache_manager.get_layout_cache() - global platform func_name = platform.get_op_name(func) if func_name not in layout_cache: layout_cache[func_name] = {} @@ -359,7 +355,7 @@ class OpDispatcher: return DTensor.from_local(py_output, output_layout) - def _with_layout_infer_slice(self, func: callable, *args, **kwargs) -> Tensor: + def _with_layout_infer_slice(self, func: callable, *args) -> Tensor: """_with_layout_infer_slice""" input_tensor = args[0] begin = args[1] @@ -385,7 +381,6 @@ class OpDispatcher: cache_manager = LayoutCacheManager.get_instance() layout_cache = cache_manager.get_layout_cache() - global platform func_name = platform.get_op_name(func) if func_name not in layout_cache: layout_cache[func_name] = {} @@ -434,7 +429,6 @@ class OpDispatcher: :param kwargs: :return: """ - global platform op_name = platform.get_op_name(op_call) if op_name in self.whitelist or get_dtensor_dispatch() is False: input_args = [arg.to_local() if isinstance(arg, DTensor) else arg for arg in args] @@ -450,11 +444,11 @@ class OpDispatcher: if suffix == "WithShape": return self._with_layout_infer_with_shape(op_call, *args, **kwargs) if suffix == "Reshape": - return self._with_layout_infer_reshape(op_call, *args, **kwargs) + return self._with_layout_infer_reshape(op_call, *args) if suffix == "WithTupleExpand": return self._with_layout_infer_with_tuple_expand(op_call, *args, **kwargs) if suffix == "Slice": - return self._with_layout_infer_slice(op_call, *args, **kwargs) + return self._with_layout_infer_slice(op_call, *args) raise RuntimeError(f"Operator {op_name} specified wrong suffix in parallel yaml.") -_OP_DISPATCHER = OpDispatcher() \ No newline at end of file +_OP_DISPATCHER = OpDispatcher() diff --git a/hyper_parallel/platform/__init__.py b/hyper_parallel/platform/__init__.py index 89da8d340a695fb3510dc445aa00c7700bf25ed1..d65bdd9e5b440ba91e2e045951465e081d8d75f6 100644 --- a/hyper_parallel/platform/__init__.py +++ b/hyper_parallel/platform/__init__.py @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -from hyper_parallel.platform.platform import get_platform, set_platform +"platform api" +from hyper_parallel.platform.platform import get_platform diff --git a/hyper_parallel/platform/mindspore/hsdp/param.py b/hyper_parallel/platform/mindspore/hsdp/param.py index 416b6c0722249dc7920904d1ed759f3d9825a7a1..b700c0c3f7159e569301716ddcbafdd6f774e401 100644 --- a/hyper_parallel/platform/mindspore/hsdp/param.py +++ b/hyper_parallel/platform/mindspore/hsdp/param.py @@ -16,9 +16,7 @@ from mindspore import ops from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor -from mindspore.communication import get_rank, get_group_size from mindspore.common.initializer import initializer -from hyper_parallel.core.dtensor import DTensor from hyper_parallel.core.hsdp.hsdp_param import HSDPParam @@ -26,118 +24,6 @@ class MindSporeHSDPParam(HSDPParam): """ MindSpore HSDP parameter. """ - def _init_rank_info(self): - """init parameter rank info""" - self.rank_id = get_rank() - self.hsdp_rank = self.rank_id - self.local_rank = self.rank_id - self.tp_rank = 0 - if not isinstance(self.param, DTensor) or self.param.layout is None: - self.rank_size = get_group_size() - return - - if len(self.param.layout.rank_list) == 1: - self.rank_size = 1 - return - - try: - self.local_rank = self.param.layout.rank_list.index(self.rank_id) - except ValueError as e: - raise ValueError(f"HSDP invalid rank {self.rank_id} with rank list {self.param.layout.rank_list}.") from e - - tensor_map = self.param.layout.tensor_map - sharded_axis_set = set() - for axis in tensor_map: - if isinstance(axis, int) and axis != -1: - sharded_axis_set.add(axis) - continue - if isinstance(axis, tuple): - for item in axis: - sharded_axis_set.add(item) - self.sharded_axis_set = sharded_axis_set - self.rank_size = 1 - self.unsharded_reverse_axis_list = [] - self.global_rank_stride_list = [] - self.hsdp_rank_stride_list = [] - self.tp_rank_stride_list = [] - device_dims = len(self.param.layout.device_matrix) - stride = 1 - hsdp_stride = 1 - tp_stride = 1 - for axis in range(device_dims): - r_axis = device_dims - 1 - axis - self.global_rank_stride_list.append(stride) - self.hsdp_rank_stride_list.append(hsdp_stride) - self.tp_rank_stride_list.append(tp_stride) - stride = stride * self.param.layout.device_matrix[r_axis] - if axis in self.sharded_axis_set: - tp_stride = tp_stride * self.param.layout.device_matrix[r_axis] - continue - - hsdp_stride = hsdp_stride * self.param.layout.device_matrix[r_axis] - self.unsharded_reverse_axis_list.append(r_axis) - self.rank_size = self.rank_size * self.param.layout.device_matrix[r_axis] - self.global_rank_stride_list.reverse() - self.hsdp_rank_stride_list.reverse() - self.tp_rank_stride_list.reverse() - self.unsharded_reverse_axis_list.reverse() - - rank_indices = [] - index = self.local_rank - for stride in self.global_rank_stride_list: - rank_indices.append(index // stride) - index = index % stride - self.rank_indices = rank_indices - hsdp_rank = 0 - for axis in self.unsharded_reverse_axis_list: - hsdp_rank = hsdp_rank + rank_indices[axis] * self.hsdp_rank_stride_list[axis] - self.hsdp_rank = hsdp_rank - tp_rank = 0 - for axis in range(device_dims): - if axis in self.sharded_axis_set: - r_axis = device_dims - 1 - axis - tp_rank = tp_rank + rank_indices[r_axis] * self.tp_rank_stride_list[r_axis] - self.tp_rank = tp_rank - - def _hsdp_rank_to_global_rank(self, hsdp_rank_list): - """transform from hsdp rank to global rank""" - rank_list = [] - for hsdp_rank in hsdp_rank_list: - local_index = hsdp_rank - local_indices_dict = {} - for axis in self.unsharded_reverse_axis_list: - stride = self.hsdp_rank_stride_list[axis] - local_indices_dict[axis] = local_index // stride - local_index = local_index % stride - global_rank = 0 - for axis, index in enumerate(self.rank_indices): - index = local_indices_dict.get(axis, index) - global_rank = global_rank + index * self.global_rank_stride_list[axis] - if self.param.layout is not None: - if global_rank >= len(self.param.layout.rank_list): - raise ValueError(f"HSDP invalid index {global_rank} with" - f"rank list len {len(self.param.layout.rank_list)}.") - global_rank = self.param.layout.rank_list[global_rank] - rank_list.append(global_rank) - return rank_list - - def _get_op_rank_list(self): - """get data parallel rank list""" - if isinstance(self.param, DTensor): - rank_base = self.hsdp_rank // self.shard_size * self.shard_size - hsdp_rank_list = [i + rank_base for i in range(self.shard_size)] - return self._hsdp_rank_to_global_rank(hsdp_rank_list) - return super()._get_op_rank_list() - - def _get_dp_rank_list(self): - """get optimizer parallel rank list""" - if isinstance(self.param, DTensor): - rank_stride = self.shard_size - rank_base = self.hsdp_rank % rank_stride - hsdp_rank_list = [i * rank_stride + rank_base for i in range(self.dp_size)] - return self._hsdp_rank_to_global_rank(hsdp_rank_list) - return super()._get_dp_rank_list() - def _init_sharded_param(self): """add and init sharded param""" if not self.param.has_init: diff --git a/hyper_parallel/platform/mindspore/hsdp/scheduler.py b/hyper_parallel/platform/mindspore/hsdp/scheduler.py index 372eb194859ea75144b2deacb5984002a4af80d9..449ed6b5b635df7760a2116d603d718bca02c585 100644 --- a/hyper_parallel/platform/mindspore/hsdp/scheduler.py +++ b/hyper_parallel/platform/mindspore/hsdp/scheduler.py @@ -18,7 +18,7 @@ from mindspore import ops from mindspore.common.tensor import Tensor from hyper_parallel.core.hsdp.hsdp_utils import OptimizerLevel from hyper_parallel.core.hsdp.hsdp_scheduler import HSDPScheduler -from hyper_parallel.platform.mindspore.platform import MindSporePlatform +from hyper_parallel.platform import get_platform from hyper_parallel.platform.mindspore.platform_graph import MindSporeGraphPlatform from hyper_parallel.platform.mindspore.hsdp.state import MindSporeHSDPState from hyper_parallel.platform.mindspore.hsdp.grad_hook import MindSporeHSDPGradHook @@ -31,7 +31,7 @@ class MindSporeHSDPScheduler(HSDPScheduler): def _init_platform(self): """init platform""" if self.config.use_eager_hook: - self.platform = MindSporePlatform() + self.platform = get_platform() else: self.platform = MindSporeGraphPlatform() @@ -57,14 +57,14 @@ class MindSporeHSDPScheduler(HSDPScheduler): """get param forward hook.""" def stateless_param_forward_hook(origin_param): - output, _ = self.platform.all_gather_into_tensor(origin_param, group=hsdp_param.sharded_group_info) + output, _ = self.platform.all_gather_into_tensor(origin_param, hsdp_param.sharded_group_info) return output def stateful_param_forward_hook(origin_param): if hsdp_param.unsharded_param_available: return hsdp_param.unsharded_param - unshared_data, _ = self.platform.all_gather_into_tensor(origin_param, group=hsdp_param.sharded_group_info) + unshared_data, _ = self.platform.all_gather_into_tensor(origin_param, hsdp_param.sharded_group_info) ops.assign(hsdp_param.unsharded_param, unshared_data) ops.assign(hsdp_param.unsharded_param_available, Tensor(True)) return hsdp_param.unsharded_param diff --git a/hyper_parallel/platform/mindspore/platform.py b/hyper_parallel/platform/mindspore/platform.py index 910311ced8f328611fd7a60f83d0cc9a19b56df5..7f6783045fce6a0892a1178855a93402ebf4a042 100644 --- a/hyper_parallel/platform/mindspore/platform.py +++ b/hyper_parallel/platform/mindspore/platform.py @@ -22,18 +22,20 @@ from mindspore.common.initializer import initializer from mindspore.communication import get_group_size from mindspore.communication import create_group as new_group from mindspore.communication import get_rank as get_rank_id -import mindspore.communication.comm_func as comm_func -from hyper_parallel.platform.platform import Platform +from mindspore.communication import comm_func +from hyper_parallel.platform.platform import Platform, PlatformType from hyper_parallel.platform.mindspore.dtensor import DTensorBase from hyper_parallel.platform.mindspore.parameter_init import init_parameters as _init_parameters +# pylint: disable=C0103 class MindSporePlatform(Platform): """MindSpore platform api""" Tensor = Tensor Parameter = Parameter Module = Cell DTensorBase = DTensorBase + platform_type = PlatformType.MINDSPORE @staticmethod def get_rank(): diff --git a/hyper_parallel/platform/platform.py b/hyper_parallel/platform/platform.py index f30c01094787766759fa754e7ae35ae2d678f6b7..d134e124e57209286ded9d655d3a218f0048279b 100644 --- a/hyper_parallel/platform/platform.py +++ b/hyper_parallel/platform/platform.py @@ -13,18 +13,23 @@ # limitations under the License. # ============================================================================ """framework platform api""" +from enum import auto, Enum + + +class PlatformType(Enum): + """ + PlatformType + """ + MINDSPORE = auto() + PYTORCH = auto() + + platform = None -def set_platform(platform_type): - global platform - if "torch" in platform_type: - from hyper_parallel.platform.torch.platform import TorchPlatform - platform = TorchPlatform() - else: - from hyper_parallel.platform.mindspore.platform import MindSporePlatform - platform = MindSporePlatform() +# pylint: disable=C0415 def get_platform(): + """get framework platform""" global platform if platform is not None: return platform @@ -37,7 +42,7 @@ def get_platform(): platform = TorchPlatform() return platform -EXISTING_COMM_GROUPS = dict() +EXISTING_COMM_GROUPS = {} class Platform: @@ -48,39 +53,43 @@ class Platform: @staticmethod def get_rank(): - pass + raise NotImplementedError("Platform subclasses must implement get_rank") @staticmethod def get_world_size(): - pass + raise NotImplementedError("Platform subclasses must implement get_world_size") @staticmethod def get_op_name(func): - pass + raise NotImplementedError("Platform subclasses must implement get_op_name") @staticmethod def differentiable_all_gather_concat(data, group, concat_size, concat_dim): - pass + raise NotImplementedError("Platform subclasses must implement differentiable_all_gather_concat") @staticmethod def chunk(data, split_dim, split_size, index): - pass + raise NotImplementedError("Platform subclasses must implement chunk") @staticmethod def differentiable_all_to_all(input_data, output_shape, group): - pass + raise NotImplementedError("Platform subclasses must implement differentiable_all_to_all") @staticmethod - def differentiable_all_reduce(data, op, group, async_op=False): - pass + def differentiable_all_reduce(data, op, group): + raise NotImplementedError("Platform subclasses must implement differentiable_all_reduce") @staticmethod def differentiable_reduce_scatter(data, dev_num, axis, op, group): - pass + raise NotImplementedError("Platform subclasses must implement differentiable_reduce_scatter") @staticmethod def init_parameters(module, stage_index): - pass + """platform ms need init parameter interface""" + if module is None: + raise ValueError("input module must not be none.") + if stage_index < 0: + raise ValueError("input stage_index must be positive.") @staticmethod def register_forward_pre_hook(cell, hook): @@ -100,13 +109,11 @@ class Platform: @staticmethod def get_param_local_shape(param): - """get param local shape""" - return param.shape + raise NotImplementedError("Platform subclasses must implement get_param_local_shape") @staticmethod def get_param_local_data(param): - """get param local shape""" - return param + raise NotImplementedError("Platform subclasses must implement get_param_local_data") @staticmethod def update_param_data(param, data): @@ -115,38 +122,39 @@ class Platform: @staticmethod def get_param_type_size(param): - pass + raise NotImplementedError("Platform subclasses must implement get_param_type_size") @staticmethod def new_zero_parameter(param_shape, param_type, requires_grad): - pass + raise NotImplementedError("Platform subclasses must implement new_zero_parameter") @staticmethod def new_tensor(tensor_shape, tensor_type): - pass + raise NotImplementedError("Platform subclasses must implement new_tensor") @staticmethod def all_gather_into_tensor(data, group_info, async_op=False): - pass + raise NotImplementedError("Platform subclasses must implement all_gather_into_tensor") @staticmethod def all_reduce(data, group_info, async_op=False): - pass + raise NotImplementedError("Platform subclasses must implement all_reduce") @staticmethod def reduce_scatter_tensor(data, group_info, async_op=False): - pass + raise NotImplementedError("Platform subclasses must implement reduce_scatter_tensor") def _create_group(self, rank_list, group_name=None): - pass + raise NotImplementedError("Platform subclasses must implement _create_group") def new_stream(self): - pass + raise NotImplementedError("Platform subclasses must implement new_stream") def get_stream_context(self): - pass + raise NotImplementedError("Platform subclasses must implement get_stream_context") def create_group(self, rank_list, group_name=None): + """create comm group with rank list""" if group_name is None: group_key = hash(tuple(rank_list)) else: @@ -166,6 +174,7 @@ class Platform: Platform.current_grad_handle.wait() if Platform.post_grad_handle_process is None: return + # pylint: disable=E1102 Platform.post_grad_handle_process() def set_grad_reduce_handle(self, handle, post_process=None): @@ -179,6 +188,7 @@ class Platform: Platform.post_grad_handle_process = post_process def wait_grad_handle(self): + """wait grad handle""" if Platform.current_grad_handle is None: return if Platform.grad_sync_stream is None: diff --git a/hyper_parallel/platform/torch/function_override.py b/hyper_parallel/platform/torch/function_override.py new file mode 100644 index 0000000000000000000000000000000000000000..ee464f8ba251ba0b148e04034505636ce725ffe8 --- /dev/null +++ b/hyper_parallel/platform/torch/function_override.py @@ -0,0 +1,60 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Torch function override""" +from torch.nn.modules import _functions +from torch.nn.modules._functions import BackwardHookFunction + + +class DTensorBackwardHookFunction(BackwardHookFunction): + """override BackwardHookFunction for dtensor""" + + @classmethod + def apply(cls, *args, **kwargs): + """override apply function""" + # pylint: disable=C0415 + from hyper_parallel import DTensor + + input_args = [] + input_layouts = [] + + for arg in args: + if arg is None: + input_layouts.append(None) + input_args.append(arg) + continue + + if not hasattr(arg, "_layout"): + input_layouts.append(None) + input_args.append(arg) + else: + layout = arg.layout + input_layouts.append(layout) + input_args.append(arg.to_local()) + + origin_output = BackwardHookFunction.apply(*input_args, **kwargs) + + if len(origin_output) != len(input_args): + raise RuntimeError("number of output should equal to number of input") + + if isinstance(origin_output, (tuple, list)): + output = () + for i, output_item in enumerate(origin_output): + output += (DTensor.from_local(output_item, input_layouts[i]),) + return output + return origin_output + + +def override_functions(): + _functions.BackwardHookFunction = DTensorBackwardHookFunction diff --git a/hyper_parallel/platform/torch/hsdp/param.py b/hyper_parallel/platform/torch/hsdp/param.py index 00344ea29c2b399248cdf328f08924bb389df9c2..4a9030a50b779897434b4046436089c4b5b08383 100644 --- a/hyper_parallel/platform/torch/hsdp/param.py +++ b/hyper_parallel/platform/torch/hsdp/param.py @@ -22,31 +22,21 @@ class TorchHSDPParam(HSDPParam): """ Torch HSDP parameter. """ - def _init_rank_info(self): - """init parameter rank info""" - self.rank_id = dist.get_rank() - self.rank_size = dist.get_world_size() - self.hsdp_rank = self.rank_id - self.local_rank = self.rank_id - self.tp_rank = 0 - def _init_sharded_param(self): """add and init sharded param""" slice_index = self.hsdp_rank % self.shard_size - param_slice = torch.trunk(self.param, self.shard_size, 0)[slice_index] - self.param.data = param_slice + local_param = self.platform.get_param_local_data(self.param) + param_slice = torch.chunk(local_param, self.shard_size, 0)[slice_index] + self.platform.update_param_data(self.param, param_slice) self.sharded_param = param_slice def _init_unsharded_param(self): """add and init unshared param""" - self.unsharded_param = torch.empty(self.param_shape, dtype=self.param.dtype) - - def _update_param_data(self, param, data): - """update param data""" - param.data = data + self.unsharded_param = torch.empty(self.param_shape, dtype=self.param.dtype, device=self.param.device) - def _get_unsharded_param_data(self, comm_async): + def _get_unsharded_param_data(self, async_op): """get unsharded param data with async comm""" - handle = dist.all_gather_into_tensor(self.unsharded_param, self.param, group=self.sharded_group, - async_op=comm_async) + local_param = self.platform.get_param_local_data(self.param) + handle = dist.all_gather_into_tensor(self.unsharded_param, local_param, group=self.sharded_group_info.group, + async_op=async_op) return self.unsharded_param, handle diff --git a/hyper_parallel/platform/torch/hsdp/scheduler.py b/hyper_parallel/platform/torch/hsdp/scheduler.py index b22a99e06de1232a525ccffdb01351206e57ccb1..eccb88ee9d3bb53d54771d80d0297cac8a78f33c 100644 --- a/hyper_parallel/platform/torch/hsdp/scheduler.py +++ b/hyper_parallel/platform/torch/hsdp/scheduler.py @@ -14,8 +14,8 @@ # ============================================================================ """HSDP scheduler""" from hyper_parallel.core.hsdp.hsdp_scheduler import HSDPScheduler -from hyper_parallel.platform.torch.platform import TorchPlatform from hyper_parallel.platform.torch.hsdp.state import TorchHSDPState +from hyper_parallel.platform import get_platform class TorchHSDPScheduler(HSDPScheduler): @@ -23,7 +23,7 @@ class TorchHSDPScheduler(HSDPScheduler): def _init_platform(self): """init platform""" - self.platform = TorchPlatform() + self.platform = get_platform() def _new_cell_state(self): """new cell state""" diff --git a/hyper_parallel/platform/torch/platform.py b/hyper_parallel/platform/torch/platform.py index 767c6c4097304c07d736911b7c831bbe709b2b7f..b0147a7d4ac64f9550d5a13c28f06cc648db015b 100644 --- a/hyper_parallel/platform/torch/platform.py +++ b/hyper_parallel/platform/torch/platform.py @@ -20,17 +20,21 @@ from torch.nn import Parameter, Module from torch._ops import OpOverload, OpOverloadPacket import torch.distributed.nn.functional as dist_func import torch.distributed as dist -from hyper_parallel.platform.platform import Platform from hyper_parallel.platform.torch.dtensor import DTensorBase from hyper_parallel.platform.torch.group_utils import create_sub_groups +from hyper_parallel.platform.platform import Platform, PlatformType +from hyper_parallel.platform.torch.function_override import override_functions +override_functions() +# pylint: disable=C0103 class TorchPlatform(Platform): """Torch platform api""" Tensor = Tensor Parameter = Parameter Module = Module DTensorBase = DTensorBase + platform_type = PlatformType.PYTORCH @staticmethod def get_rank(): @@ -40,23 +44,36 @@ class TorchPlatform(Platform): def get_world_size(): return dist.get_world_size() + @staticmethod + def get_param_local_shape(param): + """get param local shape""" + if isinstance(param, DTensorBase): + return param.local_shape + return param.shape + + @staticmethod + def get_param_local_data(param): + """get param local shape""" + if isinstance(param, DTensorBase): + return param.to_local() + return param + @staticmethod def get_op_name(func): if hasattr(func, "__name__"): return func.__name__ - elif isinstance(func, OpOverload): + if isinstance(func, OpOverload): full_name = func.name core_name = full_name.split("::")[-1].split(".")[0] return core_name - elif isinstance(func, OpOverloadPacket): + if isinstance(func, OpOverloadPacket): return func.name.split("::")[-1] - else: - func_str = str(func) - if "built-in function" in func_str: - return func_str.split()[-1].strip(">") - elif "function" in func_str: - return func_str.split()[1] - return "unknown_op" + func_str = str(func) + if "built-in function" in func_str: + return func_str.split()[-1].strip(">") + if "function" in func_str: + return func_str.split()[1] + return "unknown_op" @staticmethod def differentiable_all_gather_concat(data, group, concat_size, concat_dim): @@ -104,6 +121,7 @@ class TorchPlatform(Platform): @staticmethod def get_param_type_size(param): + # pylint: disable=W0212 return torch._utils._element_size(param.dtype) @staticmethod @@ -114,7 +132,7 @@ class TorchPlatform(Platform): def new_tensor(tensor_shape, tensor_type): return torch.empty(tensor_shape, tensor_type) - def _create_group(self, rank_list, group_name): + def _create_group(self, rank_list, group_name=None): group_dict = create_sub_groups(rank_list) return group_dict[tuple(rank_list)] @@ -122,20 +140,20 @@ class TorchPlatform(Platform): def all_gather_into_tensor(data, group_info, async_op=False): output_shape = list(data.shape) output_shape[0] = output_shape[0] * group_info.rank_size - output = torch.empty(output_shape, dtype=data.dtype) + output = torch.empty(output_shape, dtype=data.dtype, device=data.device) handle = dist.all_gather_into_tensor(output, data, group=group_info.group, async_op=async_op) return output, handle @staticmethod def all_reduce(data, group_info, async_op=False): - handle = dist.all_reduce(data.data, group=group_info.group, async_op=async_op) + handle = dist.all_reduce(data, group=group_info.group, async_op=async_op) return data, handle @staticmethod def reduce_scatter_tensor(data, group_info, async_op=False): output_shape = list(data.shape) output_shape[0] = output_shape[0] // group_info.rank_size - output = torch.empty(output_shape, dtype=data.dtype) + output = torch.empty(output_shape, dtype=data.dtype, device=data.device) handle = dist.reduce_scatter_tensor(output, data, group=group_info.group, async_op=async_op) return output, handle diff --git a/tests/torch/__init__.py b/tests/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..83b15297dbb5877a8f4175bfffa35a331ca2a156 --- /dev/null +++ b/tests/torch/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/tests/torch/common_net.py b/tests/torch/common_net.py new file mode 100644 index 0000000000000000000000000000000000000000..e0120ef08269f406cf73875a400166f398cfbd25 --- /dev/null +++ b/tests/torch/common_net.py @@ -0,0 +1,30 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""common net""" +import torch +from torch import nn + + +class SimpleModel(nn.Module): + """simple model""" + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.ones(8, 8).npu()) + + def forward(self, x): + x = torch.matmul(x, self.weight) + x = torch.relu(x) + x = torch.sum(x) + return x diff --git a/tests/torch/hsdp/dp.py b/tests/torch/hsdp/dp.py new file mode 100644 index 0000000000000000000000000000000000000000..6ea653603d86cb22d965a80c1491dbf85aaed312 --- /dev/null +++ b/tests/torch/hsdp/dp.py @@ -0,0 +1,82 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test data parallel""" +import numpy as np +import torch +from torch import nn +from torch import optim +# pylint: disable=W0611 +import torch_npu +from hyper_parallel import hsdp, DTensor, Layout, SkipDTensorDispatch +from tests.torch.common_net import SimpleModel +from tests.torch.utils import init_dist + + +def test_data_parallel(): + """test data parallel""" + rank, _ = init_dist() + step = 2 + + # -----------------------------------standalone---------------------------------- + standalone_model = SimpleModel().npu() + standalone_optimizer = optim.SGD(standalone_model.parameters(), lr=0.01) + standalone_x = torch.ones(8, 8).npu() + for _ in range (step): + standalone_loss = standalone_model(standalone_x) + standalone_loss.backward() + standalone_grad = standalone_model.weight.grad.data # using for validation + standalone_optimizer.step() + standalone_optimizer.zero_grad() + + # --------------------------------------dist------------------------------------- + dist_model = SimpleModel().npu() + + layout = Layout((2, 4), ("dp", "tp")) + x_layout = layout("None", "None") + w_layout = layout("None", "tp") + dist_x = DTensor.from_local(torch.ones(8, 8).npu(), x_layout) + local_w = torch.ones(8, 2).npu() + # pylint: disable=W0212 + for key, param in dist_model._parameters.items(): + if param is not None and not isinstance(param, DTensor): + dist_model.register_parameter( + key, + nn.Parameter(DTensor.from_local(local_w, w_layout)), + ) + dist_model = hsdp(dist_model, shard_size=1) + dist_optimizer = optim.SGD(dist_model.parameters(), lr=0.01) + + for _ in range (step): + dist_loss = dist_model(dist_x) + dist_loss = dist_loss.reduce_partial() # handle partial state + + # handle backward input + repeat_num = dist_loss.layout.repeat_num() + backward_input = torch.tensor(1.0 / repeat_num) + dist_loss.backward(backward_input) + + dist_grad = dist_model.weight.grad.data + + with SkipDTensorDispatch(): + dist_optimizer.step() + dist_optimizer.zero_grad() + + assert np.allclose(standalone_loss.cpu().detach().numpy(), + dist_loss.to_local().cpu().detach().numpy(), # use to_local() + 0.001, 0.001) + offset = rank % 4 * 2 + assert np.allclose(standalone_grad.cpu().detach().numpy()[:, offset: offset + 2], + dist_grad.cpu().detach().numpy(), + 0.001, 0.001) diff --git a/tests/torch/hsdp/test_dp.py b/tests/torch/hsdp/test_dp.py new file mode 100644 index 0000000000000000000000000000000000000000..1debf2b3af1d4ba9732f0ba95fde832fd56bc158 --- /dev/null +++ b/tests/torch/hsdp/test_dp.py @@ -0,0 +1,29 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test data parallel""" +from tests.torch.utils import torchrun_case +from tests.common.mark_utils import arg_mark + + +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") +def test_data_parallel(): + ''' + Feature: data parallel + Description: + Expectation: Run success. + ''' + master_port = 12345 + case_name = "dp.py::test_data_parallel" + torchrun_case(master_port, case_name) diff --git a/tests/torch/tensor_parallel/base_dtensor.py b/tests/torch/tensor_parallel/base_dtensor.py index 1844cec2f5bd220c388a78167423f38fbbca548a..8c0bc56a3cf3b4a2a5763f059bd30881408601d9 100644 --- a/tests/torch/tensor_parallel/base_dtensor.py +++ b/tests/torch/tensor_parallel/base_dtensor.py @@ -12,24 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -import os +"""test torch dtensor""" +import numpy as np import torch +from torch import nn +from torch import optim +# pylint: disable=W0611 import torch_npu # 昇腾NPU核心适配 -import torch.distributed as dist -import torch.nn as nn -import torch.optim as optim -import numpy as np from hyper_parallel import DTensor, Layout, SkipDTensorDispatch -"""test torch dtensor""" - -def init_dist(): - """init dist""" - dist.init_process_group() - rank = dist.get_rank() - torch.npu.set_device(rank) - device_id = rank % 8 # 8卡对应device_id 0-7 - torch.npu.set_device(device_id) - return rank, device_id +from tests.torch.utils import init_dist class SimpleModel(nn.Module): @@ -56,17 +47,20 @@ class SimpleModel(nn.Module): return x -def base_dtensor(): - rank, device_id = init_dist() - is_master = (rank == 0) +def test_base_dtensor(): + ''' + Feature: dtensor infer layout and redistribute. + Description: + Expectation: Run success. + ''' + init_dist() step = 2 - world_size = dist.get_world_size() # -----------------------------------standalone---------------------------------- standalone_model = SimpleModel().npu() standalone_optimizer = optim.SGD(standalone_model.parameters(), lr=0.01) standalone_x = torch.ones(8, 8).npu() - for i in range (step): + for _ in range (step): standalone_loss = standalone_model(standalone_x) standalone_loss.backward() standalone_grad = standalone_model.weight.grad.data # using for validation @@ -81,16 +75,17 @@ def base_dtensor(): w_layout = layout("None", "tp") dist_x = DTensor.from_local(torch.ones(8, 8).npu(), x_layout) local_w = torch.ones(8, 1).npu() + # pylint: disable=W0212 for key, param in dist_model._parameters.items(): if param is not None and not isinstance(param, DTensor): dist_model.register_parameter( key, nn.Parameter(DTensor.from_local(local_w, w_layout)), ) - + dist_optimizer = optim.SGD(dist_model.parameters(), lr=0.01) - for i in range (step): + for _ in range (step): dist_loss = dist_model(dist_x) dist_loss = dist_loss.reduce_partial() # handle partial state @@ -112,16 +107,3 @@ def base_dtensor(): assert np.allclose(standalone_grad.cpu().detach().numpy(), dist_grad.cpu().detach().numpy(), 0.001, 0.001) - - -def test_base_dtensor(): - ''' - Feature: dtensor infer layout and redistribute. - Description: - Expectation: Run success. - ''' - base_dtensor() - - -if __name__ == "__main__": - test_base_dtensor() diff --git a/tests/torch/tensor_parallel/test_base_dtensor.py b/tests/torch/tensor_parallel/test_base_dtensor.py index 197d63b79aaeeb1bf800a05a4e9af535bc63fe77..e2093129df975b97b3a845f810511c0a9e1b5dd9 100644 --- a/tests/torch/tensor_parallel/test_base_dtensor.py +++ b/tests/torch/tensor_parallel/test_base_dtensor.py @@ -13,17 +13,11 @@ # limitations under the License. # ============================================================================ """test base dtensor""" -import os - - -def run_case(master_port): - cmd = f"torchrun --nproc-per-node=8 " \ - f"--master_addr=127.0.0.1 --master_port={master_port} " \ - f"base_dtensor.py" - ret = os.system(cmd) - assert ret == 0 +from tests.torch.utils import torchrun_case +from tests.common.mark_utils import arg_mark +@arg_mark(plat_marks=["platform_ascend910b"], level_mark="level0", card_mark="allcards", essential_mark="essential") def test_base_dtensor(): ''' Feature: dtensor dispatch/infer_layout/redistribute. @@ -31,8 +25,5 @@ def test_base_dtensor(): Expectation: Run success. ''' master_port = 11333 - run_case(master_port) - - -if __name__ == "__main__": - test_base_dtensor() + case_name = "base_dtensor.py::test_base_dtensor" + torchrun_case(master_port, case_name) diff --git a/tests/torch/utils.py b/tests/torch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..501078628c9e10d77976c1cb38a559e5488a29c8 --- /dev/null +++ b/tests/torch/utils.py @@ -0,0 +1,36 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test utils""" +import os +import torch +import torch.distributed as dist + + +def init_dist(): + """init dist""" + dist.init_process_group() + rank = dist.get_rank() + torch.npu.set_device(rank) + device_id = rank % 8 + torch.npu.set_device(device_id) + return rank, device_id + + +def torchrun_case(master_port, case_name): + cmd = f"torchrun --nproc-per-node=8 " \ + f"--master_addr=127.0.0.1 --master_port={master_port} " \ + f"-m pytest -s {case_name}" + ret = os.system(cmd) + assert ret == 0