2.4K Star 8.2K Fork 4.4K

GVPMindSpore / mindspore

 / 详情

【众智】【计算-特性补齐】Conv3D

TODO
Requirement 成员
创建于  
2022-06-30 16:31

Conv3D

Tasks

Infer迁移
CPU: NativeKernelMod适配 + Resize函数
GPU: NativeKernelMod适配 + Resize函数 + 补数据类型

动态Shape ,算子改造 非接口

Introduction

1. 功能介绍

3d卷积层

2. 接口描述

附件
黎冠新 2023-04-01 19:17

评论 (15)

hedongdong 创建了Requirement
hedongdong 添加了
 
sig/ops
标签
hedongdong 添加协作者panzhihui
展开全部操作日志

Please assign maintainer to check this issue.
请为此issue分配处理人。
@hedongdong

Please add labels (comp or sig), also you can visit https://gitee.com/mindspore/community/blob/master/sigs/dx/docs/labels.md to find more.
为了让代码尽快被审核,请您为Pull Request打上 组件(comp)或兴趣组(sig) 标签,打上标签的PR可直接推送给责任人进行审核。
更多的标签可以查看https://gitee.com/mindspore/community/blob/master/sigs/dx/docs/labels.md
以组件相关代码提交为例,如果你提交的是data组件代码,你可以这样评论:
//comp/data
当然你也可以邀请data SIG组来审核代码,可以这样写:
//sig/data
另外你还可以给这个PR标记类型,例如是bugfix或者是特性需求:
//kind/bug or //kind/feature
恭喜你,你已经学会了使用命令来打标签,接下来就在下面的评论里打上标签吧!

panzhihui 修改了描述
hedongdong 添加协作者hedongdong
hedongdong 负责人hedongdong 修改为黎冠新

标杆算子支持的数据类型:half, float32, float64
需新添加的数据类型:float32,float64

panzhihui 修改了描述
panzhihui 取消协作者panzhihui

def init(self, input_shape, out_channel, kernal_size, mode=1,
pad_mode='pad', pad=0, stride=1, dilation=1,
group=1, dtype=np.float16, weight_dtype=np.float16,
data_format='NCDHW'):
super().init(dtype=dtype)
self.in_n, self.in_c, self.in_d, self.in_h, self.in_w = input_shape
self.kernel_d, self.kernel_h, self.kernel_w = kernal_size
self.out_channel = out_channel
self.kernel_size = kernal_size
self.mode = mode
self.pad_mode = pad_mode
self.pad1 = pad
self.pad = pad

    self.stride = stride
    self.dilation = dilation
    self.group = group
    self.data_format = data_format
    self.output_grad_np = None

    self.input_np = np.random.randn(*input_shape).astype(np.float16).astype(dtype)
    weight_c_in = int(self.in_c / self.group)
    self.weight = np.random.randn(out_channel, weight_c_in, *kernal_size).astype(
        np.float16).astype(weight_dtype)

    self.dilation_torch = self.dilation
    if isinstance(self.dilation, int):
        self.dilation_torch = (self.dilation, self.dilation, self.dilation)

    if isinstance(self.pad, int):
        self.padding_torch = [self.pad, self.pad, self.pad]

    elif isinstance(self.pad, tuple) and len(self.pad) == 3:
        self.padding_torch = self.pad1
    else:
        self.padding_torch = self.pad1[4:] + self.pad1[2:4] + self.pad1[:2]

    if isinstance(self.pad, tuple) and len(self.pad) == 3:
        self.pad = (
            self.pad[0], self.pad[0], self.pad[1], self.pad[1], self.pad[2], self.pad[2])
    else:
        self.pad = pad

    tmp_stride = self.stride
    if isinstance(self.stride, int):
        tmp_stride = (stride, stride)

    if pad_mode == 'same':
        if (self.in_d % tmp_stride[0]) == 0:
            pad_along_d = max(self.dilation_torch[0] * (self.kernel_d - 1) + 1 - tmp_stride[0],
                              0)
        else:
            pad_along_d = max(
                self.dilation_torch[0] * (self.kernel_d - 1) + 1 - (self.in_d % tmp_stride[0]),
                0)
        if self.in_h % tmp_stride[1] == 0:
            pad_along_height = max(
                self.dilation_torch[0] * (self.kernel_h - 1) + 1 - tmp_stride[1], 0)
        else:
            pad_along_height = max(
                self.dilation_torch[0] * (self.kernel_h - 1) + 1 - (self.in_h % tmp_stride[1]),
                0)
        if self.in_w % tmp_stride[2] == 0:
            pad_along_width = max(
                self.dilation_torch[1] * (self.kernel_w - 1) + 1 - tmp_stride[2], 0)
        else:
            pad_along_width = max(
                self.dilation_torch[1] * (self.kernel_w - 1) + 1 - (self.in_w % tmp_stride[2]),
                0)
        pad_top = pad_along_height // 2
        pad_bottom = pad_along_height - pad_top
        pad_left = pad_along_width // 2
        pad_right = pad_along_width - pad_left
        pad_head = pad_along_d // 2
        pad_tail = pad_along_d - pad_head
        self.padding_torch = [pad_tail, pad_right, pad_bottom]
def forward_pytorch_impl(self):
    inputa = torch.from_numpy(self.input_np.astype(np.float32))
    weight = torch.from_numpy(self.weight.astype(np.float32))

    if isinstance(self.pad, tuple) and len(self.pad1) == 6:

        torch_input = torch_f.pad(inputa, self.padding_torch)
        net = torch.nn.Conv3d(in_channels=self.in_c, out_channels=self.out_channel,
                              kernel_size=self.kernel_size,
                              stride=self.stride, padding=0,
                              dilation=self.dilation_torch, groups=self.group, bias=False)
        net.register_parameter('weight', torch.nn.Parameter(weight))
        output = net(torch_input)


    else:
        net = torch.nn.Conv3d(in_channels=self.in_c, out_channels=self.out_channel, bias=False,
                              kernel_size=self.kernel_size, stride=self.stride,
                              padding=self.padding_torch,
                              dilation=self.dilation_torch, groups=self.group)
        net.register_parameter('weight', torch.nn.Parameter(weight))
        output = net(inputa)
    return output.detach().numpy().astype(self.dtype)
def forward_mindspore_impl(self):
    input_x = Tensor(self.input_np)
    weight = Tensor(self.weight)
    net = Conv3dnet(out_channel=self.out_channel, kernel_size=self.kernel_size, mode=self.mode,
                    pad_mode=self.pad_mode,
                    pad=self.pad, stride=self.stride,
                    dilation=self.dilation, group=self.group, data_format=self.data_format)

    net.set_train()
    out = net(input_x, weight)
    return out.asnumpy()
黎冠新 上传了附件Conv3D交付件.zip
hedongdong 修改了标题
hedongdong 修改了描述

继续补充

    def forward_cmp(self):
        out_pytorch = self.forward_pytorch_impl()
        out_mindspore = self.forward_mindspore_impl()
        if self.dtype == np.float16:
            allclose_nparray(out_pytorch, out_mindspore, self.loss, self.loss)
        else:
            allclose_nparray(out_pytorch, out_mindspore, self.loss * 10, self.loss * 10)

再次验证精度问题:
输入图片说明
输入图片说明
把对应测试改成这种写法,两个用例分别执行了20次,没有发现精度问题
针对这个算子,建议还是按照issue链接下面提供的样例验证和提交 @黎冠新

反向写法mindspore写法:

    def grad_mindspore_impl(self):
        input_x = Tensor(self.input_np)
        weight = Tensor(self.weight)
        self.output_grad_np = np.random.randn(*list(self.forward_mindspore_impl().shape)).astype(
            np.float16).astype(
            self.dtype)
        output_grad = Tensor(self.output_grad_np.astype(self.dtype))
        net = Conv3dnet(out_channel=self.out_channel, kernel_size=self.kernel_size, mode=self.mode,
                        pad_mode=self.pad_mode,
                        pad=self.pad, stride=self.stride,
                        dilation=self.dilation, group=self.group, data_format=self.data_format)
        grad_net = GradOfAllInputsAndParams(net)
        grad_net.set_train()
        out_grad = grad_net(input_x, weight, output_grad)
        return out_grad[0][0].asnumpy(), out_grad[0][1].asnumpy()

反向pytorch的写法:

    def grad_pytorch_impl(self):
        inputa = torch.from_numpy(self.input_np.copy().astype(np.float32))
        weight = torch.nn.Parameter(torch.from_numpy(self.weight.copy().astype(np.float32)))
        if isinstance(self.pad, tuple) and len(self.pad1) == 6:
            inputa.requires_grad = True
            inputa_py = torch_f.pad(inputa, self.padding_torch)
            net = torch.nn.Conv3d(in_channels=self.in_c, out_channels=self.out_channel,
                                  kernel_size=self.kernel_size,
                                  stride=self.stride, padding=0,
                                  dilation=self.dilation_torch, groups=self.group, bias=False)
            net.register_parameter('weight', weight)

            output = net(inputa_py)
            output_grad = torch.from_numpy(self.output_grad_np.copy().astype(np.float32))
            output.backward(gradient=output_grad)
        else:
            net = torch.nn.Conv3d(in_channels=self.in_c, out_channels=self.out_channel,
                                  kernel_size=self.kernel_size, stride=self.stride,
                                  padding=self.padding_torch,
                                  dilation=self.dilation_torch, groups=self.group, bias=False)
            net.register_parameter('weight', weight)

            inputa.requires_grad = True
            output = net(inputa)
            output_grad = torch.from_numpy(
                self.output_grad_np.copy().astype(np.float16).astype(np.float32))
            output.backward(gradient=output_grad)
        return inputa.grad.numpy().astype(self.dtype), weight.grad.numpy().astype(self.dtype)
黎冠新 删除了附件Conv3D交付件.zip
黎冠新 上传了附件Conv3D交付件.zip

pytest operations/test_conv3d.py --count=50
输入图片说明
输入图片说明

最新代码问题:
1 gpu上float64用例连续执行多次,还是会出现异常超大值
复现方法:
pip install pytest-repeat
pytest test_xxx.py::test_case_name --count=10

2 代码更新后gpu上float64和float32出现了精度问题

连续执行多次结果异常的问题可以在Resize中打印一下input_size_list_, output_size_list_,workspace_size_list_,以及其他一些成员变量的值,
查看重复运行,这几个是不是一样的
conv3d_grad_filter_gpu_kernel.h

def test_conv3d_float64_1():
    x_shape = [32, 8, 32, 32, 32]
    fil_shape = [8, 8, 4, 6, 2]
    dout_shape = [32, 8, 6, 7, 4]
    input_x = Tensor(np.random.randn(*x_shape)).astype(np.float64)
    input_fil = Tensor(np.random.randn(*fil_shape)).astype(np.float64)
    input_dout = Tensor(np.random.randn(*dout_shape)).astype(np.float64)
    attributes = {"out_channel": 8, "kernel_size": (4, 6, 2), "mode": 1, "pad_mode": "pad", "pad": 0,
                  "stride": (5,4,9), "dilation": 1, "group": 1, "data_format": "NCDHW"}
    fact = Conv3DMock(attributes=attributes, inputs=[input_x, input_fil, input_dout])
    fact.forward_cmp()
    fact.grad_cmp()

def test_conv3d_float64_2():
    x_shape = [32, 8, 32, 32, 32]
    fil_shape = [8, 8, 4, 6, 2]
    dout_shape = [32, 8, 7, 4, 4]
    input_x = Tensor(np.random.randn(*x_shape)).astype(np.float64)
    input_fil = Tensor(np.random.randn(*fil_shape)).astype(np.float64)
    input_dout = Tensor(np.random.randn(*dout_shape)).astype(np.float64)
    attributes = {"out_channel": 8, "kernel_size": (4, 6, 2), "mode": 1, "pad_mode": "pad", "pad": 2,
                  "stride": (5,4,9), "dilation": (1,4,3), "group": 1, "data_format": "NCDHW"}
    fact = Conv3DMock(attributes=attributes, inputs=[input_x, input_fil, input_dout])
    fact.forward_cmp()
    fact.grad_cmp()

def test_conv3d_float64_3():
    x_shape = [32, 8, 32, 32, 32]
    fil_shape = [8, 8, 4, 5, 4]
    dout_shape = [32, 8, 35, 22, 11]
    input_x = Tensor(np.random.randn(*x_shape)).astype(np.float64)
    input_fil = Tensor(np.random.randn(*fil_shape)).astype(np.float64)
    input_dout = Tensor(np.random.randn(*dout_shape)).astype(np.float64)
    attributes = {"out_channel": 8, "kernel_size": (4, 5, 4), "mode": 1, "pad_mode": "pad", "pad": 3,
                  "stride": 1, "dilation": (1,4,9), "group": 1, "data_format": "NCDHW"}
    fact = Conv3DMock(attributes=attributes, inputs=[input_x, input_fil, input_dout])
    fact.forward_cmp()
    fact.grad_cmp()

def test_conv3d_float64_4():
    x_shape = [16, 3, 20, 64, 32]
    fil_shape = [6, 3, 4, 6, 4]
    dout_shape = [16, 6, 6, 12, 8]
    input_x = Tensor(np.random.randn(*x_shape)).astype(np.float64)
    input_fil = Tensor(np.random.randn(*fil_shape)).astype(np.float64)
    input_dout = Tensor(np.random.randn(*dout_shape)).astype(np.float64)
    attributes = {"out_channel": 6, "kernel_size": (4, 6, 4), "mode": 1, "pad_mode": "pad", "pad": 0,
                  "stride": (3,5,4), "dilation": 1, "group": 1, "data_format": "NCDHW"}
    fact = Conv3DMock(attributes=attributes, inputs=[input_x, input_fil, input_dout])
    fact.forward_cmp()
    fact.grad_cmp()
黎冠新 删除了附件Conv3D交付件.zip
黎冠新 上传了附件Conv3D交付件.zip
黎冠新 删除了附件Conv3D交付件.zip
黎冠新 上传了附件Conv3D交付件.zip

登录 后才可以发表评论

状态
负责人
项目
里程碑
Pull Requests
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
开始日期   -   截止日期
-
置顶选项
优先级
预计工期 (小时)
参与者(5)
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore

搜索帮助