diff --git a/ci/access_control_test.py b/ci/access_control_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7685657aaf3df9d7a758c99f7698b35cb55e8a86 --- /dev/null +++ b/ci/access_control_test.py @@ -0,0 +1,340 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -*- coding: UTF-8 -*- + +import os +import re +import sys +import subprocess +import threading +import queue +from abc import ABCMeta, abstractmethod +from pathlib import Path +import warnings + +BASE_DIR = Path(__file__).absolute().parent.parent +TEST_DIR = BASE_DIR / 'tests' + + +def check_path_owner_consistent(path: str): + """ + Function Description: + check whether the path belong to process owner + Parameter: + path: the path to check + Exception Description: + when invalid path, prompt the user + """ + + if not os.path.exists(path): + msg = f"The path does not exist: {path}" + raise RuntimeError(msg) + if os.stat(path).st_uid != os.getuid(): + warnings.warn(f"Warning: The {path} owner does not match the current user.") + + +def check_directory_path_readable(path): + """ + Function Description: + check whether the path is writable + Parameter: + path: the path to check + Exception Description: + when invalid data throw exception + """ + check_path_owner_consistent(path) + if os.path.islink(path): + msg = f"Invalid path is a soft chain: {path}" + raise RuntimeError(msg) + if not os.access(path, os.R_OK): + msg = f"The path permission check failed: {path}" + raise RuntimeError(msg) + + +class AccurateTest(metaclass=ABCMeta): + @abstractmethod + def identify(self, modify_file): + """ + 该接口提供代码对应的UT的路径信息 + """ + raise Exception("abstract method. Subclasses should implement it.") + + @staticmethod + def find_ut_by_regex(regex): + ut_files = [] + cmd = "find {} -name {}".format(str(TEST_DIR), regex) + status, output = subprocess.getstatusoutput(cmd) + if status or not output: + pass # 对于找不到的暂时不作处理 + else: + files = output.split('\n') + for ut_file in files: + if ut_file.endswith(".py"): + ut_files.append(ut_file) + return ut_files + + +class OpStrategy(AccurateTest): + """ + 通过对适配层的代码的识别 + """ + + def identify(self, modify_file): + """ + 通过对于算子实现文件的文件名解析获取其单元测试的名字,比如: + BinaryCrossEntropyWithLogitsBackwardKernelNpu.cpp + 针对这个文件,先识别关键字BinaryCrossEntropyWithLogitsBackward + 然后,获取其正则表达式*binary*cross*entropy*with*logits*backward*识别到符合要求的测试用例 + 具体方法:通过大写字母切分关键字,再识别包含所有这些关键字的测试文件名。 + """ + + filename = Path(modify_file).name + if filename.find('KernelNpu') >= 0: + feature_line = filename.split('KernelNpu')[0] + features = re.findall('[A-Z][^A-Z]*', feature_line) + regex = '*' + '*'.join([f"{feature.lower()}" for feature in features]) + '*' + return AccurateTest.find_ut_by_regex(regex) + return [] + + +class DirectoryStrategy(AccurateTest): + """ + Determine whether the modified files are test cases + """ + + def identify(self, modify_file): + is_test_file = str(Path(modify_file).parts[0]) == "test" \ + and re.match("test_(.+).py", Path(modify_file).name) + return [(str(BASE_DIR / modify_file))] if is_test_file else [] + + +class CoreTestStrategy(AccurateTest): + """ + Determine whether the core tests should be runned + """ + block_list = ['test', 'docs'] + core_test_cases = [str(i) for i in (BASE_DIR / 'test/test_npu').rglob('test_*.py')] + + def identify(self, modify_file): + modified_module = str(Path(modify_file).parts[0]) + if modified_module not in self.block_list: + return self.core_test_cases + return [] + + +class CopyOptStrategy(AccurateTest): + """ + 通过识别非连续转连续的测试用例 + """ + + def identify(self, modify_file): + if modify_file.find('contiguous') > 0: + regex = '*contiguous*' + return AccurateTest.find_ut_by_regex(regex) + return [] + + +class DirectoryMappingStrategy(AccurateTest): + """ + Map the modified files to the corresponding test cases + """ + mapping_list = { + 'contrib': 'test/test_contrib', + 'cpp_extension': 'test/test_cpp_extension', + 'distributed': 'test/test_distributed', + 'fx': 'test/test_fx', + 'hooks': 'test/test_hooks', + 'optim': 'test/test_optim', + 'profiler': 'test/test_profiler', + 'onnx': 'test/test_onnx', + 'utils': 'test/test_utils', + 'testing': 'test/test_testing.py', + } + + def identify(self, modify_file): + current_all_ut_path = [] + if str(Path(modify_file).parts[0]) == 'torch_npu': + mapped_ut_path = [] + module_name = str(Path(modify_file).parts[1]) + if module_name == 'csrc': + module_name = str(Path(modify_file).parts[2]) + if module_name in self.mapping_list: + mapped_ut_path.append(self.mapping_list[module_name]) + file_name = str(Path(modify_file).stem) + if file_name in self.mapping_list: + mapped_ut_path.append(self.mapping_list[file_name]) + + for mapped_path in mapped_ut_path: + if Path.is_file(BASE_DIR / mapped_path): + current_all_ut_path.append(str(BASE_DIR / mapped_path)) + else: + current_all_ut_path += [str(i) for i in (BASE_DIR / mapped_path).rglob('test_*.py')] + return current_all_ut_path + + +class TestMgr(): + def __init__(self): + self.modify_files = [] + self.test_files = { + 'ut_files': [], + 'op_ut_files': [] + } + + def load(self, modify_files): + check_directory_path_readable(modify_files) + with open(modify_files) as f: + for line in f: + line = line.strip() + self.modify_files.append(line) + + def analyze(self): + # determine whether the modification is about hostapi + def is_hostapi_enabled(modify_file): + if str(Path(modify_file).parent.name) == 'op_api': + os.environ['HOSTAPI_ENABLED'] = 'ON' + + for modify_file in self.modify_files: + is_hostapi_enabled(modify_file) + self.test_files['ut_files'] += DirectoryStrategy().identify(modify_file) + self.test_files['ut_files'] += CopyOptStrategy().identify(modify_file) + self.test_files['ut_files'] += OpStrategy().identify(modify_file) + # self.test_files['op_ut_files'] += OpStrategy().identify(modify_file) + # self.test_files['ut_files'] += DirectoryMappingStrategy().identify(modify_file) + self.test_files['ut_files'] += CoreTestStrategy().identify(modify_file) + unique_files = sorted(set(self.test_files['ut_files'])) + + exist_ut_file = [ + changed_file + for changed_file in unique_files + if Path(changed_file).exists() + ] + self.test_files['ut_files'] = exist_ut_file + + def get_test_files(self): + return self.test_files + + def print_modify_files(self): + print("modify files:") + for modify_file in self.modify_files: + print(modify_file) + + def print_ut_files(self): + print("ut files:") + for ut_file in self.test_files['ut_files']: + print(ut_file) + + def print_op_ut_files(self): + print("op ut files:") + for op_ut_file in self.test_files['op_ut_files']: + print(op_ut_file) + + +def exec_ut(files): + """ + 执行单元测试文件,其中存在失败,则标识异常并打印相关信息 + """ + + def get_op_name(ut_file): + return ut_file.split('/')[-1].split('.')[0].lstrip('test_') + + def get_ut_name(ut_file): + return str(Path(ut_file).relative_to(TEST_DIR))[:-3] + + def get_ut_cmd(ut_type, ut_file): + cmd = [sys.executable, "run_test.py", "-v", "-i"] + if ut_type == "op_ut_files": + return cmd + ["test_ops", "--", "-k", get_op_name(ut_file)] + return cmd + [get_ut_name(ut_file)] + + def wait_thread(process, event_timer): + process.wait() + event_timer.set() + + def enqueue_output(out, log_queue): + for line in iter(out.readline, b''): + log_queue.put(line.decode('utf-8')) + out.close() + return + + def start_thread(fn, *args): + stdout_t = threading.Thread(target=fn, args=args) + stdout_t.daemon = True + stdout_t.start() + + def print_subprocess_log(log_queue): + while (not log_queue.empty()): + print((log_queue.get()).strip()) + + def run_cmd_with_timeout(cmd): + os.chdir(str(TEST_DIR)) + stdout_queue = queue.Queue() + event_timer = threading.Event() + + p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.STDOUT, stdout=subprocess.PIPE) + start_thread(wait_thread, p, event_timer) + start_thread(enqueue_output, p.stdout, stdout_queue) + + try: + event_timer.wait(2000) + ret = p.poll() + if ret: + print_subprocess_log(stdout_queue) + if not event_timer.is_set(): + ret = 1 + p.kill() + p.terminate() + print("Timeout: Command '{}' timed out after 2000 seconds".format(" ".join(cmd))) + print_subprocess_log(stdout_queue) + except Exception as err: + ret = 1 + print(err) + return ret + + def run_tests(files): + exec_infos = [] + has_failed = 0 + for ut_type, ut_files in files.items(): + for ut_file in ut_files: + cmd = get_ut_cmd(ut_type, ut_file) + ut_info = " ".join(cmd[4:]).replace(" -- -k", "") + ret = run_cmd_with_timeout(cmd) + if ret: + has_failed = ret + exec_infos.append("exec ut {} failed.".format(ut_info)) + else: + exec_infos.append("exec ut {} success.".format(ut_info)) + return has_failed, exec_infos + + ret_status, exec_infos = run_tests(files) + + print("***** Total result:") + for exec_info in exec_infos: + print(exec_info) + return ret_status + + +if __name__ == "__main__": + cur_modify_files = str(BASE_DIR / 'modify_files.txt') + test_mgr = TestMgr() + test_mgr.load(cur_modify_files) + test_mgr.analyze() + cur_test_files = test_mgr.get_test_files() + + test_mgr.print_modify_files() + test_mgr.print_ut_files() + test_mgr.print_op_ut_files() + + ret_ut = exec_ut(cur_test_files) + sys.exit(ret_ut) diff --git a/build.sh b/ci/build.sh similarity index 99% rename from build.sh rename to ci/build.sh index f2c01a9ea39056c1d4f064aa1d6f69a7c81a1f2e..868e70dd728823dd38c1098312a550ba8e9c1207 100644 --- a/build.sh +++ b/ci/build.sh @@ -83,6 +83,7 @@ function main() check_python_version + cd ${CUR_DIR}/.. python"${PY_VERSION}" setup.py build bdist_wheel if [ $? != 0 ]; then echo "Failed to compile the wheel file. Please check the source code by yourself." diff --git a/common/modules.py b/common/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..5d8b41335fda11ab57bb98b63126962ac785b155 --- /dev/null +++ b/common/modules.py @@ -0,0 +1,32 @@ +import torch +from torch.autograd import Function +from torch.nn import Module + +import torch_npu +import ads + + +class ScatterMaxFunction(Function): + @staticmethod + def forward(ctx, updates, indices, out=None): + func = ads.npu_scatter_max + out, argmax = func(updates, indices, out) + ctx.save_for_backward(argmax, updates) + return out, argmax + + @staticmethod + def backward(ctx, grad_output, grad_argmax): + argmax, updates = ctx.saved_tensors + + device = argmax.device + grad_updates_index0 = argmax.unsqueeze(-1) + grad_updates_index1 = torch.tile(torch.arange(0, argmax.shape[1]), argmax.shape[0:1:1]).reshape(argmax.shape).unsqueeze(-1).to(device) + grad_updates_indices = torch.concat((grad_updates_index0, grad_updates_index1), -1).to(device) + grad_updates_indices_uss = grad_updates_indices[..., 0] * grad_updates_indices.shape[1] + grad_updates_indices[..., 1] + num_segments = torch.tensor(updates.shape[0] * updates.shape[1]).to(device) + + grad = ads.npu_scatter_max_backward(grad_output, grad_updates_indices_uss, num_segments) + + return grad.reshape(updates.shape), None, None + +npuscattermax = ScatterMaxFunction.apply diff --git a/common/ops/csrc/ScatterMaxKernelNpu.cpp b/common/ops/csrc/ScatterMaxKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c4e943842d775364a5ba471deae2589945cf1f48 --- /dev/null +++ b/common/ops/csrc/ScatterMaxKernelNpu.cpp @@ -0,0 +1,59 @@ +#include + +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" + +using namespace at; +using namespace std; + +using torch::autograd::Function; +using torch::autograd::AutogradContext; +using tensor_list = std::vector; + +std::tuple npu_scatter_max( + const at::Tensor& updates, + const at::Tensor& indices, + c10::optional out) +{ + auto sizes = updates.sizes().vec(); + + sizes[0] = indices.max().item().toLong() + 1; + + at::Tensor result = out.value_or(at::zeros(sizes, updates.options().dtype(at::kFloat))); + at::Tensor argmax = at_npu::native::OpPreparation::ApplyTensor(result, result.options().dtype(at::kInt)); + + at_npu::native::OpCommand cmd; + cmd.Name("ScatterMaxWithArgmax") + .Input(result) + .Input(indices) + .Input(updates) + .Output(result) + .Output(argmax) + .Run(); + + return std::tie(result, argmax); +} + +at::Tensor npu_scatter_max_backward(const at::Tensor& x, const at::Tensor& segment_ids, const at::Tensor& num_segments) +{ + c10::SmallVector output_size; + + auto num_segments_value = num_segments.item().toLong(); + output_size.push_back(num_segments_value); + + auto x_sizes = x.sizes(); + auto segment_ids_dims = segment_ids.dim(); + + copy(x_sizes.begin() + segment_ids_dims, x_sizes.end(), std::back_inserter(output_size)); + + at::Tensor out = at_npu::native::OpPreparation::ApplyTensor(x, output_size); + at_npu::native::OpCommand cmd; + cmd.Name("UnsortedSegmentSum") + .Input(x) + .Input(segment_ids) + .Input(num_segments) + .Output(out) + .Run(); + return out; +} diff --git a/common/ops/csrc/functions.h b/common/ops/csrc/functions.h new file mode 100644 index 0000000000000000000000000000000000000000..9371de3324b6719d0c798665fa24dde51c1088f6 --- /dev/null +++ b/common/ops/csrc/functions.h @@ -0,0 +1,5 @@ +#include +#include + +std::tuple npu_scatter_max(const at::Tensor& updates, const at::Tensor& indices, c10::optional out); +at::Tensor npu_scatter_max_backward(const at::Tensor& x, const at::Tensor& segment_ids, const at::Tensor& num_segments); diff --git a/common/ops/pybind.cpp b/common/ops/pybind.cpp index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..f5ea46465c94e61586aff388207a280dbecd6103 100644 --- a/common/ops/pybind.cpp +++ b/common/ops/pybind.cpp @@ -0,0 +1,7 @@ +#include +#include "csrc/functions.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("npu_scatter_max", &npu_scatter_max); + m.def("npu_scatter_max_backward", &npu_scatter_max_backward); +} diff --git a/cpp_extension.py b/cpp_extension.py new file mode 100644 index 0000000000000000000000000000000000000000..e0460247614cc40d59adf6b5a2887faa6d837ca7 --- /dev/null +++ b/cpp_extension.py @@ -0,0 +1,76 @@ +# Copyright (c) 2022 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +import setuptools +import torch +import torch.utils.cpp_extension as TorchExtension + +try: + import torch_npu + PYTORCH_NPU_INSTALL_PATH = os.path.dirname(os.path.realpath(torch_npu.__file__)) +except: + PYTORCH_NPU_INSTALL_PATH = "/opt/_internal/cpython-3.7.5/lib/python3.7/site-packages/torch_npu/" + + +def NpuExtension(name, sources, *args, **kwargs): + r''' + Creates a :class:`setuptools.Extension` for C++. + + Convenience method that creates a :class:`setuptools.Extension` with the + bare minimum (but often sufficient) arguments to build a C++ extension. + + All arguments are forwarded to the :class:`setuptools.Extension` + constructor. + + Example: + >>> from setuptools import setup + >>> from torch_npu.utils.cpp_extension import NpuExtension + >>> setup( + name='extension', + ext_modules=[ + NpuExtension( + name='extension', + sources=['extension.cpp'], + extra_compile_args=['-g']), + ], + cmdclass={ + 'build_ext': BuildExtension + }) + ''' + + torch_npu_dir = PYTORCH_NPU_INSTALL_PATH + include_dirs = kwargs.get('include_dirs', []) + include_dirs.append(os.path.join(torch_npu_dir, 'include')) + include_dirs += TorchExtension.include_paths() + kwargs['include_dirs'] = include_dirs + + library_dirs = kwargs.get('library_dirs', []) + library_dirs.append(os.path.join(torch_npu_dir, 'lib')) + library_dirs += TorchExtension.library_paths() + kwargs['library_dirs'] = library_dirs + + libraries = kwargs.get('libraries', []) + libraries.append('c10') + libraries.append('torch') + libraries.append('torch_cpu') + libraries.append('torch_python') + libraries.append('torch_npu') + kwargs['libraries'] = libraries + + kwargs['language'] = 'c++' + return setuptools.Extension(name, sources, *args, **kwargs) diff --git a/setup.py b/setup.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..03a7c2400bea3d34488bd405e39375ec86b4ca9a 100644 --- a/setup.py +++ b/setup.py @@ -0,0 +1,28 @@ +import os +import glob +from setuptools import setup, find_packages +# from torch_npu.utils import cpp_extension +from torch.utils.cpp_extension import BuildExtension +import cpp_extension + +source_file = [] +source_file += glob.glob(os.path.join("./common/ops/csrc/", "*.cpp")) +source_file += glob.glob(os.path.join("./common/ops/", "*.cpp")) + +exts = [] +ext1 = cpp_extension.NpuExtension( + name="ads", + sources=source_file, +) + +exts.append(ext1) +setup( + name="ads", + version="1.0", + description='Cpp Extension Include ascend_accelerator', + keywords='ads', + ext_modules=exts, + author='Ascend Contributors', + cmdclass={"build_ext": BuildExtension}, + packages=find_packages() +) diff --git a/tests/run_test.py b/tests/run_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5fe2c9134100e29afd38fa94394e290c9e04387c --- /dev/null +++ b/tests/run_test.py @@ -0,0 +1,183 @@ +import argparse +import pathlib +import os +import sys +import signal +import math +from datetime import datetime, timezone +from typing import Optional, List + +import torch +from torch.testing._internal.common_utils import shell + +import torch_npu + +REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent + +# https://stackoverflow.com/questions/2549939/get-signal-names-from-numbers-in-python +SIGNALS_TO_NAMES_DICT = { + getattr(signal, n): n for n in dir(signal) if + n.startswith("SIG") and "_" not in n +} + + +def print_to_stderr(message): + print(message, file=sys.stderr) + + +def discover_tests( + base_dir: Optional[pathlib.Path] = None, + blocklisted_patterns: Optional[List[str]] = None, + blocklisted_tests: Optional[List[str]] = None, + extra_tests: Optional[List[str]] = None) -> List[str]: + """ + Searches for all python files starting with test_ excluding one specified by patterns + """ + + def skip_test_p(name: str) -> bool: + rc = False + if blocklisted_patterns is not None: + rc |= any( + name.startswith(pattern) for pattern in blocklisted_patterns) + if blocklisted_tests is not None: + rc |= name in blocklisted_tests + return rc + + cwd = pathlib.Path( + __file__).resolve().parent if base_dir is None else base_dir + all_py_files = list(cwd.glob('**/test_*.py')) + rc = [str(fname.relative_to(cwd))[:-3] for fname in all_py_files] + rc = [test for test in rc if not skip_test_p(test)] + if extra_tests is not None: + rc += extra_tests + return sorted(rc) + + +def parse_test_module(test): + return pathlib.Path(test).parts[0] + + +TESTS = discover_tests( + blocklisted_patterns=[], + blocklisted_tests=[], + extra_tests=[] +) + +TESTS_MODULE = list(set([parse_test_module(test) for test in TESTS])) + +TEST_CHOICES = TESTS + TESTS_MODULE + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run the PyTorch unit test suite", + epilog="where TESTS is any of: {}".format(", ".join(TESTS)), + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "-v", + "--verbose", + action="count", + default=0, + help="print verbose information and test-by-test results", + ) + parser.add_argument( + "-i", + "--include", + nargs="+", + choices=TEST_CHOICES, + default=TESTS, + metavar="TESTS", + help="select a set of tests to include (defaults to ALL tests)." + " tests must be a part of the TESTS list defined in run_test.py", + ) + parser.add_argument( + "additional_unittest_args", + nargs="*", + help="additional arguments passed through to unittest, e.g., " + "python run_test.py -i sparse -- TestSparse.test_factory_size_check", + ) + return parser.parse_args() + + +def get_selected_tests(options): + selected_tests = [] + if options.include: + for item in options.include: + selected_tests.extend( + list(filter(lambda test_name: item == test_name \ + or ( + item in TESTS_MODULE and test_name.startswith( + item)), TESTS))) + else: + selected_tests = TESTS + return selected_tests + + +def run_test(test, test_directory, options): + unittest_args = options.additional_unittest_args.copy() + + if options.verbose: + unittest_args.append("-v") + # get python cmd. + executable = [sys.executable] + + # Can't call `python -m unittest test_*` here because it doesn't run code + # in `if __name__ == '__main__': `. So call `python test_*.py` instead. + argv = [test + ".py"] + unittest_args + + command = executable + argv + print_to_stderr( + "Executing {} ... [{}]".format(command, datetime.now(tz=timezone.utc))) + return shell(command, test_directory) + + +def run_test_module(test: str, test_directory: str, options) -> Optional[str]: + print_to_stderr( + "Running {} ... [{}]".format(test, datetime.now(tz=timezone.utc))) + + return_code = run_test(test, test_directory, options) + if not (isinstance(return_code, int) and not isinstance(return_code, bool)): + raise TypeError("Return code should be an integer") + if return_code == 0: + return None + + message = f"exec ut {test} failed!" + if return_code < 0: + # subprocess.Popen returns the child process' exit signal as + # return code -N, where N is the signal number. + signal_name = SIGNALS_TO_NAMES_DICT[-return_code] + message += f" Received signal: {signal_name}" + return message + + +def main(): + options = parse_args() + test_directory = os.path.join(REPO_ROOT, "tests") + selected_tests = get_selected_tests(options) + + if options.verbose: + print_to_stderr("Selected tests: {}".format(", ".join(selected_tests))) + + has_failed = False + failure_msgs = [] + + for test in selected_tests: + err_msg = run_test_module(test, test_directory, options) + + if err_msg is None: + continue + has_failed = True + failure_msgs.append(err_msg) + print_to_stderr(err_msg) + + if has_failed: + for err in failure_msgs: + print_to_stderr(err) + return False + return True + + +if __name__ == "__main__": + if not main(): + sys.exit(1) \ No newline at end of file diff --git a/tests/test_scatter_max.py b/tests/test_scatter_max.py new file mode 100644 index 0000000000000000000000000000000000000000..6b96f27a1e132d77e23981d9da9095cc833353d9 --- /dev/null +++ b/tests/test_scatter_max.py @@ -0,0 +1,75 @@ +import torch +import numpy as np +import torch_scatter + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +from common import modules + + +class TestScatterMaxWithArgmax(TestCase): + + def cpu_op_exec(self, updates, indices): + updates.requires_grad = True + + output, output_argmax = torch_scatter.scatter_max(updates, indices.long(), dim=0) + output.backward(torch.ones_like(output)) + + output_grad = updates.grad + output_grad = output_grad.detach().numpy() + output = output.detach().numpy() + output_argmax = output_argmax.to(torch.int32).numpy() + + return output, output_argmax, output_grad + + def npu_op_exec(self, updates, indices): + updates.requires_grad = True + + output, output_argmax = modules.npuscattermax(updates, indices) + output.backward(torch.ones_like(output)) + + output_grad = updates.grad.cpu() + output_grad = output_grad.detach().numpy() + output = output.cpu() + output = output.detach().numpy() + output_argmax = output_argmax.cpu().numpy() + + return output, output_argmax, output_grad + + def test_scatter_max_with_argmax_1(self): + shape_updates = (262144, 16) + shape_indices = (262144, 1) + cpu_updates_input, npu_updates_input = create_common_tensor(["float32", 2, shape_updates], 0, 262144) + cpu_indices_input, npu_indices_input = create_common_tensor(["int32", 2, shape_indices], 0, 262144) + cpu_output = self.cpu_op_exec(cpu_updates_input, cpu_indices_input) + npu_output = self.npu_op_exec(npu_updates_input, npu_indices_input) + self.assertRtolEqual(cpu_output[0], npu_output[0]) + self.assertRtolEqual(cpu_output[1], npu_output[1]) + self.assertRtolEqual(cpu_output[2], npu_output[2]) + + def test_scatter_max_with_argmax_2(self): + shape_updates = (78848, 16) + shape_indices = (78848, 1) + cpu_updates_input, npu_updates_input = create_common_tensor(["float32", 2, shape_updates], 0, 78848) + cpu_indices_input, npu_indices_input = create_common_tensor(["int32", 2, shape_indices], 0, 78848) + cpu_output = self.cpu_op_exec(cpu_updates_input, cpu_indices_input) + npu_output = self.npu_op_exec(npu_updates_input, npu_indices_input) + self.assertRtolEqual(cpu_output[0], npu_output[0]) + self.assertRtolEqual(cpu_output[1], npu_output[1]) + self.assertRtolEqual(cpu_output[2], npu_output[2]) + + def test_scatter_max_with_argmax_3(self): + shape_updates = (1024, 16) + shape_indices = (1024, 1) + cpu_updates_input, npu_updates_input = create_common_tensor(["float32", 2, shape_updates], 0, 100) + cpu_indices_input, npu_indices_input = create_common_tensor(["int32", 2, shape_indices], 0, 100) + cpu_output = self.cpu_op_exec(cpu_updates_input, cpu_indices_input) + npu_output = self.npu_op_exec(npu_updates_input, npu_indices_input) + self.assertRtolEqual(cpu_output[0], npu_output[0]) + self.assertRtolEqual(cpu_output[1], npu_output[1]) + self.assertRtolEqual(cpu_output[2], npu_output[2]) + + +if __name__ == "__main__": + run_tests()