diff --git a/model_examples/BEVFusion/README.md b/model_examples/BEVFusion/README.md index cbb0a934b1ca2d0f955c9f93eed6219425e59662..3859017b9ddc5f8c3ed7e02ee696305966542d1a 100644 --- a/model_examples/BEVFusion/README.md +++ b/model_examples/BEVFusion/README.md @@ -152,7 +152,7 @@ cd ../ ```shell # 主节点拉起脚本,默认训练1个epochs bash test/nnodes_train_performance_16p_base_fp32.sh --batch-size=4 --num-npu=8 --nnodes=2 --node-rank=0 --port=port --master-addr=master_addr # master-addr 必须指定,其余可省略以使用默认值 - # 主节点拉起脚本,默认训练1个epochs + # 副节点拉起脚本,默认训练1个epochs bash test/nnodes_train_performance_16p_base_fp32.sh --batch-size=4 --num-npu=8 --nnodes=2 --node-rank=1 --port=port --master-addr=master_addr # node-rank,master-addr 必须指定,其余可省略以使用默认值 ``` @@ -160,22 +160,24 @@ cd ../ 单机8卡 | NAME | Modality | Voxel type (voxel size) | 训练方式 | Epoch | global batch size | NDS | mAP | FPS | |------------------|-----------|-------------------------|------|-------|-------|-------|-------|-------| -| 8p-Atlas 800T A2 | lidar-cam | 0.075 | FP32 | 6 | 32 | 69.44 | 66.45 | 22.38 | +| 8p-Atlas 800T A2 | lidar-cam | 0.075 | FP32 | 6 | 32 | 69.44 | 66.45 | 23.62 | | 8p-竞品A | lidar-cam | 0.075 | FP32 | 6 | 32 | 69.78 | 67.36 | 22.54 | 双机16卡 | NAME | Modality | Voxel type (voxel size) | 训练方式 | Epoch | global batch size |FPS | 线性度 | |------------------|-----------|-------------------------|------|-------|-------|-------|-------| -| 8p-Atlas 800T A2 | lidar-cam | 0.075 | FP32 | 1 | 64 | 43.76 | 97.07% | +| 8p-Atlas 800T A2 | lidar-cam | 0.075 | FP32 | 1 | 64 | 45.86 | 97.07% | # 版本说明 ## 变更 +2025.8.29:模型优化,更新单机性能。 + 2025.8.1:模型性能优化,更新单机性能及精度。 2025.7.10:更新单机性能及精度。 -2025.5.20:支持双机,更新单机性能。 +2025.5.20:支持双机,更新单机及双机性能。 2024.12.5:首次发布。 diff --git a/model_examples/BEVFusion/bevfusion.patch b/model_examples/BEVFusion/bevfusion.patch index 64e42ce4d4b0c7d9e17296608c42f0ad646a50ac..6d248c1780dfcaa3bf3606148c8a0c4bb167c906 100644 --- a/model_examples/BEVFusion/bevfusion.patch +++ b/model_examples/BEVFusion/bevfusion.patch @@ -12,42 +12,264 @@ index 56e8440b..b3a6382a 100644 if ground_plane is not None: xyz = sampled_gt_bboxes[:, :3] diff --git a/mmdet3d/models/layers/sparse_block.py b/mmdet3d/models/layers/sparse_block.py -index 6ed7c8f4..13f69b0d 100644 +index 6ed7c8f4..6a5ba828 100644 --- a/mmdet3d/models/layers/sparse_block.py +++ b/mmdet3d/models/layers/sparse_block.py -@@ -2,17 +2,22 @@ - from typing import Optional, Tuple, Union - - from mmcv.cnn import build_conv_layer, build_norm_layer +@@ -1,224 +1,349 @@ +-# Copyright (c) OpenMMLab. All rights reserved. +-from typing import Optional, Tuple, Union +- +-from mmcv.cnn import build_conv_layer, build_norm_layer -from mmdet.models.backbones.resnet import BasicBlock, Bottleneck +-from torch import nn +- +-from mmdet3d.utils import OptConfigType +-from .spconv import IS_SPCONV2_AVAILABLE +- +-if IS_SPCONV2_AVAILABLE: +- from spconv.pytorch import SparseConvTensor, SparseModule, SparseSequential +-else: +- from mmcv.ops import SparseConvTensor, SparseModule, SparseSequential +- +- +-def replace_feature(out: SparseConvTensor, +- new_features: SparseConvTensor) -> SparseConvTensor: +- if 'replace_feature' in out.__dir__(): +- # spconv 2.x behaviour +- return out.replace_feature(new_features) +- else: +- out.features = new_features +- return out +- +- +-class SparseBottleneck(Bottleneck, SparseModule): +- """Sparse bottleneck block for PartA^2. +- +- Bottleneck block implemented with submanifold sparse convolution. +- +- Args: +- inplanes (int): Inplanes of block. +- planes (int): Planes of block. +- stride (int or Tuple[int]): Stride of the first block. Defaults to 1. +- downsample (Module, optional): Down sample module for block. +- Defaults to None. +- indice_key (str): Indice key for spconv. Default to None. +- conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for +- convolution layer. Defaults to None. +- norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for +- normalization layer. Defaults to None. +- """ +- +- expansion = 4 +- +- def __init__(self, +- inplanes: int, +- planes: int, +- stride: Union[int, Tuple[int]] = 1, +- downsample: nn.Module = None, +- indice_key=None, +- conv_cfg: OptConfigType = None, +- norm_cfg: OptConfigType = None) -> None: +- +- SparseModule.__init__(self) +- if conv_cfg is None: +- conv_cfg = dict(type='SubMConv3d') +- conv_cfg.setdefault('indice_key', indice_key) +- if norm_cfg is None: +- norm_cfg = dict(type='BN1d') +- Bottleneck.__init__( +- self, +- inplanes, +- planes, +- stride=stride, +- downsample=downsample, +- conv_cfg=conv_cfg, +- norm_cfg=norm_cfg) +- +- def forward(self, x: SparseConvTensor) -> SparseConvTensor: +- identity = x.features +- +- out = self.conv1(x) +- out = replace_feature(out, self.bn1(out.features)) +- out = replace_feature(out, self.relu(out.features)) +- +- out = self.conv2(out) +- out = replace_feature(out, self.bn2(out.features)) +- out = replace_feature(out, self.relu(out.features)) +- +- out = self.conv3(out) +- out = replace_feature(out, self.bn3(out.features)) +- +- if self.downsample is not None: +- identity = self.downsample(x).features +- +- out = replace_feature(out, out.features + identity) +- out = replace_feature(out, self.relu(out.features)) +- +- return out +- +- +-class SparseBasicBlock(BasicBlock, SparseModule): +- """Sparse basic block for PartA^2. +- +- Sparse basic block implemented with submanifold sparse convolution. +- +- Args: +- inplanes (int): Inplanes of block. +- planes (int): Planes of block. +- stride (int or Tuple[int]): Stride of the first block. Defaults to 1. +- downsample (Module, optional): Down sample module for block. +- Defaults to None. +- indice_key (str): Indice key for spconv. Default to None. +- conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for +- convolution layer. Defaults to None. +- norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for +- normalization layer. Defaults to None. +- """ +- +- expansion = 1 +- +- def __init__(self, +- inplanes: int, +- planes: int, +- stride: Union[int, Tuple[int]] = 1, +- downsample: nn.Module = None, +- indice_key: Optional[str] = None, +- conv_cfg: OptConfigType = None, +- norm_cfg: OptConfigType = None) -> None: +- SparseModule.__init__(self) +- if conv_cfg is None: +- conv_cfg = dict(type='SubMConv3d') +- conv_cfg.setdefault('indice_key', indice_key) +- if norm_cfg is None: +- norm_cfg = dict(type='BN1d') +- BasicBlock.__init__( +- self, +- inplanes, +- planes, +- stride=stride, +- downsample=downsample, +- conv_cfg=conv_cfg, +- norm_cfg=norm_cfg) +- +- def forward(self, x: SparseConvTensor) -> SparseConvTensor: +- identity = x.features +- +- assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}' +- out = self.conv1(x) +- out = replace_feature(out, self.norm1(out.features)) +- out = replace_feature(out, self.relu(out.features)) +- +- out = self.conv2(out) +- out = replace_feature(out, self.norm2(out.features)) +- +- if self.downsample is not None: +- identity = self.downsample(x).features +- +- out = replace_feature(out, out.features + identity) +- out = replace_feature(out, self.relu(out.features)) +- +- return out +- +- +-def make_sparse_convmodule(in_channels: int, +- out_channels: int, +- kernel_size: Union[int, Tuple[int]], +- indice_key: Optional[str] = None, +- stride: Union[int, Tuple[int]] = 1, +- padding: Union[int, Tuple[int]] = 0, +- conv_type: str = 'SubMConv3d', +- norm_cfg: OptConfigType = None, +- order: Tuple[str] = ('conv', 'norm', 'act'), +- **kwargs) -> SparseSequential: +- """Make sparse convolution module. +- +- Args: +- in_channels (int): The number of input channels. +- out_channels (int): The number of out channels. +- kernel_size (int | Tuple[int]): Kernel size of convolution. +- indice_key (str): The indice key used for sparse tensor. +- stride (int or tuple[int]): The stride of convolution. +- padding (int or tuple[int]): The padding number of input. +- conv_type (str): Sparse conv type in spconv. Defaults to 'SubMConv3d'. +- norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for +- normalization layer. Defaults to None. +- order (Tuple[str]): The order of conv/norm/activation layers. It is a +- sequence of "conv", "norm" and "act". Common examples are +- ("conv", "norm", "act") and ("act", "conv", "norm"). +- Defaults to ('conv', 'norm', 'act'). +- +- Returns: +- spconv.SparseSequential: sparse convolution module. +- """ +- assert isinstance(order, tuple) and len(order) <= 3 +- assert set(order) | {'conv', 'norm', 'act'} == {'conv', 'norm', 'act'} +- +- conv_cfg = dict(type=conv_type, indice_key=indice_key) +- if norm_cfg is None: +- norm_cfg = dict(type='BN1d') +- +- layers = list() +- for layer in order: +- if layer == 'conv': +- if conv_type not in [ +- 'SparseInverseConv3d', 'SparseInverseConv2d', +- 'SparseInverseConv1d' +- ]: +- layers.append( +- build_conv_layer( +- conv_cfg, +- in_channels, +- out_channels, +- kernel_size, +- stride=stride, +- padding=padding, +- bias=False)) +- else: +- layers.append( +- build_conv_layer( +- conv_cfg, +- in_channels, +- out_channels, +- kernel_size, +- bias=False)) +- elif layer == 'norm': +- layers.append(build_norm_layer(norm_cfg, out_channels)[1]) +- elif layer == 'act': +- layers.append(nn.ReLU(inplace=True)) +- +- layers = SparseSequential(*layers) +- return layers ++# Copyright (c) OpenMMLab. All rights reserved. ++from typing import Optional, Tuple, Union ++ ++from mmcv.cnn import build_conv_layer, build_norm_layer +import inspect +from mmengine.model import BaseModule +import torch - from torch import nn ++from torch import nn +import torch.utils.checkpoint as cp - - from mmdet3d.utils import OptConfigType --from .spconv import IS_SPCONV2_AVAILABLE ++ ++from mmdet3d.utils import OptConfigType +from mx_driving.spconv import SparseSequential, SubMConv3d, SparseConv3d, SparseModule +from mmdet.models.backbones.resnet import BasicBlock, Bottleneck +from mmengine.registry import Registry +from mx_driving.spconv import SparseConvTensor - --if IS_SPCONV2_AVAILABLE: -- from spconv.pytorch import SparseConvTensor, SparseModule, SparseSequential --else: -- from mmcv.ops import SparseConvTensor, SparseModule, SparseSequential - ++ ++ +MODELS = Registry('Sparse conv layer') +MODELS.register_module('SubMConv3d', module=SubMConv3d) +MODELS.register_module('SparseConv3d', module=SparseConv3d) - - def replace_feature(out: SparseConvTensor, - new_features: SparseConvTensor) -> SparseConvTensor: -@@ -23,6 +28,87 @@ def replace_feature(out: SparseConvTensor, - out.features = new_features - return out - ++ ++def replace_feature(out: SparseConvTensor, ++ new_features: SparseConvTensor) -> SparseConvTensor: ++ if 'replace_feature' in out.__dir__(): ++ # spconv 2.x behaviour ++ return out.replace_feature(new_features) ++ else: ++ out.features = new_features ++ return out ++ +class BasicBlock(BaseModule): + expansion = 1 + @@ -129,31 +351,205 @@ index 6ed7c8f4..13f69b0d 100644 + out = self.relu(out) + + return out - - class SparseBottleneck(Bottleneck, SparseModule): - """Sparse bottleneck block for PartA^2. -@@ -199,7 +285,7 @@ def make_sparse_convmodule(in_channels: int, - 'SparseInverseConv1d' - ]: - layers.append( -- build_conv_layer( ++ ++class SparseBottleneck(Bottleneck, SparseModule): ++ """Sparse bottleneck block for PartA^2. ++ ++ Bottleneck block implemented with submanifold sparse convolution. ++ ++ Args: ++ inplanes (int): Inplanes of block. ++ planes (int): Planes of block. ++ stride (int or Tuple[int]): Stride of the first block. Defaults to 1. ++ downsample (Module, optional): Down sample module for block. ++ Defaults to None. ++ indice_key (str): Indice key for spconv. Default to None. ++ conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for ++ convolution layer. Defaults to None. ++ norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for ++ normalization layer. Defaults to None. ++ """ ++ ++ expansion = 4 ++ ++ def __init__(self, ++ inplanes: int, ++ planes: int, ++ stride: Union[int, Tuple[int]] = 1, ++ downsample: nn.Module = None, ++ indice_key=None, ++ conv_cfg: OptConfigType = None, ++ norm_cfg: OptConfigType = None) -> None: ++ ++ SparseModule.__init__(self) ++ if conv_cfg is None: ++ conv_cfg = dict(type='SubMConv3d') ++ conv_cfg.setdefault('indice_key', indice_key) ++ if norm_cfg is None: ++ norm_cfg = dict(type='BN1d') ++ Bottleneck.__init__( ++ self, ++ inplanes, ++ planes, ++ stride=stride, ++ downsample=downsample, ++ conv_cfg=conv_cfg, ++ norm_cfg=norm_cfg) ++ ++ def forward(self, x: SparseConvTensor) -> SparseConvTensor: ++ identity = x.features ++ ++ out = self.conv1(x) ++ out = replace_feature(out, self.bn1(out.features)) ++ out = replace_feature(out, self.relu(out.features)) ++ ++ out = self.conv2(out) ++ out = replace_feature(out, self.bn2(out.features)) ++ out = replace_feature(out, self.relu(out.features)) ++ ++ out = self.conv3(out) ++ out = replace_feature(out, self.bn3(out.features)) ++ ++ if self.downsample is not None: ++ identity = self.downsample(x).features ++ ++ out = replace_feature(out, out.features + identity) ++ out = replace_feature(out, self.relu(out.features)) ++ ++ return out ++ ++ ++class SparseBasicBlock(BasicBlock, SparseModule): ++ """Sparse basic block for PartA^2. ++ ++ Sparse basic block implemented with submanifold sparse convolution. ++ ++ Args: ++ inplanes (int): Inplanes of block. ++ planes (int): Planes of block. ++ stride (int or Tuple[int]): Stride of the first block. Defaults to 1. ++ downsample (Module, optional): Down sample module for block. ++ Defaults to None. ++ indice_key (str): Indice key for spconv. Default to None. ++ conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for ++ convolution layer. Defaults to None. ++ norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for ++ normalization layer. Defaults to None. ++ """ ++ ++ expansion = 1 ++ ++ def __init__(self, ++ inplanes: int, ++ planes: int, ++ stride: Union[int, Tuple[int]] = 1, ++ downsample: nn.Module = None, ++ indice_key: Optional[str] = None, ++ conv_cfg: OptConfigType = None, ++ norm_cfg: OptConfigType = None) -> None: ++ SparseModule.__init__(self) ++ if conv_cfg is None: ++ conv_cfg = dict(type='SubMConv3d') ++ conv_cfg.setdefault('indice_key', indice_key) ++ if norm_cfg is None: ++ norm_cfg = dict(type='BN1d') ++ BasicBlock.__init__( ++ self, ++ inplanes, ++ planes, ++ stride=stride, ++ downsample=downsample, ++ conv_cfg=conv_cfg, ++ norm_cfg=norm_cfg) ++ ++ def forward(self, x: SparseConvTensor) -> SparseConvTensor: ++ identity = x.features ++ ++ assert x.features.dim() == 2, f'x.features.dim()={x.features.dim()}' ++ out = self.conv1(x) ++ out = replace_feature(out, self.norm1(out.features)) ++ out = replace_feature(out, self.relu(out.features)) ++ ++ out = self.conv2(out) ++ out = replace_feature(out, self.norm2(out.features)) ++ ++ if self.downsample is not None: ++ identity = self.downsample(x).features ++ ++ out = replace_feature(out, out.features + identity) ++ out = replace_feature(out, self.relu(out.features)) ++ ++ return out ++ ++ ++def make_sparse_convmodule(in_channels: int, ++ out_channels: int, ++ kernel_size: Union[int, Tuple[int]], ++ indice_key: Optional[str] = None, ++ stride: Union[int, Tuple[int]] = 1, ++ padding: Union[int, Tuple[int]] = 0, ++ conv_type: str = 'SubMConv3d', ++ norm_cfg: OptConfigType = None, ++ order: Tuple[str] = ('conv', 'norm', 'act'), ++ **kwargs) -> SparseSequential: ++ """Make sparse convolution module. ++ ++ Args: ++ in_channels (int): The number of input channels. ++ out_channels (int): The number of out channels. ++ kernel_size (int | Tuple[int]): Kernel size of convolution. ++ indice_key (str): The indice key used for sparse tensor. ++ stride (int or tuple[int]): The stride of convolution. ++ padding (int or tuple[int]): The padding number of input. ++ conv_type (str): Sparse conv type in spconv. Defaults to 'SubMConv3d'. ++ norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for ++ normalization layer. Defaults to None. ++ order (Tuple[str]): The order of conv/norm/activation layers. It is a ++ sequence of "conv", "norm" and "act". Common examples are ++ ("conv", "norm", "act") and ("act", "conv", "norm"). ++ Defaults to ('conv', 'norm', 'act'). ++ ++ Returns: ++ spconv.SparseSequential: sparse convolution module. ++ """ ++ assert isinstance(order, tuple) and len(order) <= 3 ++ assert set(order) | {'conv', 'norm', 'act'} == {'conv', 'norm', 'act'} ++ ++ conv_cfg = dict(type=conv_type, indice_key=indice_key) ++ if norm_cfg is None: ++ norm_cfg = dict(type='BN1d') ++ ++ layers = list() ++ for layer in order: ++ if layer == 'conv': ++ if conv_type not in [ ++ 'SparseInverseConv3d', 'SparseInverseConv2d', ++ 'SparseInverseConv1d' ++ ]: ++ layers.append( + build_sparse_conv_layer( - conv_cfg, - in_channels, - out_channels, -@@ -209,7 +295,7 @@ def make_sparse_convmodule(in_channels: int, - bias=False)) - else: - layers.append( -- build_conv_layer( ++ conv_cfg, ++ in_channels, ++ out_channels, ++ kernel_size, ++ stride=stride, ++ padding=padding, ++ bias=False)) ++ else: ++ layers.append( + build_sparse_conv_layer( - conv_cfg, - in_channels, - out_channels, -@@ -222,3 +308,42 @@ def make_sparse_convmodule(in_channels: int, - - layers = SparseSequential(*layers) - return layers ++ conv_cfg, ++ in_channels, ++ out_channels, ++ kernel_size, ++ bias=False)) ++ elif layer == 'norm': ++ layers.append(build_norm_layer(norm_cfg, out_channels)[1]) ++ elif layer == 'act': ++ layers.append(nn.ReLU(inplace=True)) ++ ++ layers = SparseSequential(*layers) ++ return layers + + +def build_sparse_conv_layer(cfg, *args, **kwargs): @@ -651,7 +1047,7 @@ index 00000000..44801857 + $CONFIG \ + --launcher pytorch ${@:7} diff --git a/tools/train.py b/tools/train.py -index b2ced54b..fcf09854 100644 +index b2ced54b..506747e8 100644 --- a/tools/train.py +++ b/tools/train.py @@ -8,6 +8,13 @@ from mmengine.config import Config, DictAction @@ -668,21 +1064,20 @@ index b2ced54b..fcf09854 100644 from mmdet3d.utils import replace_ceph_backend -@@ -132,4 +139,16 @@ def main(): +@@ -132,4 +139,15 @@ def main(): if __name__ == '__main__': - main() + torch_npu.npu.set_compile_mode(jit_compile=False) -+ torch_npu.npu.config.allow_internal_format = False + + pb = PatcherBuilder().add_module_patch("torch", Patch(batch_matmul)) + if os.environ.get('PERFORMANCE_MODE', '0') == '1': + # Performance-Testing mode: use Patcher to set breakpoints + pb = pb.brake_at(1000) -+ with pb.build(): ++ with pb.build(allow_internal_format=True): + main() + else: + # Training mode: run the main function directly -+ with pb.build(): ++ with pb.build(allow_internal_format=True): + main()