2.4K Star 8.2K Fork 4.4K

GVPMindSpore / mindspore

 / 详情

[CT][MS][SoftMarginLoss] GPU环境两种模式下,float16偶现精度问题

DONE
Bug-Report
创建于  
2022-08-08 09:32

Describe the current behavior / 问题描述 (Mandatory / 必填)

class SoftMarginLossMock(OpsFactory):
    def __init__(self, attributes={'reduction': 'mean'}, inputs=None, grads=None):
        self.ms_type = inputs[0].dtype
        super().__init__(dtype=dtype_to_nptype(self.ms_type))

        self.reduction = attributes.get('reduction', 'mean')
        self.logits = inputs[0]
        self.logits_np = inputs[0].asnumpy()
        self.labels = inputs[1]
        self.labels_np = inputs[1].asnumpy()

        if grads is None:
            self.out_grad_np = None
        elif self.reduction == "sum" or self.reduction == "mean":
            self.out_grad_np = np.array(grads)
        elif self.reduction == "none":
            self.out_grad_np = grads

    def forward_mindspore_impl(self):
        net = SoftMarginLoss(reduction=self.reduction)
        out = net(self.logits, self.labels)
        return out.asnumpy()

    def forward_pytorch_impl(self):
        net = torch.nn.SoftMarginLoss(reduction=self.reduction)
        logits = torch.from_numpy(self.logits_np.astype(np.float32))
        labels = torch.from_numpy(self.labels_np.astype(np.float32))
        output = net(logits, labels)
        return output.detach().numpy().astype(self.dtype)

    def forward_cmp(self):
        out_pytorch = self.forward_pytorch_impl()
        out_mindspore = self.forward_mindspore_impl()
        allclose_nparray(out_pytorch, out_mindspore, self.loss, self.loss)

    def grad_mindspore_impl(self):
        if self.out_grad_np is None:
            self.out_grad_np = self.forward_pytorch_impl()
        net = SoftMarginLoss(reduction=self.reduction)
        grad_net = GradOfFirstInput(net)
        grad_net.set_train()
        out_grad = grad_net(self.logits, self.labels, Tensor(self.out_grad_np))
        return out_grad.asnumpy()

    def grad_pytorch_impl(self):
        if self.out_grad_np is None:
            self.out_grad_np = self.forward_pytorch_impl()
        net = torch.nn.SoftMarginLoss(reduction=self.reduction)
        logits = torch.from_numpy(self.logits_np.astype(np.float32))
        labels = torch.from_numpy(self.labels_np.astype(np.float32))
        logits.requires_grad = True
        output = net(logits, labels)
        output_grad = torch.from_numpy(self.out_grad_np.copy().astype(np.float32))
        output.backward(gradient=output_grad)
        return logits.grad.numpy().astype(self.dtype)

    def grad_cmp(self):
        input_grad_mindspore = self.grad_mindspore_impl()
        input_grad_pytorch = self.grad_pytorch_impl()
        allclose_nparray(input_grad_pytorch, input_grad_mindspore, self.loss, self.loss)

1,
GPU环境两种模式下,当输入数据类型为float16时,偶现精度问题
(用例执行十次,失败三次,通过七次)

collected 10 items

test_softmarginloss.py::test_softmarginloss_input_2d_dtype_float16[1-10] PASSED                                                                     [ 10%]
test_softmarginloss.py::test_softmarginloss_input_2d_dtype_float16[2-10] PASSED                                                                     [ 20%]
test_softmarginloss.py::test_softmarginloss_input_2d_dtype_float16[3-10] PASSED                                                                     [ 30%]
test_softmarginloss.py::test_softmarginloss_input_2d_dtype_float16[4-10] PASSED                                                                     [ 40%]
test_softmarginloss.py::test_softmarginloss_input_2d_dtype_float16[5-10] FAILED                                                                     [ 50%]
test_softmarginloss.py::test_softmarginloss_input_2d_dtype_float16[6-10] PASSED                                                                     [ 60%]
test_softmarginloss.py::test_softmarginloss_input_2d_dtype_float16[7-10] FAILED                                                                     [ 70%]
test_softmarginloss.py::test_softmarginloss_input_2d_dtype_float16[8-10] FAILED                                                                     [ 80%]
test_softmarginloss.py::test_softmarginloss_input_2d_dtype_float16[9-10] PASSED                                                                     [ 90%]
test_softmarginloss.py::test_softmarginloss_input_2d_dtype_float16[10-10] PASSED                                                                    [100%]

======================================================================== FAILURES =========================================================================
____________________________________________________ test_softmarginloss_input_2d_dtype_float16[5-10] _____________________________________________________

    def test_softmarginloss_input_2d_dtype_float16():
        input_list = []
        logits = Tensor(np.random.randn(6, 8), dtype=mstype.float16)
        labels = Tensor(np.random.randn(6, 8), dtype=mstype.float16)
        input_list.append(logits)
        input_list.append(labels)
        fact = SoftMarginLossMock(attributes={'reduction': 'sum'}, inputs=input_list)
        fact.loss = 1e-3
        fact.forward_cmp()
>       fact.grad_cmp()

test_softmarginloss.py:342:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../share/ops/nn/softmarginloss_ops.py:78: in grad_cmp
    allclose_nparray(input_grad_pytorch, input_grad_mindspore, self.loss, self.loss)
../share/utils.py:31: in allclose_nparray
    _count_unequal_element(data_expected, data_me, rtol, atol)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

data_expected = array([[ -0.8887,   6.652 , -15.14  , -20.62  ,   3.377 ,  13.02  ,
         40.5   , -26.78  ],
       [ 32.34  ,   1...996 ],
       [  8.805 , -11.79  ,  46.84  ,   4.184 ,   3.857 , -10.2   ,
         -8.9   , -55.94  ]], dtype=float16)
data_me = array([[ -0.889 ,   6.652 , -15.15  , -20.64  ,   3.377 ,  13.016 ,
         40.5   , -26.75  ],
       [ 32.34  ,   1...    ],
       [  8.805 , -11.79  ,  46.84  ,   4.184 ,   3.854 , -10.21  ,
         -8.9   , -55.9   ]], dtype=float16)
rtol = 0.001, atol = 0.001

    def _count_unequal_element(data_expected, data_me, rtol, atol):
        assert data_expected.shape == data_me.shape
        total_count = len(data_expected.flatten())
        error = np.abs(data_expected - data_me)
        greater = np.greater(error, atol + np.abs(data_me) * rtol)
        loss_count = np.count_nonzero(greater)
        assert (loss_count / total_count) < rtol, \
            "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
>               format(data_expected[greater], data_me[greater], error[greater])
E       AssertionError:
E       data_expected_std:[-26.78]
E       data_me_error:[-26.75]
E       loss:[0.03125]

2,
test_softmarginloss_input_5d_dtype_float16

    def test_softmarginloss_input_5d_dtype_float16():
        input_list = []
        logits = Tensor(np.random.randn(64, 128, 5, 4, 2), dtype=mstype.float16)
        labels = Tensor(np.random.randn(64, 128, 5, 4, 2), dtype=mstype.float16)
        input_list.append(logits)
        input_list.append(labels)
        fact = SoftMarginLossMock(inputs=input_list)
        fact.loss = 1e-3
        fact.forward_cmp()
>       fact.grad_cmp()

test_softmarginloss.py:421:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../share/ops/nn/softmarginloss_ops.py:78: in grad_cmp
    allclose_nparray(input_grad_pytorch, input_grad_mindspore, self.loss, self.loss)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

data_expected = array([[[[[-2.38e-07,  4.77e-07],
          [ 5.96e-07,  1.67e-06],
          [ 1.19e-06, -4.17e-07],
          [-5.96...         [ 1.97e-06,  7.75e-07],
          [-1.37e-06,  5.96e-08],
          [ 2.38e-07, -5.36e-07]]]]], dtype=float16)
data_me = array([[[[[-0.,  0.],
          [ 0.,  0.],
          [ 0., -0.],
          [-0., -0.]],

         [[-0., -0.],
      ... 0.,  0.]],

         [[-0., -0.],
          [ 0.,  0.],
          [-0.,  0.],
          [ 0., -0.]]]]], dtype=float16)
rtol = 0.001, atol = 0.001, equal_nan = True

    def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
        if np.any(np.isnan(data_expected)) or np.any(np.isnan(data_me)):
>           assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
E           AssertionError

Environment / 环境信息 (Mandatory / 必填)

  • Hardware Environment(Ascend/GPU/CPU) / 硬件环境:

Please delete the backend not involved / 请删除不涉及的后端:
/device ascend/GPU/CPU/kirin/等其他芯片

  • Software Environment / 软件环境 (Mandatory / 必填):
    -- MindSpore version (e.g., 1.7.0.Bxxx) :
    -- Python version (e.g., Python 3.7.5) :
    -- OS platform and distribution (e.g., Linux Ubuntu 16.04):
    -- GCC/Compiler version (if compiled from source):

  • Excute Mode / 执行模式 (Mandatory / 必填)(PyNative/Graph):

Please delete the mode not involved / 请删除不涉及的模式:
/mode pynative
/mode graph

Related testcase / 关联用例 (Mandatory / 必填)

Steps to reproduce the issue / 重现步骤 (Mandatory / 必填)

Describe the expected behavior / 预期结果 (Mandatory / 必填)

Related log / screenshot / 日志 / 截图 (Mandatory / 必填)

Special notes for this issue/备注 (Optional / 选填)

评论 (3)

xuebao_zhang 创建了Bug-Report
xuebao_zhang 添加了
 
kind/bug
标签
展开全部操作日志

Please assign maintainer to check this issue.
请为此issue分配处理人。
@xuebao_zhang

Please add labels (comp or sig), also you can visit https://gitee.com/mindspore/community/blob/master/sigs/dx/docs/labels.md to find more.
为了让代码尽快被审核,请您为Pull Request打上 组件(comp)或兴趣组(sig) 标签,打上标签的PR可直接推送给责任人进行审核。
更多的标签可以查看https://gitee.com/mindspore/community/blob/master/sigs/dx/docs/labels.md
以组件相关代码提交为例,如果你提交的是data组件代码,你可以这样评论:
//comp/data
当然你也可以邀请data SIG组来审核代码,可以这样写:
//sig/data
另外你还可以给这个PR标记类型,例如是bugfix或者是特性需求:
//kind/bug or //kind/feature
恭喜你,你已经学会了使用命令来打标签,接下来就在下面的评论里打上标签吧!

xuebao_zhang 修改了描述

用例问题,修改ops文件为以下方式:
if self.out_grad_np is None:
self.out_grad_np = np.array(
np.random.randn(*list(self.forward_pytorch_impl().shape))).astype(self.dtype)

xuebao_zhang 任务状态TODO 修改为DONE

登录 后才可以发表评论

状态
负责人
项目
里程碑
Pull Requests
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
开始日期   -   截止日期
-
置顶选项
优先级
预计工期 (小时)
参与者(2)
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore

搜索帮助