diff --git a/docs/get_started/patcher.md b/docs/get_started/patcher.md index 12ca7667022c77cd32848273f0f8ddfcb4423202..ac050ac55d9696237e438820518623b40e851fbc 100644 --- a/docs/get_started/patcher.md +++ b/docs/get_started/patcher.md @@ -73,7 +73,7 @@ import torch import torch_npu import mx_driving from mx_driving.patcher import PatcherBuilder, Patch -from mx_driving.patcher import ddp, ddp_forward +from mx_driving.patcher import ddp from mx_driving.patcher import resnet_add_relu, resnet_maxpool, nuscenes_dataset from mx_driving.patcher import dc, mdc, msda @@ -82,7 +82,7 @@ bev_former_patcher_builder = ( .add_module_patch("mmcv.ops", Patch(msda), Patch(dc), Patch(mdc)) .add_module_patch("mmdet.models.backbones.resnet", Patch(resnet_add_relu), Patch(resnet_maxpool)) .add_module_patch("mmdet3d.datasets.nuscenes_dataset", Patch(nuscenes_dataset)) - .add_module_patch("mmcv.parallel", Patch(ddp), Patch(ddp_forward)) + .add_module_patch("mmcv.parallel.distributed", Patch(ddp)) ) ``` @@ -100,7 +100,7 @@ if __name__ == '__main__': ``` ### patcher使能特性说明 -- ddp, ddp_forward用于修改mmcv框架中并行相关代码适配NPU训练。 +- ddp用于修改mmcv框架中并行相关代码适配NPU训练。 - resnet_add_relu, resnet_maxpool用于resnet结构中特定算子的优化,替换为DrivingSDK中高性能算子。 - dc, mdc, msda用于mmcv中DeformConv2d,ModulatedDeformConv2d,MultiScaleDeformableAttn算子替换为DrivingSDK中高性能算子。 - nuscenes_dataset用于针对BEVFormer模型的性能优化。 diff --git a/model_examples/DiffusionDrive/migrate_to_ascend/patch.py b/model_examples/DiffusionDrive/migrate_to_ascend/patch.py index eaa6f5d677c71449556092bda20ac7d1c18b7350..9b82ca05c8643d3f814e1da8511937ee7a687835 100644 --- a/model_examples/DiffusionDrive/migrate_to_ascend/patch.py +++ b/model_examples/DiffusionDrive/migrate_to_ascend/patch.py @@ -16,7 +16,7 @@ import torch_npu import mx_driving from mx_driving import deformable_aggregation from mx_driving.patcher import PatcherBuilder, Patch -from mx_driving.patcher import index, batch_matmul, numpy_type, ddp, stream, ddp_forward +from mx_driving.patcher import index, batch_matmul, numpy_type, ddp, stream from mx_driving.patcher import resnet_add_relu, resnet_maxpool @@ -402,22 +402,6 @@ def get_hccl_init_dist(runner: ModuleType): return None -def run_ddp_forward(parallel: ModuleType, options: Dict): - - def _run_ddp_forward(self, *inputs, **kwargs): - module_to_run = self.module - - if self.device_ids: - inputs, kwargs = self.to_kwargs( # type: ignore - inputs, kwargs, self.device_ids[0]) - return module_to_run(*inputs[0], **kwargs[0]) # type: ignore - else: - return module_to_run(*inputs, **kwargs) - - if hasattr(parallel, "MMDistributedDataParallel"): - parallel.MMDistributedDataParallel._run_ddp_forward = _run_ddp_forward - - def instance_queue(queue: ModuleType, options: Dict): def prepare_motion( @@ -470,7 +454,8 @@ def generate_patcher_builder(performance=False): PatcherBuilder() .add_module_patch("torch", Patch(index), Patch(batch_matmul)) .add_module_patch("numpy", Patch(numpy_type)) - .add_module_patch("mmcv.parallel", Patch(ddp), Patch(stream), Patch(ddp_forward), Patch(run_ddp_forward)) + .add_module_patch("mmcv.parallel", Patch(stream)) + .add_module_patch("mmcv.parallel.distributed", Patch(ddp)) .add_module_patch("mmdet.models.backbones.resnet", Patch(resnet_add_relu), Patch(resnet_maxpool)) .add_module_patch("projects.mmdet3d_plugin.models.attention", Patch(flash_attn)) diff --git a/mx_driving/patcher/__init__.py b/mx_driving/patcher/__init__.py index 102c6cc4c5eeda456092b0f92777dde39ff8549c..54dbb9e21b7085853ab209a253f7d3ffc1983b13 100644 --- a/mx_driving/patcher/__init__.py +++ b/mx_driving/patcher/__init__.py @@ -40,7 +40,6 @@ __all__ = [ "pseudo_sampler", "numpy_type", "ddp", - "ddp_forward", "stream", "resnet_add_relu", "resnet_maxpool", @@ -51,7 +50,7 @@ __all__ = [ "optimizer", ] -from mx_driving.patcher.distribute import ddp, ddp_forward +from mx_driving.patcher.distribute import ddp from mx_driving.patcher.functions import stream from mx_driving.patcher.mmcv import dc, mdc, msda, patch_mmcv_version from mx_driving.patcher.mmdet import pseudo_sampler, resnet_add_relu, resnet_maxpool, resnet_fp16 @@ -69,7 +68,8 @@ default_patcher_builder = ( .add_module_patch("torch", Patch(index), Patch(batch_matmul)) .add_module_patch("numpy", Patch(numpy_type)) .add_module_patch("mmdet.core.bbox.samplers", Patch(pseudo_sampler)) - .add_module_patch("mmcv.parallel", Patch(ddp), Patch(stream), Patch(ddp_forward)) + .add_module_patch("mmcv.parallel", Patch(stream)) + .add_module_patch("mmcv.parallel.distributed", Patch(ddp)) .add_module_patch("mmdet.models.backbones.resnet", Patch(resnet_add_relu), Patch(resnet_maxpool)) .add_module_patch("mmdet3d.datasets.nuscenes_dataset", Patch(nuscenes_dataset)) .add_module_patch("mmdet3d.evaluation.metrics", Patch(nuscenes_metric)) diff --git a/mx_driving/patcher/brake.py b/mx_driving/patcher/brake.py index da7aaf189bda0c786b8f58dbc6b8b130243c21f4..5b448611d409df936f8738db3990e9bf244cc359 100644 --- a/mx_driving/patcher/brake.py +++ b/mx_driving/patcher/brake.py @@ -142,9 +142,23 @@ def brake(runner: ModuleType, options: Dict): if hasattr(runner, "EpochBasedRunner"): runner.EpochBasedRunner.train = train + else: + raise AttributeError("EpochBasedRunner not found") + if hasattr(runner, "EpochBasedTrainLoop"): runner.EpochBasedTrainLoop.run_epoch = run_epoch + else: + raise AttributeError("EpochBasedTrainLoop not found") + + if hasattr(runner, "IterBasedTrainLoop"): runner.IterBasedTrainLoop.run = run + else: + raise AttributeError("IterBasedTrainLoop not found") + + if hasattr(runner, "IterBasedRunner"): runner.IterBasedRunner.run = run_iter + else: + raise AttributeError("IterBasedRunner not found") + \ No newline at end of file diff --git a/mx_driving/patcher/distribute.py b/mx_driving/patcher/distribute.py index 93c5c67c03d2c4c3a51037cccd8cb6f7afab47d6..913152a8694cc9edc1533602a53f3ba36afee5b0 100644 --- a/mx_driving/patcher/distribute.py +++ b/mx_driving/patcher/distribute.py @@ -3,14 +3,10 @@ from types import ModuleType from typing import Dict -def ddp(mmcvparallel: ModuleType, options: Dict): - if hasattr(mmcvparallel, "distributed"): - import mmcv.device - mmcvparallel.distributed.MMDistributedDataParallel = mmcv.device.npu.NPUDistributedDataParallel - - -def ddp_forward(mmcvparallel: ModuleType, options: Dict): - def new_forward(self, *inputs, **kwargs): +def ddp(module: ModuleType, options: Dict): + # For mmcv 1.x: module path is mmcv.parallel.distributed + + def _run_ddp_forward(self, *inputs, **kwargs): module_to_run = self.module if self.device_ids: @@ -20,5 +16,10 @@ def ddp_forward(mmcvparallel: ModuleType, options: Dict): else: return module_to_run(*inputs, **kwargs) - if hasattr(mmcvparallel, "distributed"): - mmcvparallel.distributed.MMDistributedDataParallel._run_ddp_forward = new_forward \ No newline at end of file + + if hasattr(module, "MMDistributedDataParallel"): + import mmcv.device + module.MMDistributedDataParallel._run_ddp_forward = _run_ddp_forward + module.MMDistributedDataParallel = mmcv.device.npu.NPUDistributedDataParallel + else: + raise AttributeError("MMDistributedDataParallel not found") \ No newline at end of file diff --git a/mx_driving/patcher/functions.py b/mx_driving/patcher/functions.py index 1925c57a41e3095eb9a07c3063823420974b279e..7d689c0396ddb5c2a175f5758dd6359f74c73aaa 100644 --- a/mx_driving/patcher/functions.py +++ b/mx_driving/patcher/functions.py @@ -32,3 +32,5 @@ def stream(mmcvparallel: ModuleType, options: Dict): if hasattr(mmcvparallel._functions, "Scatter"): mmcvparallel._functions.Scatter.forward = new_forward + else: + raise AttributeError("Scatter not found") \ No newline at end of file diff --git a/mx_driving/patcher/mmcv.py b/mx_driving/patcher/mmcv.py index dee8bf4ca033f03fa4101254694d568c78aeebbb..3e61c44eb6d0fec9d4ba51a2d406e3b9900b73c6 100644 --- a/mx_driving/patcher/mmcv.py +++ b/mx_driving/patcher/mmcv.py @@ -42,6 +42,8 @@ def msda(mmcvops: ModuleType, options: Dict): MultiScaleDeformableAttnFunction.forward = apply_mxdriving_msda_forward_param(MultiScaleDeformableAttnFunction.forward) MultiScaleDeformableAttnFunction.backward = apply_mxdriving_msda_backward_param(MultiScaleDeformableAttnFunction.backward) mmcvops.multi_scale_deform_attn.MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction + else: + raise AttributeError("multi_scale_deform_attn not found") def dc(mmcvops: ModuleType, options: Dict): @@ -50,6 +52,8 @@ def dc(mmcvops: ModuleType, options: Dict): if hasattr(mmcvops, "deform_conv"): mmcvops.deform_conv.DeformConv2dFunction = DeformConv2dFunction mmcvops.deform_conv.deform_conv2d = deform_conv2d + else: + raise AttributeError("deform_conv not found") def mdc(mmcvops: ModuleType, options: Dict): @@ -58,3 +62,5 @@ def mdc(mmcvops: ModuleType, options: Dict): if hasattr(mmcvops, "modulated_deform_conv"): mmcvops.modulated_deform_conv.ModulatedDeformConv2dFunction = ModulatedDeformConv2dFunction mmcvops.modulated_deform_conv.modulated_deform_conv2d = modulated_deform_conv2d + else: + raise AttributeError("modulated_deform_conv not found") \ No newline at end of file diff --git a/mx_driving/patcher/mmdet.py b/mx_driving/patcher/mmdet.py index 8311a9fdaccd65a10945b5c2526d6c792b7f4294..6e5224f193befd72cc9dcaee55ca432bb48bdacb 100644 --- a/mx_driving/patcher/mmdet.py +++ b/mx_driving/patcher/mmdet.py @@ -19,6 +19,8 @@ def pseudo_sampler(mmdetsamplers: ModuleType, options: Dict): return sampling_result mmdetsamplers.pseudo_sampler.PseudoSampler.sample = sample + else: + raise AttributeError("pseudo_sampler not found") def resnet_add_relu(mmdetresnet: ModuleType, options: Dict): @@ -50,6 +52,8 @@ def resnet_add_relu(mmdetresnet: ModuleType, options: Dict): return out mmdetresnet.BasicBlock.forward = forward + else: + raise AttributeError("BasicBlock not found") if hasattr(mmdetresnet, "Bottleneck"): @@ -93,6 +97,8 @@ def resnet_add_relu(mmdetresnet: ModuleType, options: Dict): return out mmdetresnet.Bottleneck.forward = forward + else: + raise AttributeError("Bottleneck not found") def resnet_maxpool(mmdetresnet: ModuleType, options: Dict): @@ -119,6 +125,8 @@ def resnet_maxpool(mmdetresnet: ModuleType, options: Dict): return tuple(out) mmdetresnet.ResNet.forward = forward + else: + raise AttributeError("ResNet not found") def resnet_fp16(mmdetresnet: ModuleType, options: Dict): @@ -144,3 +152,5 @@ def resnet_fp16(mmdetresnet: ModuleType, options: Dict): return tuple([out.float() for out in tuple(outs)]) mmdetresnet.ResNet.forward = forward + else: + raise AttributeError("ResNet not found") \ No newline at end of file diff --git a/mx_driving/patcher/mmdet3d.py b/mx_driving/patcher/mmdet3d.py index d02d1258e44ac11c0d8d383e5cbc8737bf002995..c04b552ae4678e6b1505e9210ba1cf13801cb306 100644 --- a/mx_driving/patcher/mmdet3d.py +++ b/mx_driving/patcher/mmdet3d.py @@ -34,6 +34,8 @@ def nuscenes_dataset(mmdet3ddatasets: ModuleType, options: Dict): return box_list mmdet3ddatasets.output_to_nusc_box = output_to_nusc_box + else: + raise AttributeError("output_to_nusc_box not found") def nuscenes_metric(mmdet3dmetrics: ModuleType, options: Dict): @@ -96,3 +98,5 @@ def nuscenes_metric(mmdet3dmetrics: ModuleType, options: Dict): return box_list, attrs mmdet3dmetrics.output_to_nusc_box = output_to_nusc_box + else: + raise AttributeError("output_to_nusc_box not found") \ No newline at end of file diff --git a/mx_driving/patcher/nuscenes.py b/mx_driving/patcher/nuscenes.py index d00e2095e28ec8549ff726427d5de91e088f0c19..0effee7e7721f86fcc5242dadda125b27fc7206d 100644 --- a/mx_driving/patcher/nuscenes.py +++ b/mx_driving/patcher/nuscenes.py @@ -60,4 +60,6 @@ def nuscenes_mot_metric(nusceneseval: ModuleType, options: Dict): return r if hasattr(nusceneseval, "mot"): - nusceneseval.mot.MOTAccumulatorCustom.merge_event_dataframes = merge_event_dataframes_new \ No newline at end of file + nusceneseval.mot.MOTAccumulatorCustom.merge_event_dataframes = merge_event_dataframes_new + else: + raise AttributeError("mot not found") \ No newline at end of file diff --git a/mx_driving/patcher/optimizer.py b/mx_driving/patcher/optimizer.py index 7b915373d0514667d243b9f681310145f45b4e2c..1ed5257abc73d8a022f579659e84b77d0a282785 100644 --- a/mx_driving/patcher/optimizer.py +++ b/mx_driving/patcher/optimizer.py @@ -251,6 +251,8 @@ def optimizer_hooks(mmcvhooks: ModuleType, options: Dict): # clear grads runner.model.zero_grad() runner.optimizer.zero_grad() + else: + raise AttributeError("optimizer not found") def optimizer_wrapper(mmcvoptwrapper: ModuleType, options: Dict): @@ -274,3 +276,5 @@ def optimizer_wrapper(mmcvoptwrapper: ModuleType, options: Dict): self.clip_grads = _get_clip_func(self.optimizer) OptimWrapper.__init__ = new_init + else: + raise AttributeError("OptimWrapper not found") \ No newline at end of file diff --git a/mx_driving/patcher/profiler.py b/mx_driving/patcher/profiler.py index 915c8eabaeecbf3edcfe6fc22f9983b52232dc56..542544de9dd68745996a78945cd78c08998fd29d 100644 --- a/mx_driving/patcher/profiler.py +++ b/mx_driving/patcher/profiler.py @@ -180,9 +180,23 @@ def profiler(runner: ModuleType, options: Dict): if hasattr(runner, "EpochBasedRunner"): runner.EpochBasedRunner.train = train + else: + raise AttributeError("EpochBasedRunner not found") + if hasattr(runner, "EpochBasedTrainLoop"): runner.EpochBasedTrainLoop.run_epoch = run_epoch + else: + raise AttributeError("EpochBasedTrainLoop not found") + + if hasattr(runner, "IterBasedTrainLoop"): runner.IterBasedTrainLoop.run = run + else: + raise AttributeError("IterBasedTrainLoop not found") + + if hasattr(runner, "IterBasedRunner"): runner.IterBasedRunner.run = run_iter + else: + raise AttributeError("IterBasedRunner not found") + \ No newline at end of file diff --git a/mx_driving/patcher/tensor.py b/mx_driving/patcher/tensor.py index bfe42e7102ed1754d44ffa7066331bdd74478413..c41cc60ea9d269eb4d48cb71a60a28e452bb6b11 100644 --- a/mx_driving/patcher/tensor.py +++ b/mx_driving/patcher/tensor.py @@ -18,7 +18,10 @@ def index(torch: ModuleType, options: Dict): return torch.masked_select(self, indices).view(-1, self.shape[1]) return fn(self, indices) # fallback to the original function - torch.Tensor.__getitem__ = new_fn + if hasattr(torch, "Tensor"): + torch.Tensor.__getitem__ = new_fn + else: + raise AttributeError('Tensor not found') def check_shape_bmm(a, b): @@ -44,6 +47,13 @@ def batch_matmul(torch: ModuleType, options: Dict): return original_fn(a, b) return wrapper - torch.matmul = create_wrapper(torch.matmul) - torch.Tensor.matmul = create_wrapper(torch.Tensor.matmul) - torch.Tensor.__matmul__ = create_wrapper(torch.Tensor.__matmul__) + if hasattr(torch, "matmul"): + torch.matmul = create_wrapper(torch.matmul) + else: + raise AttributeError("matmul not found") + + if hasattr(torch, "Tensor"): + torch.Tensor.matmul = create_wrapper(torch.Tensor.matmul) + torch.Tensor.__matmul__ = create_wrapper(torch.Tensor.__matmul__) + else: + raise AttributeError("Tensor not found") \ No newline at end of file diff --git a/tests/torch/test_patcher_distribute.py b/tests/torch/test_patcher_distribute.py index 7dba095c7a818ded2c30a449b952d1583040cee5..7aa47fc2f1e493a3df23b371569c5d7b97c5111e 100644 --- a/tests/torch/test_patcher_distribute.py +++ b/tests/torch/test_patcher_distribute.py @@ -6,7 +6,7 @@ from unittest.mock import ANY, patch, MagicMock, PropertyMock import torch import torch_npu from torch_npu.testing.testcase import TestCase, run_tests -from mx_driving.patcher import ddp, ddp_forward +from mx_driving.patcher import ddp def assertIsNotInstance(obj, cls): @@ -22,58 +22,9 @@ class TestDistribute(TestCase): def test_ddp_patch(self): # Apply monkey patch - ddp(self.mock_mmcvparallel, {}) + ddp(self.mock_mmcvparallel.distributed, {}) assertIsNotInstance(self.mock_mmcvparallel.distributed.MMDistributedDataParallel, MagicMock) - def test_ddp_forward_patch(self): - # Apply the ddp_forward patch - ddp_forward(self.mock_mmcvparallel, {}) - - # Get the patched _run_ddp_forward method - new_forward = self.mock_mmcvparallel.distributed.MMDistributedDataParallel._run_ddp_forward - - # Verify _run_ddp_forward is correctly replaced - assertIsNotInstance( - new_forward, - MagicMock - ) - - # Create mock instance and inputs - mock_self = MagicMock() - mock_self.device_ids = [0] # Simulate device IDs present - mock_self.module = MagicMock(return_value="module_output") - - # Mock the to_kwargs method - mock_self.to_kwargs = MagicMock(return_value=( - [("processed_input",)], - [{"processed_kwarg": "value"}] - )) - - # Call the patched forward method - result = new_forward(mock_self, "input1", "input2", kwarg1="value1") - - # Check to_kwargs is called correctly - mock_self.to_kwargs.assert_called_once_with( - ("input1", "input2"), - {"kwarg1": "value1"}, - 0 - ) - - # Check module is called correctly - mock_self.module.assert_called_once_with( - "processed_input", - processed_kwarg="value" - ) - - # Verify return value - self.assertEqual(result, "module_output") - - # Test case with no device_ids - mock_self.reset_mock() - mock_self.device_ids = [] - result = new_forward(mock_self, "input3", kwarg2="value2") - mock_self.module.assert_called_once_with("input3", kwarg2="value2") - if __name__ == '__main__': run_tests() \ No newline at end of file diff --git a/tests/torch/test_patcher_functions.py b/tests/torch/test_patcher_functions.py index 24822b652749eccb68e70d4d3778b61352c732d5..66a3dcbf4b31183baba10542c5c2ef401e2bfcc4 100644 --- a/tests/torch/test_patcher_functions.py +++ b/tests/torch/test_patcher_functions.py @@ -109,18 +109,6 @@ class TestPatcherStream(TestCase): self.mock_mmcvparallel._functions.scatter.assert_called_once() self.assertIsInstance(result, tuple) - def test_no_scatter_class(self): - """Verify graceful handling when Scatter class is missing""" - mock_mmcvparallel = MagicMock() - mock_mmcvparallel._functions = MagicMock() - delattr(mock_mmcvparallel._functions, "Scatter") - - from mx_driving.patcher import stream - try: - stream(mock_mmcvparallel, {}) - except AttributeError: - self.fail("stream should handle missing Scatter class gracefully") - if __name__ == "__main__": run_tests() \ No newline at end of file