From 5db3a51fe925ecb482e936a52cadf5a97d856a0e Mon Sep 17 00:00:00 2001 From: fangruohuan Date: Sat, 14 Dec 2024 12:19:26 +0800 Subject: [PATCH] add ut data cache --- ci/access_control_test.py | 6 +- tests/torch/data_cache.py | 171 ++++++++++++++++++ tests/torch/test_add_relu.py | 52 +++--- tests/torch/test_assign_score_withk.py | 6 +- tests/torch/test_bev_pool.py | 4 + tests/torch/test_bev_pool_v2.py | 4 + tests/torch/test_bev_pool_v3.py | 20 +- tests/torch/test_border_align.py | 17 +- tests/torch/test_box_iou.py | 7 +- tests/torch/test_boxes_overlap_bev.py | 58 +++--- tests/torch/test_cal_anchors_heading.py | 6 +- tests/torch/test_deformable_aggregation.py | 38 ++-- .../torch/test_deformable_aggregation_grad.py | 73 ++++---- .../test_furthest_point_sample_with_dist.py | 10 +- tests/torch/test_furthest_point_sampling.py | 8 +- tests/torch/test_fused_bias_leaky_relu.py | 76 ++++---- .../torch/test_geometric_kernel_attention.py | 65 +++++-- tests/torch/test_group_points.py | 22 ++- tests/torch/test_group_points_grad.py | 20 +- tests/torch/test_hard_voxelize.py | 7 +- tests/torch/test_hypot.py | 88 ++++----- tests/torch/test_knn.py | 17 +- .../torch/test_multi_scale_deformable_attn.py | 47 +++-- tests/torch/test_npu_dyn_voxelization.py | 10 +- tests/torch/test_npu_dynamic_scatter.py | 5 +- tests/torch/test_npu_max_pool2d.py | 4 + tests/torch/test_npu_nms3d.py | 3 +- tests/torch/test_npu_nms3d_normal.py | 7 +- tests/torch/test_pixel_group.py | 12 +- tests/torch/test_point_to_voxel.py | 5 + tests/torch/test_points_in_box.py | 38 ++-- tests/torch/test_points_in_box_all.py | 48 +++-- tests/torch/test_roi_align_rotated.py | 14 +- tests/torch/test_roiaware_pool3d.py | 12 +- tests/torch/test_roiaware_pool3d_grad.py | 11 +- tests/torch/test_roipoint_pool3d.py | 4 + tests/torch/test_rotated_iou.py | 17 +- tests/torch/test_scatter_max.py | 9 +- tests/torch/test_scatter_mean.py | 40 ++-- tests/torch/test_three_interpolate.py | 7 +- tests/torch/test_three_nn.py | 21 ++- tests/torch/test_unique_voxel.py | 7 + tests/torch/test_vec_pool_backward.py | 36 ++-- tests/torch/test_voxel_pooling_train.py | 5 + tests/torch/test_voxel_to_point.py | 5 + 45 files changed, 791 insertions(+), 351 deletions(-) create mode 100644 tests/torch/data_cache.py diff --git a/ci/access_control_test.py b/ci/access_control_test.py index dc341e47..20560149 100644 --- a/ci/access_control_test.py +++ b/ci/access_control_test.py @@ -24,6 +24,7 @@ import warnings from abc import ABCMeta, abstractmethod from pathlib import Path +NUM_DEVICE = 8 BASE_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) # project root TEST_DIR = os.path.join(BASE_DIR, "tests", "torch") @@ -195,9 +196,8 @@ def exec_ut(files): stdout_queue = queue.Queue() event_timer = threading.Event() - p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, stdout=subprocess.PIPE) + p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.STDOUT) start_thread(wait_thread, p, event_timer) - start_thread(enqueue_output, p.stdout, stdout_queue) try: event_timer.wait(2000) @@ -220,6 +220,8 @@ def exec_ut(files): has_failed = 0 for ut_type, ut_files in files.items(): for ut_file in ut_files: + if not os.path.basename(ut_file).startswith("test_"): + continue cmd = get_ut_cmd(ut_type, ut_file) ut_info = " ".join(cmd[4:]).replace(" -- -k", "") ret = run_cmd_with_timeout(cmd) diff --git a/tests/torch/data_cache.py b/tests/torch/data_cache.py new file mode 100644 index 00000000..b0ff151f --- /dev/null +++ b/tests/torch/data_cache.py @@ -0,0 +1,171 @@ +import hashlib +import json +import os +import pickle +import traceback +import inspect +from functools import wraps + +import numpy as np +import torch + + +def serialize_param(param): + # 将输入参数转字符串, 需要考虑列表或者字典的情况 + if isinstance(param, (list, tuple)): + return ''.join(serialize_param(item) for item in param) + elif isinstance(param, dict): + items = sorted(param.items()) + return ''.join(serialize_param(v) for k, v in items) + else: + return hash_object(param) + "_" + + +def hash_object(obj): + # 将对象转换为字符串 + try: + obj_bytes = pickle.dumps(obj) + hasher = hashlib.sha256() + hasher.update(obj_bytes) + hex_digest = hasher.hexdigest()[:10] + except Exception as e: + # 有些对象没法被sha256哈希化,暂时跳过 + hex_digest = "" + + return hex_digest + + +def save_data(data, save_path, case_name): + file_names = [] + # 考虑有多个返回值的情况 + if isinstance(data, tuple) or isinstance(data, list): + for i, result in enumerate(data): + if isinstance(result, np.ndarray): + filename_i = case_name + str(i) + ".npy" + file_save_path = os.path.join(save_path, filename_i) + np.save(file_save_path, result) + elif isinstance(result, torch.Tensor): + filename_i = case_name + str(i) + ".pth" + file_save_path = os.path.join(save_path, filename_i) + torch.save(result, file_save_path) + elif isinstance(result, (int, float, list, tuple)): + filename_i = case_name + str(i) + ".json" + file_save_path = os.path.join(save_path, filename_i) + result_ = {} + result_["result"] = result + with open(file_save_path, 'w') as json_file: + json.dump(result_, json_file) + else: + raise ValueError(f"Save cache data failed, return data type should be np.ndarray, torch.Tensor, int or float, but got {type(result)}") + + file_names.append(filename_i) + + elif isinstance(data, np.ndarray): + filename = case_name + ".npy" + file_save_path = os.path.join(save_path, filename) + np.save(file_save_path, data) + file_names.append(filename) + elif isinstance(data, torch.Tensor): + filename = case_name + ".pth" + file_save_path = os.path.join(save_path, filename) + torch.save(data, file_save_path) + file_names.append(filename) + elif isinstance(data, (int, float, list, tuple)): + filename = case_name + ".json" + file_save_path = os.path.join(save_path, filename) + result_ = {} + result_["result"] = data + with open(file_save_path, 'w') as json_file: + json.dump(result_, json_file) + file_names.append(filename) + else: + raise ValueError(f"Save cache data failed, return data type should be np.ndarray, torch.Tensor, int or float, but got {type(data)}") + + + return file_names + + +def load_data(save_path, file_names): + results = [] + for file_name in file_names: + file_path = os.path.join(save_path, file_name) + if not os.path.exists(file_path): + raise FileNotFoundError(f"Load cache data failed. {file_path} not exists.") + if file_name.endswith(".npy"): + result = np.load(file_path) + results.append(result) + elif file_name.endswith(".pth"): + result = torch.load(file_path) + results.append(result) + elif file_name.endswith(".json"): + with open(file_path, 'r') as json_file: + result = json.load(json_file)["result"] + results.append(result) + + return results + + +def golden_data_cache(ut_name, save_path=None, refresh_data=False): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + # 基于args和kwargs生成缓存数据的casename, 表示Golden Function输入的参数 + case_name = '' + for arg in args: + case_name += (serialize_param(arg)) + for k, v in kwargs.items(): + case_name += (serialize_param(v)) + + save_path_ = save_path + if save_path_ is None: + if os.getenv('MXDRIVING_CACHE_PATH', None) is not None: + save_path_ = os.getenv('MXDRIVING_CACHE_PATH', None) + else: + current_file_path = os.path.abspath(__file__) + save_path_ = os.path.dirname(current_file_path) + ut_name_ = os.path.basename(ut_name) + ut_name_ = os.path.splitext(ut_name_)[0] + save_path_ = os.path.join(save_path_, "data_cache", ut_name_, func.__name__) + cache_data_path = os.path.join(save_path_, case_name + ".json") + + # 如果路径下没有缓存,则重新生成缓存数据 + if not os.path.exists(cache_data_path) or refresh_data: + if not os.path.exists(save_path_): + os.makedirs(save_path_) + # 用于存储该case下面所有cache data的数据名称 + cache_data_names = {} + + results = func(*args, **kwargs) + + # 保存数据 + try: + file_names = save_data(results, save_path_, case_name) + if len(file_names) > 0: + cache_data_names[case_name] = file_names + with open(cache_data_path, 'w') as f: + json.dump(cache_data_names, f) + print(f"Cache data saved in {save_path_}.") + except Exception as e: + print("Failed to save cache.") + traceback.print_exc() + + else: + with open(cache_data_path, 'r') as file: + cache_data_names = json.load(file) + file_names = cache_data_names[case_name] + # 读取数据 + try: + results = load_data(save_path_, file_names) + if len(results) == 1: + results = results[0] + else: + results = tuple(results) + print(f"Load cache data from {save_path_}.") + except Exception as e: + results = func(*args, **kwargs) + print("Failed to load cache, using golden function to generate data.") + traceback.print_exc() + + return results + return wrapper + return decorator \ No newline at end of file diff --git a/tests/torch/test_add_relu.py b/tests/torch/test_add_relu.py index 6c52d545..69508c2b 100644 --- a/tests/torch/test_add_relu.py +++ b/tests/torch/test_add_relu.py @@ -1,23 +1,38 @@ import unittest -import torch + import numpy as np -import torch_npu +import torch import torch.nn.functional as F - +import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving.fused + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] +@golden_data_cache(__file__) +def gen_inputs(shape, dtype): + x = np.random.uniform(1, 1, shape).astype(dtype) + x = torch.from_numpy(x) + y = np.random.uniform(1, 1, shape).astype(dtype) + y = torch.from_numpy(y) + return x, y + + +@golden_data_cache(__file__) +def gen_cpu_outputs(x, y): + cpu_result = F.relu(x.float() + y.float()) + return cpu_result + + class TestPointsInBox(TestCase): @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `AddRelu` is only supported on 910B, skip this ut!") def test_npu_add_relu_three_dim(self, device="npu"): - x = np.random.uniform(1, 1, [1, 100, 3]).astype(np.float32) - x = torch.from_numpy(x) - y = np.random.uniform(2.0, 2.0, [1, 100, 3]).astype(np.float32) - y = torch.from_numpy(y) - cpu_result = F.relu(x + y) + x, y = gen_inputs([1, 100, 3], np.float32) + cpu_result = gen_cpu_outputs(x, y) result = mx_driving.fused.npu_add_relu(x.npu(), y.npu()).cpu().numpy() self.assertRtolEqual(result, cpu_result.numpy()) result = mx_driving.npu_add_relu(x.npu(), y.npu()).cpu().numpy() @@ -25,11 +40,8 @@ class TestPointsInBox(TestCase): @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `AddRelu` is only supported on 910B, skip this ut!") def test_npu_add_relu_large_number(self, device="npu"): - x = np.random.uniform(1, 1, [18, 256, 232, 400]).astype(np.float32) - x = torch.from_numpy(x) - y = np.random.uniform(2.0, 2.0, [18, 256, 232, 400]).astype(np.float32) - y = torch.from_numpy(y) - cpu_result = F.relu(x + y) + x, y = gen_inputs([18, 256, 232, 100], np.float32) + cpu_result = gen_cpu_outputs(x, y) result = mx_driving.fused.npu_add_relu(x.npu(), y.npu()).cpu().numpy() self.assertRtolEqual(result, cpu_result.numpy()) result = mx_driving.npu_add_relu(x.npu(), y.npu()).cpu().numpy() @@ -37,11 +49,8 @@ class TestPointsInBox(TestCase): @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `AddRelu` is only supported on 910B, skip this ut!") def test_npu_add_relu_fp16_large_number(self, device="npu"): - x = np.random.uniform(1, 1, [18, 256, 232, 400]).astype(np.float16) - x = torch.from_numpy(x) - y = np.random.uniform(2.0, 2.0, [18, 256, 232, 400]).astype(np.float16) - y = torch.from_numpy(y) - cpu_result = F.relu(x.float() + y.float()) + x, y = gen_inputs([18, 256, 232, 100], np.float16) + cpu_result = gen_cpu_outputs(x, y) result = mx_driving.fused.npu_add_relu(x.npu(), y.npu()).cpu().numpy() self.assertRtolEqual(result, cpu_result.half().numpy()) result = mx_driving.npu_add_relu(x.npu(), y.npu()).cpu().numpy() @@ -49,11 +58,8 @@ class TestPointsInBox(TestCase): @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `AddRelu` is only supported on 910B, skip this ut!") def test_npu_add_relu_fp16_small_case(self, device="npu"): - x = np.random.uniform(1, 1, [18]).astype(np.float16) - x = torch.from_numpy(x) - y = np.random.uniform(2.0, 2.0, [18]).astype(np.float16) - y = torch.from_numpy(y) - cpu_result = F.relu(x.float() + y.float()) + x, y = gen_inputs([18], np.float16) + cpu_result = gen_cpu_outputs(x, y) result = mx_driving.fused.npu_add_relu(x.npu(), y.npu()).cpu().numpy() self.assertRtolEqual(result, cpu_result.half().numpy()) result = mx_driving.npu_add_relu(x.npu(), y.npu()).cpu().numpy() diff --git a/tests/torch/test_assign_score_withk.py b/tests/torch/test_assign_score_withk.py index c13f62bc..4a6f80ea 100644 --- a/tests/torch/test_assign_score_withk.py +++ b/tests/torch/test_assign_score_withk.py @@ -1,10 +1,13 @@ -import torch import numpy as np +import torch +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving # 'pylint: disable=too-many-arguments,huawei-too-many-arguments +@golden_data_cache(__file__) def gen_data(B, N, npoint, M, K, out_dim): points = np.random.rand(B, N, M, out_dim).astype(np.float32) centers = np.random.rand(B, N, M, out_dim).astype(np.float32) @@ -148,6 +151,7 @@ class TestAssignScoreWithk(TestCase): class TestAssignScoreWithkGrad(TestCase): # 'pylint: disable=too-many-arguments,huawei-too-many-arguments + @golden_data_cache(__file__) def cpu_backward_op(self, grad_out, scores, diff --git a/tests/torch/test_bev_pool.py b/tests/torch/test_bev_pool.py index 7767305a..e150af91 100644 --- a/tests/torch/test_bev_pool.py +++ b/tests/torch/test_bev_pool.py @@ -3,6 +3,7 @@ import unittest import numpy as np import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests import mx_driving.point @@ -13,6 +14,7 @@ DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] # pylint: disable=too-many-arguments,huawei-too-many-arguments +@golden_data_cache(__file__) def golden_bev_pool(feat, geom_feat, b, d, h, w, c): output = np.zeros((b, d, h, w, c), dtype=np.float32) ranks = geom_feat[:, 0] * (w * d * b) + geom_feat[:, 1] * (d * b) + geom_feat[:, 2] * b + geom_feat[:, 3] @@ -33,6 +35,7 @@ def golden_bev_pool(feat, geom_feat, b, d, h, w, c): # pylint: disable=too-many-arguments,huawei-too-many-arguments +@golden_data_cache(__file__) def golden_bev_pool_grad(feat, geom_feat, interval_starts, interval_lengths, grad_output, b, d, h, w, c): grad_feat = np.zeros_like(feat) for start, length in zip(interval_starts, interval_lengths): @@ -43,6 +46,7 @@ def golden_bev_pool_grad(feat, geom_feat, interval_starts, interval_lengths, gra return grad_feat +@golden_data_cache(__file__) def generate_bev_pool_data(n, b, d, h, w, c): feat = np.random.rand(n, c).astype(np.float32) geom_feat_b = np.random.randint(0, b, (n,)).astype(np.int32) diff --git a/tests/torch/test_bev_pool_v2.py b/tests/torch/test_bev_pool_v2.py index f34e8c9b..c722e133 100644 --- a/tests/torch/test_bev_pool_v2.py +++ b/tests/torch/test_bev_pool_v2.py @@ -1,6 +1,7 @@ import numpy as np import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests import mx_driving.point @@ -12,6 +13,7 @@ DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] # pylint: disable=too-many-arguments,huawei-too-many-arguments +@golden_data_cache(__file__) def golden_bev_pool_v2( depth, feat, ranks_depth, ranks_feat, ranks_bev, interval_starts, interval_lengths, b, d, h, w, c ): @@ -28,6 +30,7 @@ def golden_bev_pool_v2( # pylint: disable=too-many-arguments,huawei-too-many-arguments +@golden_data_cache(__file__) def golden_bev_pool_v2_grad( grad_out, depth, feat, ranks_depth, ranks_feat, ranks_bev, interval_starts, interval_lengths, b, d, h, w, c ): @@ -46,6 +49,7 @@ def golden_bev_pool_v2_grad( # pylint: disable=too-many-return-values +@golden_data_cache(__file__) def generate_bev_pool_data(B, D, H, W, C, N_RANKS): feat = np.random.rand(B, 1, H, W, C).astype(np.float32) depth = np.random.rand(B, 1, D, H, W).astype(np.float32) diff --git a/tests/torch/test_bev_pool_v3.py b/tests/torch/test_bev_pool_v3.py index 578f41fb..b588e6e4 100644 --- a/tests/torch/test_bev_pool_v3.py +++ b/tests/torch/test_bev_pool_v3.py @@ -2,6 +2,7 @@ import unittest import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests from mx_driving import bev_pool_v3 @@ -11,11 +12,13 @@ DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] # pylint: disable=too-many-arguments,huawei-too-many-arguments +@golden_data_cache(__file__) def golden_bev_pool_v3(depth, feat, ranks_depth, ranks_feat, ranks_bev, bev_feat_shape): B, D, H, W, C = bev_feat_shape N_RANKS = ranks_depth.shape[0] depth = depth.view([-1]) feat = feat.view([-1, C]) + out = torch.zeros([B * D * H * W, C]) for i in range(N_RANKS): d = depth[ranks_depth[i]] @@ -23,11 +26,20 @@ def golden_bev_pool_v3(depth, feat, ranks_depth, ranks_feat, ranks_bev, bev_feat b = ranks_bev[i] out[b] += d * f out = out.view(bev_feat_shape) + out = torch.permute(out, [0, 4, 1, 2, 3]) return out +@golden_data_cache(__file__) +def golden_bev_pool_v3_grad(bev_feat_cpu, grad_out, feat, depth): + bev_feat_cpu.backward(grad_out) + + return feat.grad, depth.grad + + # pylint: disable=too-many-return-values +@golden_data_cache(__file__) def generate_bev_pool_data(B, D, H, W, C, N_RANKS): depth = torch.rand([B, 1, D, H, W]) feat = torch.rand([B, 1, H, W, C]) @@ -56,6 +68,7 @@ class TestBEVPoolV2(TestCase): feat, depth, grad_out, ranks_depth, ranks_feat, ranks_bev, bev_feat_shape = generate_bev_pool_data( B, D, H, W, C, N_RANKS ) + feat_npu = feat.clone().to("npu") depth_npu = depth.clone().to("npu") grad_out_npu = grad_out.clone().to("npu") @@ -68,15 +81,16 @@ class TestBEVPoolV2(TestCase): depth_npu.requires_grad_() bev_feat_cpu = golden_bev_pool_v3(depth, feat, ranks_depth, ranks_feat, ranks_bev, bev_feat_shape) + bev_feat_grad_cpu, bev_depth_grad_cpu = golden_bev_pool_v3_grad(bev_feat_cpu, grad_out, feat, depth) + bev_feat_npu = bev_pool_v3( depth_npu, feat_npu, ranks_depth_npu, ranks_feat_npu, ranks_bev_npu, bev_feat_shape ) - bev_feat_cpu.backward(grad_out) bev_feat_npu.backward(grad_out_npu) self.assertRtolEqual(bev_feat_npu.detach().cpu().numpy(), bev_feat_cpu.detach().cpu().numpy()) - self.assertRtolEqual(feat_npu.grad.cpu().numpy(), feat.grad.cpu().numpy()) - self.assertRtolEqual(depth_npu.grad.cpu().numpy(), depth.grad.cpu().numpy()) + self.assertRtolEqual(feat_npu.grad.cpu().numpy(), bev_feat_grad_cpu.cpu().numpy()) + self.assertRtolEqual(depth_npu.grad.cpu().numpy(), bev_depth_grad_cpu.cpu().numpy()) if __name__ == "__main__": diff --git a/tests/torch/test_border_align.py b/tests/torch/test_border_align.py index 69fad41e..5929646f 100644 --- a/tests/torch/test_border_align.py +++ b/tests/torch/test_border_align.py @@ -1,18 +1,20 @@ """ Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. """ -import unittest +import copy import math -from typing import List +import unittest from functools import reduce -import copy +from typing import List + import numpy as np import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests -import mx_driving -from mx_driving.detection import border_align +import mx_driving +from mx_driving import border_align torch.npu.config.allow_internal_format = False @@ -21,16 +23,19 @@ DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] EPS = 1e-8 +@golden_data_cache(__file__) def generate_features(feature_shape): features = torch.rand(feature_shape) return features +@golden_data_cache(__file__) def generate_grad_outputs(output_shape): grad_outputs = torch.rand(output_shape) return grad_outputs +@golden_data_cache(__file__) def generate_rois(inputs): num_boxes = inputs.shape[0] * inputs.shape[2] * inputs.shape[3] xyxy = torch.rand(num_boxes, 4) @@ -41,6 +46,7 @@ def generate_rois(inputs): return rois +@golden_data_cache(__file__) def border_align_cpu_golden(inputs, rois, pooled_size_): n, c4, h, w = inputs.shape c = c4 // 4 @@ -207,6 +213,7 @@ def border_align_box(box, pool_size, inputs, output, argmax_idx): return inputs +@golden_data_cache(__file__) def border_align_grad_cpu_golden(boxes, pool_size, inputs, grad_output, argmax_idx): grad_inputs = torch.zeros_like(inputs).detach().cpu().numpy() grad_output = grad_output.transpose(1, 2).contiguous().detach().cpu().numpy() diff --git a/tests/torch/test_box_iou.py b/tests/torch/test_box_iou.py index 2a8f55c9..c57deb6f 100644 --- a/tests/torch/test_box_iou.py +++ b/tests/torch/test_box_iou.py @@ -1,13 +1,16 @@ +import unittest from collections import namedtuple -import unittest import numpy as np import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving.detection from mx_driving import box_iou_quadri + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] @@ -323,6 +326,7 @@ def cpu_box_iou_quadri_aligned(boxes1, boxes2, mode="iou"): return ious +@golden_data_cache(__file__) def cpu_box_iou_quadri(boxes1, boxes2, mode="iou", aligned=False): if aligned: return cpu_box_iou_quadri_aligned(boxes1, boxes2, mode) @@ -368,6 +372,7 @@ def boxes_to_pts(boxes): return np.array(pts) +@golden_data_cache(__file__) def gen_boxes_quadri(boxes_num): boxes = gen_boxes_rotated(boxes_num) boxes = boxes_to_pts(boxes) diff --git a/tests/torch/test_boxes_overlap_bev.py b/tests/torch/test_boxes_overlap_bev.py index 7689f9d2..c81b7573 100644 --- a/tests/torch/test_boxes_overlap_bev.py +++ b/tests/torch/test_boxes_overlap_bev.py @@ -1,14 +1,18 @@ import unittest -from math import cos, sin, fabs, atan2 -from typing import List from collections import namedtuple +from math import atan2, cos, fabs, sin +from typing import List + import numpy as np import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving.detection from mx_driving import boxes_overlap_bev + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] EPS = 1e-8 @@ -200,6 +204,7 @@ def box_overlap(box_a: List[float], box_b: List[float]): return fabs(area) / 2.0 +@golden_data_cache(__file__) def cpu_boxes_overlap_bev(boxes_a: List[List[float]], boxes_b: List[List[float]]): boxes_a_num = boxes_a.shape[0] boxes_b_num = boxes_b.shape[0] @@ -211,6 +216,33 @@ def cpu_boxes_overlap_bev(boxes_a: List[List[float]], boxes_b: List[List[float]] return ans + +@golden_data_cache(__file__) +def cpu_gen_boxes(shape): + boxes_a_num, boxes_b_num = shape + boxes_a = np.zeros((boxes_a_num, 5)) + boxes_b = np.zeros((boxes_b_num, 5)) + for i in range(boxes_a_num): + x1 = np.random.uniform(0, 50) + y1 = np.random.uniform(0, 50) + x2 = x1 + np.random.uniform(0, 50) + y2 = y1 + np.random.uniform(0, 50) + angle = np.random.uniform(0, 1) + boxes_a[i] = [x1, y1, x2, y2, angle] + + for i in range(boxes_b_num): + x1 = np.random.uniform(0, 50) + y1 = np.random.uniform(0, 50) + x2 = x1 + np.random.uniform(0, 50) + y2 = y1 + np.random.uniform(0, 50) + angle = np.random.uniform(0, 1) + boxes_b[i] = [x1, y1, x2, y2, angle] + + boxes_a_cpu = boxes_a.astype(np.float32) + boxes_b_cpu = boxes_b.astype(np.float32) + + return boxes_a_cpu, boxes_b_cpu + Inputs = namedtuple('Inputs', ['boxes_a', 'boxes_b']) @@ -248,27 +280,7 @@ class TestBoxesOverlapBev(TestCase): return test_results def gen_inputs(self, shape, dtype): - boxes_a_num, boxes_b_num = shape - boxes_a = np.zeros((boxes_a_num, 5)) - boxes_b = np.zeros((boxes_b_num, 5)) - for i in range(boxes_a_num): - x1 = np.random.uniform(0, 50) - y1 = np.random.uniform(0, 50) - x2 = x1 + np.random.uniform(0, 50) - y2 = y1 + np.random.uniform(0, 50) - angle = np.random.uniform(0, 1) - boxes_a[i] = [x1, y1, x2, y2, angle] - - for i in range(boxes_b_num): - x1 = np.random.uniform(0, 50) - y1 = np.random.uniform(0, 50) - x2 = x1 + np.random.uniform(0, 50) - y2 = y1 + np.random.uniform(0, 50) - angle = np.random.uniform(0, 1) - boxes_b[i] = [x1, y1, x2, y2, angle] - - boxes_a_cpu = boxes_a.astype(np.float32) - boxes_b_cpu = boxes_b.astype(np.float32) + boxes_a_cpu, boxes_b_cpu = cpu_gen_boxes(shape) boxes_a_npu = torch.from_numpy(boxes_a_cpu).npu() boxes_b_npu = torch.from_numpy(boxes_b_cpu).npu() diff --git a/tests/torch/test_cal_anchors_heading.py b/tests/torch/test_cal_anchors_heading.py index 4f2e7ec8..4b5abfd3 100644 --- a/tests/torch/test_cal_anchors_heading.py +++ b/tests/torch/test_cal_anchors_heading.py @@ -1,10 +1,13 @@ -import torch import numpy as np +import torch +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving class TestCalAnchorsHeading(TestCase): + @golden_data_cache(__file__) def cal_anchors_heading_cpu(self, anchors, origin_pos=None): if origin_pos is None: input_add_start = torch.cat((torch.zeros_like(anchors[:, :, 0:1, :]), anchors), dim=-2) @@ -31,6 +34,7 @@ class TestCalAnchorsHeading(TestCase): heading = mx_driving.cal_anchors_heading(anchors, origin_pos) return heading.cpu().numpy() + @golden_data_cache(__file__) def gen_data(self, batch_size, anchors_num, seq_length): anchors = np.random.uniform(-5, 5, (batch_size, anchors_num, seq_length, 2)) origin_pos = np.random.uniform(-5, 5, (batch_size, 2)) diff --git a/tests/torch/test_deformable_aggregation.py b/tests/torch/test_deformable_aggregation.py index 35977a37..6fa539c9 100644 --- a/tests/torch/test_deformable_aggregation.py +++ b/tests/torch/test_deformable_aggregation.py @@ -1,22 +1,38 @@ import unittest -import torch -import numpy as np +import numpy as np +import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests -import mx_driving.fused + import mx_driving +import mx_driving.fused DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] +@golden_data_cache(__file__) +def cpu_gen_inputs(B, C, anchor, pts, numGroups): + feature_maps = np.random.rand(B, 2816, C).astype(np.float32) + spatial_shape = torch.tensor([[[32, 88]]], dtype=torch.int32).numpy() + scale_start_index = torch.tensor([[0]], dtype=torch.int32).numpy() + sample_location = np.random.rand(B, anchor, pts, 1, 2).astype(np.float32) + weights = np.random.rand(B, anchor, pts, 1, 1, numGroups).astype(np.float32) + + return feature_maps, spatial_shape, scale_start_index, sample_location, weights + + class TestDeformableAggregation(TestCase): # pylint: disable=too-many-arguments,huawei-too-many-arguments - def golden_deformable_aggregation(self, out, batch_size, num_anchors, num_pts, num_cams, num_scale, num_embeds, + @golden_data_cache(__file__) + def golden_deformable_aggregation(self, batch_size, num_anchors, num_pts, num_cams, num_scale, num_embeds, num_groups, num_feat, feature_maps, spatial_shape, scale_start_index, sample_location, weights): - + + out = np.zeros((batch_size, num_anchors, num_embeds)).astype(np.float32) + num_kernels = batch_size * num_anchors * num_pts * num_cams * num_scale for idx in range(num_kernels): chanenl_offset = 0 @@ -106,6 +122,8 @@ class TestDeformableAggregation(TestCase): chanenl_offset += num_embeds // num_groups + return out + @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `DeformableAggregation` is only supported on 910B, skip this ut!") def test_deformable_aggregation(self): @@ -121,12 +139,7 @@ class TestDeformableAggregation(TestCase): for pts in ptsList: for anchor in anchorList: for numGroups in numGroupsList: - - feature_maps = np.random.rand(B, 2816, C).astype(np.float32) - spatial_shape = torch.tensor([[[32, 88]]], dtype=torch.int32).numpy() - scale_start_index = torch.tensor([[0]], dtype=torch.int32).numpy() - sample_location = np.random.rand(B, anchor, pts, 1, 2).astype(np.float32) - weights = np.random.rand(B, anchor, pts, 1, 1, numGroups).astype(np.float32) + feature_maps, spatial_shape, scale_start_index, sample_location, weights = cpu_gen_inputs(B, C, anchor, pts, numGroups) torch_feature_maps = torch.from_numpy(feature_maps).npu() torch_spatial_shape = torch.from_numpy(spatial_shape).npu() @@ -146,9 +159,8 @@ class TestDeformableAggregation(TestCase): weights = weights.flatten() feature_maps = feature_maps.flatten() - out_cpu = np.zeros((batch_size, num_anchors, num_embeds)).astype(np.float32) - self.golden_deformable_aggregation(out_cpu, batch_size, num_anchors, num_pts, num_cams, + out_cpu = self.golden_deformable_aggregation(batch_size, num_anchors, num_pts, num_cams, num_scale, num_embeds, num_groups, num_feat, feature_maps, spatial_shape, scale_start_index, sample_location, weights) diff --git a/tests/torch/test_deformable_aggregation_grad.py b/tests/torch/test_deformable_aggregation_grad.py index 23e24c5e..f2fe4991 100644 --- a/tests/torch/test_deformable_aggregation_grad.py +++ b/tests/torch/test_deformable_aggregation_grad.py @@ -3,17 +3,31 @@ import unittest import numpy as np import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests -import mx_driving.fused import mx_driving +import mx_driving.fused DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] +# 'pylint: disable=too-many-arguments,huawei-too-many-arguments +@golden_data_cache(__file__) +def gen_inputs(B, C, input_h, input_w, anchor, pts, numGroups): + feature_maps = np.random.rand(B, input_h * input_w, C).astype(np.float32) + spatial_shape = torch.tensor([[[input_h, input_w]]], dtype=torch.int32).numpy() + scale_start_index = torch.tensor([[0]], dtype=torch.int32).numpy() + sample_location = np.random.rand(B, anchor, pts, 1, 2).astype(np.float32) + weights = np.random.rand(B, anchor, pts, 1, 1, numGroups).astype(np.float32) + + return feature_maps, spatial_shape, scale_start_index, sample_location, weights + + class TestDeformableAggregation(TestCase): # pylint: disable=too-many-arguments,huawei-too-many-arguments + @golden_data_cache(__file__) def golden_deformable_aggregation_grad( self, batch_size, @@ -28,12 +42,25 @@ class TestDeformableAggregation(TestCase): spatial_shape, scale_start_index, sample_location, - weights, - grad_output, - grad_mc_ms_feat, - grad_sampling_location, - grad_weights, + weights ): + + out_cpu = np.zeros((batch_size, num_anchors, num_embeds)).astype(np.float32) + grad_mc_ms_feat = np.zeros_like(feature_maps) + grad_sampling_location = np.zeros_like(sample_location) + grad_weights = np.zeros_like(weights) + grad_output = np.ones_like(out_cpu) + + feature_maps = feature_maps.flatten() + spatial_shape = spatial_shape.flatten() + scale_start_index = scale_start_index.flatten() + sample_location = sample_location.flatten() + weights = weights.flatten() + grad_mc_ms_feat = grad_mc_ms_feat.flatten() + grad_sampling_location = grad_sampling_location.flatten() + grad_weights = grad_weights.flatten() + grad_output = grad_output.flatten() + num_kernels = batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale for idx in range(num_kernels): @@ -147,6 +174,8 @@ class TestDeformableAggregation(TestCase): grad_sampling_location[loc_offset] += w * grad_w_weight * top_grad_mc_ms_feat grad_sampling_location[loc_offset + 1] += h * grad_h_weight * top_grad_mc_ms_feat + return grad_mc_ms_feat, grad_sampling_location, grad_weights + @unittest.skipIf( DEVICE_NAME != 'Ascend910B', "OP `DeformableAggregationGrad` is only supported on 910B, skip this ut!", @@ -167,11 +196,8 @@ class TestDeformableAggregation(TestCase): for numGroups in numGroupsList: input_h = 16 input_w = 22 - feature_maps = np.random.rand(B, input_h * input_w, C).astype(np.float32) - spatial_shape = torch.tensor([[[input_h, input_w]]], dtype=torch.int32).numpy() - scale_start_index = torch.tensor([[0]], dtype=torch.int32).numpy() - sample_location = np.random.rand(B, anchor, pts, 1, 2).astype(np.float32) - weights = np.random.rand(B, anchor, pts, 1, 1, numGroups).astype(np.float32) + + feature_maps, spatial_shape, scale_start_index, sample_location, weights = gen_inputs(B, C, input_h, input_w, anchor, pts, numGroups) feature_maps_shape = feature_maps.shape torch_feature_maps = torch.from_numpy(feature_maps).npu() @@ -201,23 +227,8 @@ class TestDeformableAggregation(TestCase): num_pts = sample_location.shape[2] num_groups = weights.shape[5] - out_cpu = np.zeros((batch_size, num_anchors, num_embeds)).astype(np.float32) - grad_mc_ms_feat = np.zeros_like(feature_maps) - grad_sampling_location = np.zeros_like(sample_location) - grad_weights = np.zeros_like(weights) - grad_output = np.ones_like(out_cpu) - - feature_maps = feature_maps.flatten() - spatial_shape = spatial_shape.flatten() - scale_start_index = scale_start_index.flatten() - sample_location = sample_location.flatten() - weights = weights.flatten() - grad_mc_ms_feat = grad_mc_ms_feat.flatten() - grad_sampling_location = grad_sampling_location.flatten() - grad_weights = grad_weights.flatten() - grad_output = grad_output.flatten() - - self.golden_deformable_aggregation_grad( + + grad_mc_ms_feat, grad_sampling_location, grad_weights = self.golden_deformable_aggregation_grad( batch_size, num_anchors, num_pts, @@ -230,11 +241,7 @@ class TestDeformableAggregation(TestCase): spatial_shape, scale_start_index, sample_location, - weights, - grad_output, - grad_mc_ms_feat, - grad_sampling_location, - grad_weights, + weights ) diff --git a/tests/torch/test_furthest_point_sample_with_dist.py b/tests/torch/test_furthest_point_sample_with_dist.py index cdd6c400..7cb442f7 100644 --- a/tests/torch/test_furthest_point_sample_with_dist.py +++ b/tests/torch/test_furthest_point_sample_with_dist.py @@ -13,19 +13,22 @@ # limitations under the License. import unittest -import torch -import numpy as np +import numpy as np +import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving import mx_driving.point + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] class TestFurthestPointSampleWithDist(TestCase): - + @golden_data_cache(__file__) def create_input_data(self, shape): b, n = shape point_xyz = np.random.uniform(0, 10, [b, n, 3]).astype(np.float32) @@ -45,6 +48,7 @@ class TestFurthestPointSampleWithDist(TestCase): else: return a + @golden_data_cache(__file__) def supported_op_exec(self, point_dist, point_num): b, n, _ = point_dist.shape tmp = np.zeros([b, n]).astype(np.float32) diff --git a/tests/torch/test_furthest_point_sampling.py b/tests/torch/test_furthest_point_sampling.py index 41e1d6b5..6de26235 100644 --- a/tests/torch/test_furthest_point_sampling.py +++ b/tests/torch/test_furthest_point_sampling.py @@ -14,15 +14,18 @@ import unittest from abc import ABC, abstractmethod + import numpy as np import torch - import torch_npu -from torch_npu.testing.testcase import TestCase, run_tests +from data_cache import golden_data_cache from torch_npu.testing.common_utils import create_common_tensor +from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving import mx_driving.point + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] @@ -45,6 +48,7 @@ class CreateBenchMarkTest(ABC): else : return a[0] + @golden_data_cache(__file__) def getCpuRes(self): cpuRes = np.zeros([self.batch, self.numPoints], dtype=np.int32) nearestDistCopy = self.nearestDist.copy() diff --git a/tests/torch/test_fused_bias_leaky_relu.py b/tests/torch/test_fused_bias_leaky_relu.py index 65724a29..475d8a82 100644 --- a/tests/torch/test_fused_bias_leaky_relu.py +++ b/tests/torch/test_fused_bias_leaky_relu.py @@ -1,18 +1,42 @@ import unittest -import torch + import numpy as np -import torch_npu +import torch import torch.nn.functional as F - +import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving.fused + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] negative_slop = -0.1 scale = 0.25 +@golden_data_cache(__file__) +def cpu_gen_inputs(shape, bias_dim, feature_dtype, bias_dtype): + x = np.random.uniform(1, 1, shape).astype(feature_dtype) + x = torch.from_numpy(x) + bias = np.random.uniform(-2.0, 2.0, bias_dim).astype(bias_dtype) + bias = torch.from_numpy(bias) + bias_cpu = bias.reshape([-1 if i == 1 else 1 for i in range(x.ndim)]) + bias_cpu = bias_cpu.repeat([1 if i == 1 else x.size(i) for i in range(x.ndim)]) + + return x, bias, bias_cpu + + +@golden_data_cache(__file__) +def cpu_gen_outputs(x, bias_cpu): + print(x.shape, bias_cpu.shape) + cpu_result = F.leaky_relu(x.float() + bias_cpu.float(), negative_slop) + cpu_result = cpu_result * scale + + return cpu_result + + class TestFusedBiasLeakyRelu(TestCase): seed = 1024 np.random.seed(seed) @@ -20,47 +44,29 @@ class TestFusedBiasLeakyRelu(TestCase): @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `FusedBiasLeakyRelu` is only supported on 910B, skip this ut!") def test_npu_fused_bias_leaky_relu_three_dim(self, device="npu"): N, H, W = [1, 100, 3] - x = np.random.uniform(1, 1, [N, H, W]).astype(np.float32) - x = torch.from_numpy(x) - bias = np.random.uniform(-2.0, 2.0, H).astype(np.float32) - bias = torch.from_numpy(bias) - bias_cpu = bias.reshape([-1 if i == 1 else 1 for i in range(x.ndim)]) - bias_cpu = bias_cpu.repeat([1 if i == 1 else x.size(i) for i in range(x.ndim)]) + x, bias, bias_cpu = cpu_gen_inputs([1, 100, 3], H, np.float32, np.float32) - cpu_result = F.leaky_relu(x + bias_cpu, negative_slop) - cpu_result = cpu_result * scale + cpu_result = cpu_gen_outputs(x, bias_cpu) npu_result = mx_driving.fused.npu_fused_bias_leaky_relu(x.npu(), bias.npu(), negative_slop, scale).cpu().numpy() self.assertRtolEqual(npu_result, cpu_result.numpy()) @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `FusedBiasLeakyRelu` is only supported on 910B, skip this ut!") def test_npu_fused_bias_leaky_relu_large_number(self, device="npu"): - B, N, H, W = [18, 256, 232, 400] - x = np.random.uniform(1, 1, [B, N, H, W]).astype(np.float32) - x = torch.from_numpy(x) - bias = np.random.uniform(-2.0, 2.0, N).astype(np.float32) - bias = torch.from_numpy(bias) - bias_cpu = bias.reshape([-1 if i == 1 else 1 for i in range(x.ndim)]) - bias_cpu = bias_cpu.repeat([1 if i == 1 else x.size(i) for i in range(x.ndim)]) + B, N, H, W = [18, 256, 232, 100] + x, bias, bias_cpu = cpu_gen_inputs([18, 256, 232, 100], N, np.float32, np.float32) - cpu_result = F.leaky_relu(x + bias_cpu, negative_slop) - cpu_result = cpu_result * scale + cpu_result = cpu_gen_outputs(x, bias_cpu) npu_result = mx_driving.fused.npu_fused_bias_leaky_relu(x.npu(), bias.npu(), negative_slop, scale).cpu().numpy() self.assertRtolEqual(npu_result, cpu_result.numpy()) @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `FusedBiasLeakyRelu` is only supported on 910B, skip this ut!") def test_npu_fused_bias_leaky_relu_fp16_large_number(self, device="npu"): - B, N, H, W = [18, 256, 232, 400] - x = np.random.uniform(1, 1, [B, N, H, W]).astype(np.float16) - x = torch.from_numpy(x) - bias = np.random.uniform(-2.0, 2.0, N).astype(np.float32) - bias = torch.from_numpy(bias) - bias_cpu = bias.reshape([-1 if i == 1 else 1 for i in range(x.ndim)]) - bias_cpu = bias_cpu.repeat([1 if i == 1 else x.size(i) for i in range(x.ndim)]) + B, N, H, W = [18, 256, 232, 100] + x, bias, bias_cpu = cpu_gen_inputs([18, 256, 232, 100], N, np.float16, np.float32) - cpu_result = F.leaky_relu(x.float() + bias_cpu.float(), negative_slop) - cpu_result = cpu_result * scale + cpu_result = cpu_gen_outputs(x, bias_cpu) npu_result = mx_driving.fused.npu_fused_bias_leaky_relu(x.npu(), bias.npu(), negative_slop, scale).cpu().numpy() self.assertRtolEqual(npu_result, cpu_result.half().numpy()) @@ -68,15 +74,9 @@ class TestFusedBiasLeakyRelu(TestCase): @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `FusedBiasLeakyRelu` is only supported on 910B, skip this ut!") def test_npu_fused_bias_leaky_relu_fp16_small_case(self, device="npu"): N, H, W = [18, 200, 6] - x = np.random.uniform(1, 1, [N, H, W]).astype(np.float16) - x = torch.from_numpy(x) - bias = np.random.uniform(-2.0, 2.0, H).astype(np.float32) - bias = torch.from_numpy(bias) - bias_cpu = bias.reshape([-1 if i == 1 else 1 for i in range(x.ndim)]) - bias_cpu = bias_cpu.repeat([1 if i == 1 else x.size(i) for i in range(x.ndim)]) - - cpu_result = F.leaky_relu(x.float() + bias_cpu.float(), negative_slop) - cpu_result = cpu_result * scale + x, bias, bias_cpu = cpu_gen_inputs([18, 200, 6], H, np.float16, np.float32) + + cpu_result = cpu_gen_outputs(x, bias_cpu) npu_result = mx_driving.fused.npu_fused_bias_leaky_relu(x.npu(), bias.npu(), negative_slop, scale).cpu().numpy() self.assertRtolEqual(npu_result, cpu_result.half().numpy()) diff --git a/tests/torch/test_geometric_kernel_attention.py b/tests/torch/test_geometric_kernel_attention.py index fca08bf7..67a7f5d5 100644 --- a/tests/torch/test_geometric_kernel_attention.py +++ b/tests/torch/test_geometric_kernel_attention.py @@ -8,13 +8,40 @@ from collections import namedtuple import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests import mx_driving + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] +# pylint: disable=too-many-return-values +@golden_data_cache(__file__) +def cpu_gen_inputs(shape): + bs, num_queries, embed_dims, num_heads, num_levels, num_points = shape + + sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2) + if bs == 24: + spatial_shapes = torch.tensor([15, 25] * num_levels).reshape(num_levels, 2) + sampling_locations[:, :, :, :, :, 0] = sampling_locations[:, :, :, :, :, 0] * 21 - 3 # -3 ~ 18 + sampling_locations[:, :, :, :, :, 1] = sampling_locations[:, :, :, :, :, 1] * 31 - 3 # -3 ~ 28 + else: + spatial_shapes = torch.tensor([6, 10] * num_levels).reshape(num_levels, 2) + sampling_locations[:, :, :, :, :, 0] = sampling_locations[:, :, :, :, :, 0] * 12 - 3 # -3 ~ 9 + sampling_locations[:, :, :, :, :, 1] = sampling_locations[:, :, :, :, :, 1] * 16 - 3 # -3 ~ 13 + num_keys = sum((H * W).item() for H, W in spatial_shapes) + + value = torch.rand(bs, num_keys, num_heads, embed_dims) * 3 + attn_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points) * 1 + level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + grad_output = torch.rand(bs, num_queries, num_heads * embed_dims) * 10 + + return sampling_locations, spatial_shapes, value, attn_weights, level_start_index, grad_output + + +@golden_data_cache(__file__) def cpu_geometric_kernel_attention(value, spatial_shapes, level_start_index, sampling_locations, attn_weights): """CPU version of geometric kernel attention. @@ -66,6 +93,16 @@ def cpu_geometric_kernel_attention(value, spatial_shapes, level_start_index, sam return output.view(bs, num_queries, -1) +@golden_data_cache(__file__) +def cpu_geometric_kernel_attention_grad(cpu_output, grad_output, value, attn_weights): + cpu_output.backward(grad_output) + + grad_value = value.grad.float().numpy() + grad_attn_weights = attn_weights.grad.float().numpy() + + return grad_value, grad_attn_weights + + ExecResults = namedtuple('ExecResults', ['output', 'grad_value', 'grad_attn_weights']) Inputs = namedtuple('Inputs', ['value', 'spatial_shapes', 'level_start_index', 'sampling_locations', 'attn_weights', 'grad_output']) @@ -93,23 +130,8 @@ class TestGeometricKernelAttention(TestCase): return test_results def gen_inputs(self, shape, dtype): - bs, num_queries, embed_dims, num_heads, num_levels, num_points = shape - - sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2) - if bs == 24: - spatial_shapes = torch.tensor([15, 25] * num_levels).reshape(num_levels, 2) - sampling_locations[:, :, :, :, :, 0] = sampling_locations[:, :, :, :, :, 0] * 21 - 3 # -3 ~ 18 - sampling_locations[:, :, :, :, :, 1] = sampling_locations[:, :, :, :, :, 1] * 31 - 3 # -3 ~ 28 - else: - spatial_shapes = torch.tensor([6, 10] * num_levels).reshape(num_levels, 2) - sampling_locations[:, :, :, :, :, 0] = sampling_locations[:, :, :, :, :, 0] * 12 - 3 # -3 ~ 9 - sampling_locations[:, :, :, :, :, 1] = sampling_locations[:, :, :, :, :, 1] * 16 - 3 # -3 ~ 13 - num_keys = sum((H * W).item() for H, W in spatial_shapes) - - value = torch.rand(bs, num_keys, num_heads, embed_dims) * 3 - attn_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points) * 1 - level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) - grad_output = torch.rand(bs, num_queries, num_heads * embed_dims) * 10 + + sampling_locations, spatial_shapes, value, attn_weights, level_start_index, grad_output = cpu_gen_inputs(shape) cpu_value = value.float() cpu_spatial_shapes = spatial_shapes.long() @@ -144,11 +166,14 @@ class TestGeometricKernelAttention(TestCase): cpu_output = cpu_geometric_kernel_attention( value, spatial_shapes, level_start_index, sampling_locations, attn_weights ) - cpu_output.backward(grad_output) + grad_value, grad_attn_weights = cpu_geometric_kernel_attention_grad( + cpu_output, grad_output, value, attn_weights + ) + return ExecResults( output=cpu_output.detach().float().numpy(), - grad_value=value.grad.float().numpy(), - grad_attn_weights=attn_weights.grad.float().numpy() + grad_value=grad_value, + grad_attn_weights=grad_attn_weights ) def npu_to_exec(self, npu_inputs): diff --git a/tests/torch/test_group_points.py b/tests/torch/test_group_points.py index 12d92350..545257c7 100644 --- a/tests/torch/test_group_points.py +++ b/tests/torch/test_group_points.py @@ -1,9 +1,11 @@ import unittest -import torch -import numpy as np +import numpy as np +import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving import mx_driving.point @@ -11,7 +13,18 @@ import mx_driving.point DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] +# 'pylint: disable=too-many-arguments,huawei-too-many-arguments +@golden_data_cache(__file__) +def cpu_gen_inputs(B, C, N, mean, std_dev, npoints, nsample, dtype): + np_points = np.random.normal(mean, std_dev, (B, C, N)).astype(dtype) + np_indices = np.random.randint(0, N, (B, npoints, nsample)).astype(np.int32) + np_out = np.zeros((B, C, npoints, nsample)).astype(dtype) + + return np_points, np_indices, np_out + + class TestGroupPoints(TestCase): + @golden_data_cache(__file__) def cpu_group_points(self, points, indices, out): B, npoints, nsample = indices.shape @@ -44,10 +57,7 @@ class TestGroupPoints(TestCase): std_dev = np.random.uniform(0, 25) for j in range(2): - - np_points = np.random.normal(mean, std_dev, (B, C, N)).astype(astype[j]) - np_indices = np.random.randint(0, N, (B, npoints, nsample)).astype(np.int32) - np_out = np.zeros((B, C, npoints, nsample)).astype(astype[j]) + np_points, np_indices, np_out = cpu_gen_inputs(B, C, N, mean, std_dev, npoints, nsample, astype[j]) th_points = torch.from_numpy(np_points).npu().to(dtype[j]) th_indices = torch.from_numpy(np_indices).int().npu() diff --git a/tests/torch/test_group_points_grad.py b/tests/torch/test_group_points_grad.py index 8bf61d3e..1eb02138 100644 --- a/tests/torch/test_group_points_grad.py +++ b/tests/torch/test_group_points_grad.py @@ -1,17 +1,29 @@ import unittest -import torch -import numpy as np +import numpy as np +import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving._C DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] +@golden_data_cache(__file__) +def cpu_gen_inputs(B, C, N, npoints, nsample): + np_grad_out = np.random.rand(B, C, npoints, nsample).astype(np.float32) + np_indices = np.random.randint(0, N, (B, npoints, nsample)).astype(np.int32) + np_grad_features = np.zeros((B, C, N)).astype(np.float32) + + return np_grad_out, np_indices, np_grad_features + + class TestGroupPointsGrad(TestCase): # pylint: disable=too-many-arguments,huawei-too-many-arguments + @golden_data_cache(__file__) def golden_group_points_grad(self, np_grad_out, np_indices, np_grad_features, B, npoints, nsample): np_grad_out = np_grad_out.transpose(0, 2, 3, 1) @@ -40,9 +52,7 @@ class TestGroupPointsGrad(TestCase): for N in N_list: for npoints in npoints_list: for nsample in nsample_list: - np_grad_out = np.random.rand(B, C, npoints, nsample).astype(np.float32) - np_indices = np.random.randint(0, N, (B, npoints, nsample)).astype(np.int32) - np_grad_features = np.zeros((B, C, N)).astype(np.float32) + np_grad_out, np_indices, np_grad_features = cpu_gen_inputs(B, C, N, npoints, nsample) torch_grad_out = torch.from_numpy(np_grad_out).npu() torch_indices = torch.from_numpy(np_indices).npu() diff --git a/tests/torch/test_hard_voxelize.py b/tests/torch/test_hard_voxelize.py index 4ee50108..42376022 100644 --- a/tests/torch/test_hard_voxelize.py +++ b/tests/torch/test_hard_voxelize.py @@ -3,9 +3,12 @@ import unittest import numpy as np import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests -from mx_driving import Voxelization + import mx_driving.point +from mx_driving import Voxelization + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] @@ -15,6 +18,7 @@ class TestHardVoxelize(TestCase): point_nums = [1, 7, 6134, 99999] np.random.seed(seed) + @golden_data_cache(__file__) def gen(self, point_num): x = 108 * np.random.rand(point_num) - 54 y = 108 * np.random.rand(point_num) - 54 @@ -33,6 +37,7 @@ class TestHardVoxelize(TestCase): cnt, pts, voxs, num_per_vox = vlz(points_npu) return cnt, voxs.cpu().numpy(), cnt1, voxs1.cpu().numpy() + @golden_data_cache(__file__) def golden_hard_voxelize(self, points): point_num = points.shape[0] gridx = 1440 diff --git a/tests/torch/test_hypot.py b/tests/torch/test_hypot.py index 310eb87a..cef14ea4 100644 --- a/tests/torch/test_hypot.py +++ b/tests/torch/test_hypot.py @@ -1,83 +1,79 @@ -from copy import deepcopy import unittest +from copy import deepcopy + +import numpy as np import torch import torch_npu -import numpy as np +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] +@golden_data_cache(__file__) +def cpu_gen_inputs(data_range_x, data_range_y, x_shape, y_shape): + x = np.random.uniform(data_range_x[0], data_range_x[1], x_shape).astype(np.float32) + x = torch.from_numpy(x) + y = np.random.uniform(data_range_y[0], data_range_y[1], y_shape).astype(np.float32) + y = torch.from_numpy(y) + + return x, y + + +@golden_data_cache(__file__) +def cpu_gen_outputs(x, y): + z = torch.hypot(x, y).numpy() + return z + + class TestHypot(TestCase): def test_hypot_one_dim(self, device="npu"): - x = np.random.uniform(3, 3, [1]).astype(np.float32) - x = torch.from_numpy(x) - y = np.random.uniform(4, 4, [1]).astype(np.float32) - y = torch.from_numpy(y) - z = np.random.uniform(5, 5, [1]).astype(np.float32) + x, y = cpu_gen_inputs([3, 3], [4, 4], [1], [1]) + z = cpu_gen_outputs(x, y) npu_result = mx_driving.hypot(x.npu(), y.npu()).cpu() self.assertRtolEqual(npu_result.numpy(), z) def test_hypot_one_dim_broadcast(self, device="npu"): - x = np.random.uniform(3, 3, [1]).astype(np.float32) - x = torch.from_numpy(x) - y = np.random.uniform(4, 4, [10]).astype(np.float32) - y = torch.from_numpy(y) - z = np.random.uniform(5, 5, [10]).astype(np.float32) + x, y = cpu_gen_inputs([3, 3], [4, 4], [1], [10]) + z = cpu_gen_outputs(x, y) npu_result = mx_driving.hypot(x.npu(), y.npu()).cpu() self.assertRtolEqual(npu_result.numpy(), z) def test_hypot_three_dim(self, device="npu"): - x = np.random.uniform(3, 3, [35, 50, 80]).astype(np.float32) - x = torch.from_numpy(x) - y = np.random.uniform(4, 4, [35, 50, 80]).astype(np.float32) - y = torch.from_numpy(y) - z = np.random.uniform(5, 5, [35, 50, 80]).astype(np.float32) + x, y = cpu_gen_inputs([3, 3], [4, 4], [35, 50, 80], [35, 50, 80]) + z = cpu_gen_outputs(x, y) npu_result = mx_driving.hypot(x.npu(), y.npu()).cpu() self.assertRtolEqual(npu_result.numpy(), z) def test_hypot_random_three_dim(self, device="npu"): - x = np.random.uniform(1, 3, [35, 50, 80]).astype(np.float32) - x = torch.from_numpy(x) - y = np.random.uniform(1, 4, [35, 50, 80]).astype(np.float32) - y = torch.from_numpy(y) - z = torch.hypot(x, y).numpy() + x, y = cpu_gen_inputs([1, 3], [1, 4], [35, 50, 80], [35, 50, 80]) + z = cpu_gen_outputs(x, y) npu_result = mx_driving.hypot(x.npu(), y.npu()).cpu() self.assertRtolEqual(npu_result.numpy(), z) def test_hypot_random_three_dim_broadcast_x(self, device="npu"): - x = np.random.uniform(1, 3, [35, 1, 80]).astype(np.float32) - x = torch.from_numpy(x) - y = np.random.uniform(1, 4, [35, 50, 80]).astype(np.float32) - y = torch.from_numpy(y) - z = torch.hypot(x, y).numpy() + x, y = cpu_gen_inputs([1, 3], [1, 4], [35, 1, 80], [35, 50, 80]) + z = cpu_gen_outputs(x, y) npu_result = mx_driving.hypot(x.npu(), y.npu()).cpu() self.assertRtolEqual(npu_result.numpy(), z) def test_hypot_random_three_dim_broadcast_y(self, device="npu"): - x = np.random.uniform(1, 3, [35, 50, 80]).astype(np.float32) - x = torch.from_numpy(x) - y = np.random.uniform(1, 4, [35, 1, 80]).astype(np.float32) - y = torch.from_numpy(y) - z = torch.hypot(x, y).numpy() + x, y = cpu_gen_inputs([1, 3], [1, 4], [35, 50, 80], [35, 1, 80]) + z = cpu_gen_outputs(x, y) npu_result = mx_driving.hypot(x.npu(), y.npu()).cpu() self.assertRtolEqual(npu_result.numpy(), z) def test_hypot_large_random_dim_broadcast(self, device="npu"): - x = np.random.uniform(1, 3, [35, 50, 80, 1, 3]).astype(np.float32) - x = torch.from_numpy(x) - y = np.random.uniform(1, 4, [35, 1, 80, 171, 3]).astype(np.float32) - y = torch.from_numpy(y) - z = torch.hypot(x, y).numpy() + x, y = cpu_gen_inputs([1, 3], [1, 4], [35, 50, 80, 1, 3], [35, 1, 80, 171, 3]) + z = cpu_gen_outputs(x, y) npu_result = mx_driving.hypot(x.npu(), y.npu()).cpu() self.assertRtolEqual(npu_result.numpy(), z) def test_hypot_grad_base(self, device="npu"): - x = np.random.uniform(3, 3, [35, 50]).astype(np.float32) - x = torch.from_numpy(x) - y = np.random.uniform(4, 4, [35, 50]).astype(np.float32) - y = torch.from_numpy(y) + x, y = cpu_gen_inputs([3, 3], [4, 4], [35, 50], [35, 50]) z_grad = torch.randn([35, 50]) x.requires_grad = True y.requires_grad = True @@ -91,10 +87,7 @@ class TestHypot(TestCase): self.assertRtolEqual(y.grad.numpy(), y_npu.grad.numpy()) def test_hypot_grad_zero(self, device="npu"): - x = np.random.uniform(0, 0, [35, 50]).astype(np.float32) - x = torch.from_numpy(x) - y = np.random.uniform(0, 0, [35, 50]).astype(np.float32) - y = torch.from_numpy(y) + x, y = cpu_gen_inputs([0, 0], [0, 0], [35, 50], [35, 50]) z_grad = torch.randn([35, 50]) x.requires_grad = True y.requires_grad = True @@ -108,10 +101,7 @@ class TestHypot(TestCase): self.assertRtolEqual(y.grad.numpy(), y_npu.grad.numpy()) def test_hypot_grad_large_random_dim_broadcast(self, device="npu"): - x = np.random.uniform(-3, 3, [35, 50, 80, 1, 3]).astype(np.float32) - x = torch.from_numpy(x) - y = np.random.uniform(-4, 4, [35, 1, 80, 171, 3]).astype(np.float32) - y = torch.from_numpy(y) + x, y = cpu_gen_inputs([-3, 3], [-4, 4], [35, 50, 80, 1, 3], [35, 1, 80, 171, 3]) z_grad = torch.randn([35, 50, 80, 171, 3]) x.requires_grad = True y.requires_grad = True diff --git a/tests/torch/test_knn.py b/tests/torch/test_knn.py index 513d8b47..f4e379c4 100644 --- a/tests/torch/test_knn.py +++ b/tests/torch/test_knn.py @@ -1,18 +1,29 @@ -import torch import numpy as np +import torch +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving import mx_driving.common +@golden_data_cache(__file__) +def cpu_gen_inputs(attrs): + batch, npoint, N, nsample, transposed = attrs + idx = np.zeros((batch, npoint, nsample), dtype=np.int32) + dist2 = np.zeros((batch, npoint, nsample), dtype=np.float32) + + return idx, dist2 + + class TestKnn(TestCase): + @golden_data_cache(__file__) def cpu_op_exec(self, attrs, xyz, center_xyz): batch, npoint, N, nsample, transposed = attrs - idx = np.zeros((batch, npoint, nsample), dtype=np.int32) - dist2 = np.zeros((batch, npoint, nsample), dtype=np.float32) + idx, dist2 = cpu_gen_inputs(attrs) if transposed: xyz = np.transpose(xyz, axes=(0, 2, 1)) center_xyz = np.transpose(center_xyz, axes=(0, 2, 1)) diff --git a/tests/torch/test_multi_scale_deformable_attn.py b/tests/torch/test_multi_scale_deformable_attn.py index d338834d..e50b08a8 100644 --- a/tests/torch/test_multi_scale_deformable_attn.py +++ b/tests/torch/test_multi_scale_deformable_attn.py @@ -3,13 +3,32 @@ from collections import namedtuple import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests import mx_driving.fused + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] +# pylint: disable=too-many-return-values +@golden_data_cache(__file__) +def cpu_gen_inputs(shape): + bs, num_queries, embed_dims, num_heads, num_levels, num_points = shape + shapes = torch.tensor([60, 40] * num_levels).reshape(num_levels, 2) + num_keys = sum((H * W).item() for H, W in shapes) + + value = torch.rand(bs, num_keys, num_heads, embed_dims) * 0.01 + sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2) + attention_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points) + 1e-5 + offset = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1])) + grad_output = torch.rand(bs, num_queries, num_heads * embed_dims) * 1e-3 + + return shapes, num_keys, value, sampling_locations, attention_weights, offset, grad_output + + +@golden_data_cache(__file__) def multi_scale_deformable_attn_pytorch( value: torch.Tensor, value_spatial_shapes: torch.Tensor, @@ -42,6 +61,17 @@ def multi_scale_deformable_attn_pytorch( return output.transpose(1, 2).contiguous() +@golden_data_cache(__file__) +def multi_scale_deformable_attn_pytorch_grad(cpu_output, cpu_grad_output, cpu_value, cpu_sampling_locations, cpu_attention_weights): + cpu_output.backward(cpu_grad_output) + grad_value = cpu_value.grad.float().numpy() + grad_sampling_locations = cpu_sampling_locations.grad.float().numpy() + grad_attention_weights = cpu_attention_weights.grad.float().numpy() + + return grad_value, grad_sampling_locations, grad_attention_weights + + + ExecResults = namedtuple("ExecResults", ["output", "grad_value", "grad_sampling_locations", "grad_attention_weights"]) Inputs = namedtuple("Inputs", ["value", "shapes", "offset", "sampling_locations", "attention_weights", "grad_output"]) @@ -81,14 +111,7 @@ class TestMultiScaleDeformableAttnFunction(TestCase): def gen_inputs(self, shape, dtype): bs, num_queries, embed_dims, num_heads, num_levels, num_points = shape - shapes = torch.tensor([60, 40] * num_levels).reshape(num_levels, 2) - num_keys = sum((H * W).item() for H, W in shapes) - - value = torch.rand(bs, num_keys, num_heads, embed_dims) * 0.01 - sampling_locations = torch.rand(bs, num_queries, num_heads, num_levels, num_points, 2) - attention_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points) + 1e-5 - offset = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1])) - grad_output = torch.rand(bs, num_queries, num_heads * embed_dims) * 1e-3 + shapes, num_keys, value, sampling_locations, attention_weights, offset, grad_output = cpu_gen_inputs(shape) cpu_value = value.double() cpu_shapes = shapes.long() @@ -124,12 +147,12 @@ class TestMultiScaleDeformableAttnFunction(TestCase): cpu_output = multi_scale_deformable_attn_pytorch( cpu_value, cpu_shapes, cpu_sampling_locations, cpu_attention_weights ) - cpu_output.backward(cpu_grad_output) + grad_value, grad_sampling_locations, grad_attention_weights = multi_scale_deformable_attn_pytorch_grad(cpu_output, cpu_grad_output, cpu_value, cpu_sampling_locations, cpu_attention_weights) return ExecResults( output=cpu_output.detach().float().numpy(), - grad_value=cpu_value.grad.float().numpy(), - grad_sampling_locations=cpu_sampling_locations.grad.float().numpy(), - grad_attention_weights=cpu_attention_weights.grad.float().numpy(), + grad_value=grad_value, + grad_sampling_locations=grad_sampling_locations, + grad_attention_weights=grad_attention_weights, ) def npu_to_exec(self, npu_inputs): diff --git a/tests/torch/test_npu_dyn_voxelization.py b/tests/torch/test_npu_dyn_voxelization.py index 7ca150a3..4a47d2ef 100644 --- a/tests/torch/test_npu_dyn_voxelization.py +++ b/tests/torch/test_npu_dyn_voxelization.py @@ -1,10 +1,12 @@ -import unittest -import random import math +import random +import unittest + import torch import torch_npu - +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving.point from mx_driving import Voxelization @@ -12,6 +14,7 @@ from mx_driving import Voxelization DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] +@golden_data_cache(__file__) def dyn_voxelization_cpu(points, voxel_size, coors_range): num_points = points.size(0) @@ -62,6 +65,7 @@ class TestDynVoxelization(TestCase): coors1 = dynamic_voxelization_npu1(points) return coors, coors1 + @golden_data_cache(__file__) def gen_data(self, shape, dtype): points_cpu = torch.randint(-20, 100, shape, dtype=dtype) points_cpu = points_cpu + torch.rand(shape, dtype=dtype) diff --git a/tests/torch/test_npu_dynamic_scatter.py b/tests/torch/test_npu_dynamic_scatter.py index ed1e1459..8b3fc8fa 100644 --- a/tests/torch/test_npu_dynamic_scatter.py +++ b/tests/torch/test_npu_dynamic_scatter.py @@ -1,10 +1,11 @@ import unittest + import numpy as np import torch - import torch_npu -from torch_npu.testing.testcase import TestCase, run_tests +from data_cache import golden_data_cache from torch_npu.testing.common_utils import create_common_tensor +from torch_npu.testing.testcase import TestCase, run_tests import mx_driving.point diff --git a/tests/torch/test_npu_max_pool2d.py b/tests/torch/test_npu_max_pool2d.py index e9294f13..b26c799b 100644 --- a/tests/torch/test_npu_max_pool2d.py +++ b/tests/torch/test_npu_max_pool2d.py @@ -1,15 +1,19 @@ import torch import torch.nn as nn +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving.fused +@golden_data_cache(__file__) def gen_inputs(shape, dtype): torch.manual_seed(123) x_data_cpu = torch.rand(shape, dtype=dtype) return x_data_cpu +@golden_data_cache(__file__) def cpu_to_exec(x_data_cpu): f = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) cpu_output = f(x_data_cpu.float()) diff --git a/tests/torch/test_npu_nms3d.py b/tests/torch/test_npu_nms3d.py index dfe4cd8b..25e63e03 100644 --- a/tests/torch/test_npu_nms3d.py +++ b/tests/torch/test_npu_nms3d.py @@ -1,5 +1,5 @@ import unittest -from math import cos, sin, fabs, atan2 +from math import atan2, cos, fabs, sin from typing import List import numpy as np @@ -11,6 +11,7 @@ from torch_npu.testing.testcase import TestCase, run_tests import mx_driving import mx_driving.detection + torch.npu.config.allow_internal_format = False torch_npu.npu.set_compile_mode(jit_compile=False) DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] diff --git a/tests/torch/test_npu_nms3d_normal.py b/tests/torch/test_npu_nms3d_normal.py index 2dcdecf8..5fa41059 100644 --- a/tests/torch/test_npu_nms3d_normal.py +++ b/tests/torch/test_npu_nms3d_normal.py @@ -1,13 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. import unittest -import torch -import numpy as np +import numpy as np +import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving import mx_driving.detection + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] diff --git a/tests/torch/test_pixel_group.py b/tests/torch/test_pixel_group.py index 04a6b47e..2ab6ae2b 100644 --- a/tests/torch/test_pixel_group.py +++ b/tests/torch/test_pixel_group.py @@ -1,15 +1,17 @@ -from typing import List +import unittest from dataclasses import dataclass +from typing import List -import unittest +import numpy as np import torch import torch_npu -import numpy as np - +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving import mx_driving.detection + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] @@ -24,6 +26,7 @@ class KernelParams: distance_threshold: float +@golden_data_cache(__file__) def pixel_group_cpu_golden(params: KernelParams): score = params.score mask = params.mask @@ -96,6 +99,7 @@ def pixel_group_npu_golden(params: KernelParams): return output1, output2 +@golden_data_cache(__file__) def generate_data(H, W, dim, num): score = np.random.uniform(0, 1, [H, W]).astype(np.float32) score = torch.from_numpy(score) diff --git a/tests/torch/test_point_to_voxel.py b/tests/torch/test_point_to_voxel.py index ce8a0578..0ab63d36 100644 --- a/tests/torch/test_point_to_voxel.py +++ b/tests/torch/test_point_to_voxel.py @@ -3,9 +3,12 @@ import unittest import numpy as np import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving._C + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] @@ -14,12 +17,14 @@ class TestPointToVoxel(TestCase): np.random.seed(seed) point_nums = [1, 7, 6134, 99999] + @golden_data_cache(__file__) def gen(self, point_num): x = np.random.randint(-1, 256, (point_num,)) y = np.random.randint(-1, 256, (point_num,)) z = np.random.randint(-1, 256, (point_num,)) return np.stack([x, y, z], axis=-1).astype(np.int32) + @golden_data_cache(__file__) def golden_encode(self, coords): point_num = coords.shape[0] res = np.zeros((point_num,), dtype=np.int32) diff --git a/tests/torch/test_points_in_box.py b/tests/torch/test_points_in_box.py index adadf303..f9d222f9 100644 --- a/tests/torch/test_points_in_box.py +++ b/tests/torch/test_points_in_box.py @@ -12,16 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest -import torch + import numpy as np +import torch import torch_npu - +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving.preprocess + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] +@golden_data_cache(__file__) +def cpu_gen_inputs(range_boxes, range_points, shape_boxes, shape_points): + boxes = np.random.uniform(range_boxes[0], range_boxes[1], shape_boxes).astype(np.float32) + boxes = torch.from_numpy(boxes) + points = np.random.uniform(range_points[0], range_points[1], shape_points).astype(np.float32) + points = torch.from_numpy(points) + + return boxes, points + + def lidar_to_local_coords_cpu(shift_x, shift_y, rz): cosa = torch.cos(-rz) sina = torch.sin(-rz) @@ -53,6 +66,7 @@ def check_pt_in_box3d_cpu(pt, box3d, idx): return in_flag +@golden_data_cache(__file__) def points_in_boxes_cpu_forward(boxes_tensor, pts_tensor, pts_indices_tensor): boxes_num = boxes_tensor.size(0) pts_num = pts_tensor.size(0) @@ -84,10 +98,7 @@ class TestPointsInBox(TestCase): @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `PointsInBox` is only supported on 910B, skip this ut!") def test_points_in_box_shape_randn(self, device="npu"): - boxes = np.random.uniform(0, 1, [1, 200, 7]).astype(np.float32) - boxes = torch.from_numpy(boxes) - points = np.random.uniform(0, 2.0, [1, 100, 3]).astype(np.float32) - points = torch.from_numpy(points) + boxes, points = cpu_gen_inputs([0, 1], [0, 2.0], [1, 200, 7], [1, 100, 3]) shape1 = points.shape batch_size = shape1[0] num_points = shape1[1] @@ -105,10 +116,7 @@ class TestPointsInBox(TestCase): @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `PointsInBox` is only supported on 910B, skip this ut!") def test_points_in_box_shape_large_boxes(self, device="npu"): - boxes = np.random.uniform(0, 1, [1, 2000, 7]).astype(np.float32) - boxes = torch.from_numpy(boxes) - points = np.random.uniform(0, 2.0, [1, 100, 3]).astype(np.float32) - points = torch.from_numpy(points) + boxes, points = cpu_gen_inputs([0, 1], [0, 2.0], [1, 2000, 7], [1, 100, 3]) shape1 = points.shape batch_size = shape1[0] num_points = shape1[1] @@ -126,10 +134,7 @@ class TestPointsInBox(TestCase): @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `PointsInBox` is only supported on 910B, skip this ut!") def test_points_in_box_shape_large_points(self, device="npu"): - boxes = np.random.uniform(0, 1, [1, 200, 7]).astype(np.float32) - boxes = torch.from_numpy(boxes) - points = np.random.uniform(0, 2.0, [1, 1500, 3]).astype(np.float32) - points = torch.from_numpy(points) + boxes, points = cpu_gen_inputs([0, 1], [0, 2.0], [1, 200, 7], [1, 1500, 3]) shape1 = points.shape batch_size = shape1[0] num_points = shape1[1] @@ -147,10 +152,7 @@ class TestPointsInBox(TestCase): @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `PointsInBox` is only supported on 910B, skip this ut!") def test_points_in_box_shape_large_batch(self, device="npu"): - boxes = np.random.uniform(0, 1, [2, 200, 7]).astype(np.float32) - boxes = torch.from_numpy(boxes) - points = np.random.uniform(0, 2.0, [2, 1500, 3]).astype(np.float32) - points = torch.from_numpy(points) + boxes, points = cpu_gen_inputs([0, 1], [0, 2.0], [2, 200, 7], [2, 1500, 3]) shape1 = points.shape batch_size = shape1[0] num_points = shape1[1] diff --git a/tests/torch/test_points_in_box_all.py b/tests/torch/test_points_in_box_all.py index b8dc29f5..c85cff91 100644 --- a/tests/torch/test_points_in_box_all.py +++ b/tests/torch/test_points_in_box_all.py @@ -12,15 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest -import torch + import numpy as np +import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving.preprocess + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] +@golden_data_cache(__file__) +def cpu_gen_inputs(range_boxes, range_points, shape_boxes, shape_points): + boxes = np.random.uniform(range_boxes[0], range_boxes[1], shape_boxes).astype(np.float32) + boxes = torch.from_numpy(boxes) + points = np.random.uniform(range_points[0], range_points[1], shape_points).astype(np.float32) + points = torch.from_numpy(points) + + return boxes, points + + +@golden_data_cache(__file__) def lidar_to_local_coords_cpu(shift_x, shift_y, rz): cosa = torch.cos(-rz) sina = torch.sin(-rz) @@ -30,6 +45,7 @@ def lidar_to_local_coords_cpu(shift_x, shift_y, rz): return local_x, local_y +@golden_data_cache(__file__) def points_in_boxes_all_cpu_forward(boxes, pts): cx, cy, cz, x_size, y_size, z_size, rz = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4], boxes[:, 5], boxes[:, 6] x, y, z = pts[:, 0], pts[:, 1], pts[:, 2] @@ -60,10 +76,7 @@ def points_in_boxes_all_cpu_forward(boxes, pts): class TestPointsInBoxAll(TestCase): @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `PointsInBoxAll` is only supported on 910B, skip this ut!") def test_points_in_box_shape_randn(self, device="npu"): - boxes = np.random.uniform(0, 1, [1, 100, 7]).astype(np.float32) - boxes = torch.from_numpy(boxes) - points = np.random.uniform(0, 2.0, [1, 100, 3]).astype(np.float32) - points = torch.from_numpy(points) + boxes, points = cpu_gen_inputs([0, 1], [0, 2.0], [1, 100, 7], [1, 100, 3]) shape1 = points.shape batch_size = shape1[0] num_points = shape1[1] @@ -77,10 +90,7 @@ class TestPointsInBoxAll(TestCase): @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `PointsInBoxAll` is only supported on 910B, skip this ut!") def test_points_in_box_shape_large_boxes(self, device="npu"): - boxes = np.random.uniform(0, 1, [1, 10000, 7]).astype(np.float32) - boxes = torch.from_numpy(boxes) - points = np.random.uniform(0, 2.0, [1, 100, 3]).astype(np.float32) - points = torch.from_numpy(points) + boxes, points = cpu_gen_inputs([0, 1], [0, 2.0], [1, 10000, 7], [1, 100, 3]) shape1 = points.shape batch_size = shape1[0] num_points = shape1[1] @@ -94,10 +104,7 @@ class TestPointsInBoxAll(TestCase): @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `PointsInBoxAll` is only supported on 910B, skip this ut!") def test_points_in_box_shape_large_points(self, device="npu"): - boxes = np.random.uniform(0, 1, [1, 100, 7]).astype(np.float32) - boxes = torch.from_numpy(boxes) - points = np.random.uniform(0, 2.0, [1, 10000, 3]).astype(np.float32) - points = torch.from_numpy(points) + boxes, points = cpu_gen_inputs([0, 1], [0, 2.0], [1, 100, 7], [1, 10000, 3]) shape1 = points.shape batch_size = shape1[0] num_points = shape1[1] @@ -111,10 +118,7 @@ class TestPointsInBoxAll(TestCase): @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `PointsInBoxAll` is only supported on 910B, skip this ut!") def test_points_in_box_shape_large_boxes_points(self, device="npu"): - boxes = np.random.uniform(0, 1, [1, 10000, 7]).astype(np.float32) - boxes = torch.from_numpy(boxes) - points = np.random.uniform(0, 2.0, [1, 10000, 3]).astype(np.float32) - points = torch.from_numpy(points) + boxes, points = cpu_gen_inputs([0, 1], [0, 2.0], [1, 10000, 7], [1, 10000, 3]) shape1 = points.shape batch_size = shape1[0] num_points = shape1[1] @@ -128,10 +132,7 @@ class TestPointsInBoxAll(TestCase): @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `PointsInBoxAll` is only supported on 910B, skip this ut!") def test_points_in_box_shape_large_batch(self, device="npu"): - boxes = np.random.uniform(0, 1, [100, 100, 7]).astype(np.float32) - boxes = torch.from_numpy(boxes) - points = np.random.uniform(0, 2.0, [100, 100, 3]).astype(np.float32) - points = torch.from_numpy(points) + boxes, points = cpu_gen_inputs([0, 1], [0, 2.0], [100, 100, 7], [100, 100, 3]) shape1 = points.shape batch_size = shape1[0] num_points = shape1[1] @@ -145,10 +146,7 @@ class TestPointsInBoxAll(TestCase): @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `PointsInBoxAll` is only supported on 910B, skip this ut!") def test_points_in_box_shape_random_shape(self, device="npu"): - boxes = np.random.uniform(-5, 5, [214, 192, 7]).astype(np.float32) - boxes = torch.from_numpy(boxes) - points = np.random.uniform(-5, 5, [214, 371, 3]).astype(np.float32) - points = torch.from_numpy(points) + boxes, points = cpu_gen_inputs([-5, 5], [-5, 5], [214, 192, 7], [214, 371, 3]) shape1 = points.shape batch_size = shape1[0] num_points = shape1[1] diff --git a/tests/torch/test_roi_align_rotated.py b/tests/torch/test_roi_align_rotated.py index 918e659f..9c0c18b6 100644 --- a/tests/torch/test_roi_align_rotated.py +++ b/tests/torch/test_roi_align_rotated.py @@ -1,25 +1,29 @@ """ Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. """ -import unittest +import copy import math -from typing import List +import unittest from functools import reduce +from typing import List -import copy import numpy as np import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving import mx_driving.detection + torch.npu.config.allow_internal_format = False torch_npu.npu.set_compile_mode(jit_compile=False) DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] EPS = 1e-8 +@golden_data_cache(__file__) def cpu_roi_align_rotated_grad(input_array, rois, grad_outputs, args_dict): spatial_scale, sampling_ratio, pooled_height, pooled_width, aligned, clockwise = args_dict.values() bs, c, h, w = input_array.shape @@ -147,6 +151,7 @@ def bilinear_interpolate_grad(height, width, y, x): return bilinear_args +@golden_data_cache(__file__) def cpu_roi_align_rotated(input_array, rois, args_dict): spatial_scale, sampling_ratio, pooled_height, pooled_width, aligned, clockwise = args_dict.values() N, C, H, W = input_array.shape @@ -304,10 +309,12 @@ class TestRoiAlignedRotated(TestCase): return output_1.cpu(), output_2.cpu(), grad_1.cpu(), grad_2.cpu() + @golden_data_cache(__file__) def generate_features(self, feature_shape): features = torch.rand(feature_shape) return features + @golden_data_cache(__file__) def generate_rois(self, roi_shape, feature_shape, spatial_scale): num_boxes = roi_shape[0] rois = torch.Tensor(6, num_boxes) @@ -320,6 +327,7 @@ class TestRoiAlignedRotated(TestCase): return rois.transpose(0, 1).contiguous() + @golden_data_cache(__file__) def generate_grad(self, roi_shape, feature_shape, pooled_height, pooled_width): num_boxes = roi_shape[0] channels = feature_shape[1] diff --git a/tests/torch/test_roiaware_pool3d.py b/tests/torch/test_roiaware_pool3d.py index 94fb179e..07a525b8 100644 --- a/tests/torch/test_roiaware_pool3d.py +++ b/tests/torch/test_roiaware_pool3d.py @@ -1,15 +1,21 @@ import math import unittest + +import numpy as np import torch import torch_npu -import numpy as np +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving import mx_driving.detection + + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] class TestRoIAwarePool3dGrad(TestCase): + @golden_data_cache(__file__) def roiaware_pool3d_cpu(self, rois, pts, pts_feature, out, max_pts_per_voxel, pool_method, dtype): # cast if (dtype == np.float16): @@ -84,7 +90,8 @@ class TestRoIAwarePool3dGrad(TestCase): return 1, local_x, local_y else: return 0, local_x, local_y - + + @golden_data_cache(__file__) def roiaware_pool3d_golden(self, rois, pts, pts_feature, out, max_pts_per_voxel, mode): num_rois = rois.shape[0] num_channels = pts_feature.shape[-1] @@ -168,6 +175,7 @@ class TestRoIAwarePool3dGrad(TestCase): return pooled_features + @golden_data_cache(__file__) def gen_input_data(self, boxes_num, out_size, channels, npoints, dtype): xyz_coor = np.random.uniform(-1, 1, size=(boxes_num, 3)).astype(dtype) xyz_size_num = np.random.uniform(1, 50, size=(1, 3)).astype(dtype) diff --git a/tests/torch/test_roiaware_pool3d_grad.py b/tests/torch/test_roiaware_pool3d_grad.py index 31576854..466f5076 100644 --- a/tests/torch/test_roiaware_pool3d_grad.py +++ b/tests/torch/test_roiaware_pool3d_grad.py @@ -1,15 +1,20 @@ import unittest -import torch + import numpy as np +import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving._C import mx_driving.detection + + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] class TestRoIAwarePool3dGrad(TestCase): - + @golden_data_cache(__file__) def roiaware_pool3d_grad_cpu(self, pts_idx_of_voxels, argmax, grad_out, npoints, pool_method): channels = grad_out.shape[-1] @@ -70,6 +75,7 @@ class TestRoIAwarePool3dGrad(TestCase): pts_idx = pts_idx_of_voxels[b, ox, oy, oz, i] grad_in[pts_idx, :] += grad_out[b, ox, oy, oz, :] / max(total_pts, 1.0) + @golden_data_cache(__file__) def gen_input_data(self, pts_idx_of_voxels_shape, channels, npoints, dtype): boxes_num, out_x, out_y, out_z, max_pts_per_voxel = pts_idx_of_voxels_shape grad_out = np.random.uniform(-5, 5, (boxes_num, out_x, out_y, out_z, channels)).astype(dtype) @@ -81,6 +87,7 @@ class TestRoIAwarePool3dGrad(TestCase): pts_idx_of_voxels = torch.from_numpy(pts_idx_of_voxels) return argmax, grad_out, pts_idx_of_voxels + @golden_data_cache(__file__) def gen_pts_idx_of_voxels(self, pts_idx_of_voxels_shape, npoints): boxes_num, out_x, out_y, out_z, max_pts_per_voxel = pts_idx_of_voxels_shape pts_idx_of_voxels = np.zeros((boxes_num, out_x, out_y, out_z, max_pts_per_voxel - 1)).astype("int32") diff --git a/tests/torch/test_roipoint_pool3d.py b/tests/torch/test_roipoint_pool3d.py index cca726f6..2c645d5d 100644 --- a/tests/torch/test_roipoint_pool3d.py +++ b/tests/torch/test_roipoint_pool3d.py @@ -17,6 +17,7 @@ import unittest import numpy as np import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests from mx_driving import roipoint_pool3d @@ -29,6 +30,7 @@ DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] # float16[-14,16], float32[-126,128], float64[-1022,1024], int16[0,15], int32[0,31], int64[0,63] # random_value(-7, 8, (1, 2, 3), np.float32, True, True, False, False) # pylint: disable=too-many-arguments,huawei-too-many-arguments +@golden_data_cache(__file__) def random_value( min_log, max_log, size, dtype=np.float32, nega_flag=True, zero_flag=True, inf_flag=False, nan_flag=False ): @@ -78,6 +80,7 @@ def check_point_in_box3d(point, box3d): return in_flag +@golden_data_cache(__file__) def roipoint_pool3d_forward(num_sampled_points, points, point_features, boxes3d, pooled_features): point_num = points.shape[0] # N feature_len = point_features.shape[1] # C @@ -109,6 +112,7 @@ def roipoint_pool3d_forward(num_sampled_points, points, point_features, boxes3d, return 0 +@golden_data_cache(__file__) def cpu_roipoint_pool3d(num_sampled_points, points, point_features, boxes3d): # B=batch_size; N=point_num; M=boxes_num; C=feature_len; num = num_sampled_points batch_size = points.shape[0] # B diff --git a/tests/torch/test_rotated_iou.py b/tests/torch/test_rotated_iou.py index 8d73b32b..ce0a13c4 100644 --- a/tests/torch/test_rotated_iou.py +++ b/tests/torch/test_rotated_iou.py @@ -1,14 +1,18 @@ import unittest -from math import cos, sin, fabs, atan2, pi -from typing import List from collections import namedtuple +from math import atan2, cos, fabs, pi, sin +from typing import List + import numpy as np import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving.detection from mx_driving import npu_rotated_iou + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] EPS = 1e-5 @@ -204,6 +208,7 @@ def box_overlap(box_a: List[float], box_b: List[float]): return inter / union +@golden_data_cache(__file__) def cpu_rotated_iou(boxes_a: List[List[float]], boxes_b: List[List[float]]): boxes_a_num = boxes_a.shape[0] boxes_b_num = boxes_b.shape[0] @@ -251,7 +256,8 @@ class TestNpuRotatedIou(TestCase): test_results.append((cpu_results, npu_results)) return test_results - def gen_inputs(self, shape, dtype): + @golden_data_cache(__file__) + def cpu_gen_inputs(self, shape, dtype): boxes_a_num, boxes_b_num = shape boxes_a = np.zeros((boxes_a_num, 5)) boxes_b = np.zeros((boxes_b_num, 5)) @@ -273,6 +279,11 @@ class TestNpuRotatedIou(TestCase): boxes_a_cpu = boxes_a.astype(np.float32) boxes_b_cpu = boxes_b.astype(np.float32) + + return boxes_a_cpu, boxes_b_cpu + + def gen_inputs(self, shape, dtype): + boxes_a_cpu, boxes_b_cpu = self.cpu_gen_inputs(shape, dtype) boxes_a_npu = torch.from_numpy(boxes_a_cpu).npu().unsqueeze(0) boxes_b_npu = torch.from_numpy(boxes_b_cpu).npu().unsqueeze(0) diff --git a/tests/torch/test_scatter_max.py b/tests/torch/test_scatter_max.py index 3671531b..f5e13582 100644 --- a/tests/torch/test_scatter_max.py +++ b/tests/torch/test_scatter_max.py @@ -1,15 +1,14 @@ -import torch import numpy as np -import torch_scatter - +import torch import torch_npu -from torch_npu.testing.testcase import TestCase, run_tests +import torch_scatter from torch_npu.testing.common_utils import create_common_tensor +from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving.common class TestScatterMaxWithArgmax(TestCase): - def cpu_op_exec(self, updates, indices): updates.requires_grad = True diff --git a/tests/torch/test_scatter_mean.py b/tests/torch/test_scatter_mean.py index 21468395..cc1af678 100644 --- a/tests/torch/test_scatter_mean.py +++ b/tests/torch/test_scatter_mean.py @@ -1,14 +1,25 @@ -import torch import numpy as np -import torch_scatter - +import torch import torch_npu -from torch_npu.testing.testcase import TestCase, run_tests +import torch_scatter +from data_cache import golden_data_cache from torch_npu.testing.common_utils import create_common_tensor +from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving.common +@golden_data_cache(__file__) +def cpu_gen_inputs(src_shape, index_shape, index_max): + cpu_src = np.random.uniform(0, 100, size=src_shape).astype(np.float32) + cpu_index = np.random.uniform(0, index_max, size=index_shape).astype(np.int32) + cpu_src = torch.from_numpy(cpu_src) + cpu_index = torch.from_numpy(cpu_index) + return cpu_src, cpu_index + + class TestScatterMeanWithArgmax(TestCase): + @golden_data_cache(__file__) def cpu_op_exec(self, src, index, out=None, dim=0, dim_size=None): src.requires_grad = True out = torch_scatter.scatter_mean(src, index.long(), out=out, dim=dim, dim_size=dim_size) @@ -40,8 +51,9 @@ class TestScatterMeanWithArgmax(TestCase): index_shape = input_info[1] index_max = input_info[2] for dim in range(len(index_shape)): - cpu_src, npu_src = create_common_tensor(["float32", 2, src_shape], 0, 100) - cpu_index, npu_index = create_common_tensor(["int32", 2, index_shape], 0, index_max) + cpu_src, cpu_index = cpu_gen_inputs(src_shape, index_shape, index_max) + npu_src, npu_index = cpu_src.npu(), cpu_index.npu() + cpu_output, cpu_grad_in = self.cpu_op_exec(cpu_src, cpu_index.long(), dim=dim) npu_output, npu_grad_in = self.npu_op_exec(npu_src, npu_index, None, dim) @@ -60,8 +72,8 @@ class TestScatterMeanWithArgmax(TestCase): index_shape = input_info[1] index_max = input_info[2] for dim in range(len(index_shape)): - cpu_src, npu_src = create_common_tensor(["float32", 2, src_shape], 0, 100) - cpu_index, npu_index = create_common_tensor(["int32", 2, index_shape], 0, index_max) + cpu_src, cpu_index = cpu_gen_inputs(src_shape, index_shape, index_max) + npu_src, npu_index = cpu_src.npu(), cpu_index.npu() cpu_output, cpu_grad_in = self.cpu_op_exec(cpu_src, cpu_index.long(), dim=dim) npu_output, npu_grad_in = self.npu_op_exec(npu_src, npu_index, None, dim) @@ -81,8 +93,8 @@ class TestScatterMeanWithArgmax(TestCase): out_shape = input_info[2] index_max = input_info[3] for dim in range(len(index_shape)): - cpu_src, npu_src = create_common_tensor(["float32", 2, src_shape], 0, 100) - cpu_index, npu_index = create_common_tensor(["int32", 2, index_shape], 0, index_max) + cpu_src, cpu_index = cpu_gen_inputs(src_shape, index_shape, index_max) + npu_src, npu_index = cpu_src.npu(), cpu_index.npu() cpu_output, cpu_grad_in = self.cpu_op_exec(cpu_src, cpu_index.long(), dim=dim) npu_output, npu_grad_in = self.npu_op_exec(npu_src, npu_index, None, dim) @@ -104,8 +116,8 @@ class TestScatterMeanWithArgmax(TestCase): out_shape = input_info[2] dim = input_info[3] - cpu_src, npu_src = create_common_tensor(["float32", 2, src_shape], 0, 100) - cpu_index, npu_index = create_common_tensor(["int32", 2, index_shape], 0, out_shape[dim]) + cpu_src, cpu_index = cpu_gen_inputs(src_shape, index_shape, out_shape[dim]) + npu_src, npu_index = cpu_src.npu(), cpu_index.npu() cpu_out, npu_out = create_common_tensor(["float32", 2, out_shape], 0, 100) cpu_output, cpu_grad_in = self.cpu_op_exec(cpu_src, cpu_index.long(), out=cpu_out, dim=dim) npu_output, npu_grad_in = self.npu_op_exec(npu_src, npu_index, out=npu_out, dim=dim) @@ -126,8 +138,8 @@ class TestScatterMeanWithArgmax(TestCase): out_shape = input_info[2] dim_size = input_info[3] for dim in range(len(index_shape)): - cpu_src, npu_src = create_common_tensor(["float32", 2, src_shape], 0, 100) - cpu_index, npu_index = create_common_tensor(["int32", 2, index_shape], 0, dim_size) + cpu_src, cpu_index = cpu_gen_inputs(src_shape, index_shape, dim_size) + npu_src, npu_index = cpu_src.npu(), cpu_index.npu() cpu_out, npu_out = create_common_tensor(["float32", 2, out_shape], 0, 100) cpu_output, cpu_grad_in = self.cpu_op_exec(cpu_src, cpu_index.long(), out=None, dim=dim, dim_size=dim_size) npu_output, npu_grad_in = self.npu_op_exec(npu_src, npu_index, out=None, dim=dim, dim_size=dim_size) diff --git a/tests/torch/test_three_interpolate.py b/tests/torch/test_three_interpolate.py index 3e9a7e88..f70f3da3 100644 --- a/tests/torch/test_three_interpolate.py +++ b/tests/torch/test_three_interpolate.py @@ -1,12 +1,13 @@ -import torch import numpy as np - +import torch +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving class TestThreeinterpolate(TestCase): - + @golden_data_cache(__file__) def cpu_op_exec(self, feat, idx, wt): bs, cs, ms = feat.shape ns = idx.shape[1] diff --git a/tests/torch/test_three_nn.py b/tests/torch/test_three_nn.py index 2b8857ca..0ec7e571 100644 --- a/tests/torch/test_three_nn.py +++ b/tests/torch/test_three_nn.py @@ -1,11 +1,21 @@ -import torch import numpy as np +import torch +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving import mx_driving.common +@golden_data_cache(__file__) +def cpu_gen_inputs(batch, N, npoint): + source = np.ones((batch, N, 3)).astype(np.float32) + target = np.zeros((batch, npoint, 3)).astype(np.float32) + return source, target + + class TestThreeNN(TestCase): + @golden_data_cache(__file__) def cpu_op_exec(self, batch, npoint, @@ -35,8 +45,7 @@ class TestThreeNN(TestCase): batch = 1 npoint = 1 N = 200 - source = np.ones((batch, N, 3)).astype(np.float32) - target = np.zeros((batch, npoint, 3)).astype(np.float32) + source, target = cpu_gen_inputs(batch, N, npoint) expected_dist, expected_idx = self.cpu_op_exec(batch, npoint, source, target) dist, idx = mx_driving.three_nn(torch.from_numpy(target).npu(), torch.from_numpy(source).npu()) @@ -263,8 +272,7 @@ class TestThreeNN(TestCase): N = 3 npoint = 1 - source = np.ones((batch, N, 3)).astype(np.float32) - target = np.zeros((batch, npoint, 3)).astype(np.float32) + source, target = cpu_gen_inputs(batch, N, npoint) expected_dist, expected_idx = self.cpu_op_exec(batch, npoint, source, target) dist, idx = mx_driving.three_nn(torch.from_numpy(target).npu(), torch.from_numpy(source).npu()) @@ -281,8 +289,7 @@ class TestThreeNN(TestCase): N = 12 npoint = 21 - source = np.ones((batch, N, 3)).astype(np.float32) - target = np.zeros((batch, npoint, 3)).astype(np.float32) + source, target = cpu_gen_inputs(batch, N, npoint) expected_dist, expected_idx = self.cpu_op_exec(batch, npoint, source, target) dist, idx = mx_driving.three_nn(torch.from_numpy(target).npu(), torch.from_numpy(source).npu()) diff --git a/tests/torch/test_unique_voxel.py b/tests/torch/test_unique_voxel.py index c56dc275..87171137 100644 --- a/tests/torch/test_unique_voxel.py +++ b/tests/torch/test_unique_voxel.py @@ -3,9 +3,12 @@ import unittest import numpy as np import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] @@ -14,10 +17,12 @@ class TestUniqueVoxel(TestCase): np.random.seed(seed) point_nums = [1, 7, 6134, 99999] + @golden_data_cache(__file__) def gen(self, point_num): x = np.random.randint(0, 1024, (point_num,)) return x.astype(np.int32) + @golden_data_cache(__file__) def golden_unique(self, voxels): res = np.unique(voxels) return res.shape[0], np.sort(res) @@ -27,12 +32,14 @@ class TestUniqueVoxel(TestCase): cnt, uni_vox, _, _, _ = mx_driving.unique_voxel(voxels_npu) return cnt, uni_vox.cpu().numpy() + @golden_data_cache(__file__) def gen_integration(self, point_num): x = np.random.randint(0, 256, (point_num,)) y = np.random.randint(0, 256, (point_num,)) z = np.random.randint(0, 256, (point_num,)) return np.stack([x, y, z], axis=-1).astype(np.int32) + @golden_data_cache(__file__) def golden_integration(self, coords): point_num = coords.shape[0] res = np.zeros((point_num,), dtype=np.int32) diff --git a/tests/torch/test_vec_pool_backward.py b/tests/torch/test_vec_pool_backward.py index ce128a21..12e723eb 100644 --- a/tests/torch/test_vec_pool_backward.py +++ b/tests/torch/test_vec_pool_backward.py @@ -1,18 +1,36 @@ import unittest -import torch -import numpy as np +import numpy as np +import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving._C +np.random.seed(2024) DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] -class TestVecPoolGrad(TestCase): - np.random.seed(2024) +@golden_data_cache(__file__) +def cpu_gen_inputs(args): + m, c_out, n, c_in, num_total_grids, num_max_sum_points = args + np_grad_new_features = np.random.rand(m, c_out).astype(np.float32) + np_point_cnt_of_grid = np.random.randint(1, 8, (m, num_total_grids)).astype(np.int32) + np_grouped_idxs = np.column_stack(( + np.random.randint(0, n, num_max_sum_points), + np.random.randint(0, m, num_max_sum_points), + np.random.randint(0, num_total_grids, num_max_sum_points) + )).astype(np.int32) + np_grad_support_features = np.zeros((n, c_in)).astype(np.float32) + + return np_grad_new_features, np_point_cnt_of_grid, np_grouped_idxs, np_grad_support_features + + +class TestVecPoolGrad(TestCase): + @golden_data_cache(__file__) def golden_vec_pool_backward(self, grad_new_features, point_cnt_of_grid, grouped_idxs, grad_support_features): num_c_out = grad_new_features.shape[1] num_total_grids = point_cnt_of_grid.shape[1] @@ -44,15 +62,7 @@ class TestVecPoolGrad(TestCase): for args in args_list: m, c_out, n, c_in, num_total_grids, num_max_sum_points = args - - np_grad_new_features = np.random.rand(m, c_out).astype(np.float32) - np_point_cnt_of_grid = np.random.randint(1, 8, (m, num_total_grids)).astype(np.int32) - np_grouped_idxs = np.column_stack(( - np.random.randint(0, n, num_max_sum_points), - np.random.randint(0, m, num_max_sum_points), - np.random.randint(0, num_total_grids, num_max_sum_points) - )).astype(np.int32) - np_grad_support_features = np.zeros((n, c_in)).astype(np.float32) + np_grad_new_features, np_point_cnt_of_grid, np_grouped_idxs, np_grad_support_features = cpu_gen_inputs(args) torch_grad_new_features = torch.from_numpy(np_grad_new_features).npu() torch_point_cnt_of_grid = torch.from_numpy(np_point_cnt_of_grid).npu() diff --git a/tests/torch/test_voxel_pooling_train.py b/tests/torch/test_voxel_pooling_train.py index 19cc7e69..1eed3360 100644 --- a/tests/torch/test_voxel_pooling_train.py +++ b/tests/torch/test_voxel_pooling_train.py @@ -4,15 +4,18 @@ import unittest import numpy as np import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests import mx_driving.point from mx_driving import npu_voxel_pooling_train + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] # pylint: disable=too-many-arguments,huawei-too-many-arguments +@golden_data_cache(__file__) def voxel_pooling_train_cpu_forward( batch_size, num_points, num_channels, num_voxel_x, num_voxel_y, num_voxel_z, geom_xyz, input_features ): @@ -42,6 +45,7 @@ def voxel_pooling_train_cpu_forward( return pos_memo, output_features.permute(0, 3, 1, 2) +@golden_data_cache(__file__) def voxel_pooling_train_cpu_backward(pos, result_cpu, grad_features): features_shape = grad_features.shape mask = (pos != -1)[..., 0] @@ -81,6 +85,7 @@ class TestVoxelPoolingTrain(TestCase): return result1, result2, grad_features_npu1, grad_features_npu2 + @golden_data_cache(__file__) def gen_data(self, geom_shape, feature_shape, coeff, batch_size, num_channels, dtype): geom_xyz = torch.rand(geom_shape) * coeff geom_xyz = geom_xyz.reshape(batch_size, -1, 3) diff --git a/tests/torch/test_voxel_to_point.py b/tests/torch/test_voxel_to_point.py index 35b21bff..990e7f74 100644 --- a/tests/torch/test_voxel_to_point.py +++ b/tests/torch/test_voxel_to_point.py @@ -3,9 +3,12 @@ import unittest import numpy as np import torch import torch_npu +from data_cache import golden_data_cache from torch_npu.testing.testcase import TestCase, run_tests + import mx_driving._C + DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] @@ -14,10 +17,12 @@ class TestVoxelToPoint(TestCase): np.random.seed(seed) point_nums = [1, 7, 6134, 99999] + @golden_data_cache(__file__) def gen(self, point_num): x = np.random.randint(0, 10240, (point_num,)) return x.astype(np.int32) + @golden_data_cache(__file__) def golden_decode(self, voxels): point_num = voxels.shape[0] res = np.zeros((point_num, 3), dtype=np.int32) -- Gitee