From 14cb3c82f5c4cdba49ae7da564f47260d7e45d98 Mon Sep 17 00:00:00 2001 From: mengyuanli Date: Fri, 25 Jul 2025 09:45:36 +0800 Subject: [PATCH 1/3] reshape_and_cache asd --- .../ms_kernels_internal/internal_helper.cc | 11 - .../ms_kernels_internal/internal_helper.h | 4 - .../pyboost/internal_pyboost_runner.cc | 13 +- .../pyboost/internal_pyboost_runner.h | 3 +- .../ms_kernels_internal/reshape_and_cache.cc | 63 ++- tests/st/test_custom_reshape_and_cache.py | 482 ++++++++++++++---- 6 files changed, 445 insertions(+), 131 deletions(-) diff --git a/ccsrc/base/ms_kernels_internal/internal_helper.cc b/ccsrc/base/ms_kernels_internal/internal_helper.cc index a37baf4..c4ca74a 100644 --- a/ccsrc/base/ms_kernels_internal/internal_helper.cc +++ b/ccsrc/base/ms_kernels_internal/internal_helper.cc @@ -30,17 +30,6 @@ #include namespace ms_custom_ops { -std::string TransInternalOpName(const std::string &ms_op_name) { - auto internal_name = - InternalNameMapper::GetInstance().GetInternalName(ms_op_name); - if (internal_name.empty()) { - MS_LOG(EXCEPTION) - << "Op " << ms_op_name - << " is supported in Internal, but the name is not mapped"; - } - return internal_name; -} - InternalNameMapper &InternalNameMapper::GetInstance() { static InternalNameMapper name_mammer; return name_mammer; diff --git a/ccsrc/base/ms_kernels_internal/internal_helper.h b/ccsrc/base/ms_kernels_internal/internal_helper.h index 0f89ef2..cb173fc 100644 --- a/ccsrc/base/ms_kernels_internal/internal_helper.h +++ b/ccsrc/base/ms_kernels_internal/internal_helper.h @@ -36,8 +36,6 @@ inline internal::ShapeInfo TransInternalShape(const ShapeVector &shape) { return internal_shape; } -std::string TransInternalOpName(const std::string &ms_op_name); - bool CheckDefaultSupportFormat(const std::string &format); internal::DataType TransInternalDataType(TypeId ms_type); @@ -73,8 +71,6 @@ class InternalNameRegistrar { public: InternalNameRegistrar(const std::string &ms_name, const std::string &internal_name) { - std::cout << "InternalNameRegistrar : ms_name: " << ms_name - << ", internal_name: " << internal_name << std::endl; InternalNameMapper::GetInstance().Insert(ms_name, internal_name); } ~InternalNameRegistrar() = default; diff --git a/ccsrc/base/ms_kernels_internal/pyboost/internal_pyboost_runner.cc b/ccsrc/base/ms_kernels_internal/pyboost/internal_pyboost_runner.cc index 4dff6c8..c2ada7c 100644 --- a/ccsrc/base/ms_kernels_internal/pyboost/internal_pyboost_runner.cc +++ b/ccsrc/base/ms_kernels_internal/pyboost/internal_pyboost_runner.cc @@ -27,10 +27,7 @@ void InternalPyboostRunner::GetOrCreateKernel(const TensorList &inputs, MS_LOG(DEBUG) << "Internal Op [" << this->op_name() << "] hit cache"; } else { MS_LOG(DEBUG) << "Internal Op [" << this->op_name() << "] miss cache"; - if (!IsInternalDtypeSupport(inputs, outputs)) { - MS_LOG(EXCEPTION) << "Input dtype is not supported for internal op [" - << this->op_name() << "]"; - } + TransDataType(inputs, outputs); UpdateArgImmutableInfo(&inputs_ii_, inputs, true); UpdateArgImmutableInfo(&outputs_ii_, outputs); internal_op_ = CreateKernel(inputs_ii_, outputs_ii_); @@ -70,8 +67,8 @@ size_t InternalPyboostRunner::CalcWorkspace() { 0); } -bool InternalPyboostRunner::IsInternalDtypeSupport( - const TensorList &ms_inputs, const TensorList &ms_outputs) { +void InternalPyboostRunner::TransDataType(const TensorList &ms_inputs, + const TensorList &ms_outputs) { internal_inputs_dtype_.resize(ms_inputs.size()); internal_outputs_dtype_.resize(ms_outputs.size()); @@ -92,10 +89,6 @@ bool InternalPyboostRunner::IsInternalDtypeSupport( internal_outputs_dtype_[i] = TransInternalDataType(ms_outputs[i].data_type()); } - - return mindspore::internal::IsInternalKernelDtypesSupported( - TransInternalOpName(this->op_name()), internal_inputs_dtype_, - internal_outputs_dtype_); } TilingCacheItemPtr InternalPyboostRunner::GetOrGenerateTiling() { diff --git a/ccsrc/base/ms_kernels_internal/pyboost/internal_pyboost_runner.h b/ccsrc/base/ms_kernels_internal/pyboost/internal_pyboost_runner.h index a56dfc5..1c1fa6a 100644 --- a/ccsrc/base/ms_kernels_internal/pyboost/internal_pyboost_runner.h +++ b/ccsrc/base/ms_kernels_internal/pyboost/internal_pyboost_runner.h @@ -64,8 +64,7 @@ protected: virtual bool UpdateParam() { return true; } protected: - bool IsInternalDtypeSupport(const TensorList &ms_inputs, - const TensorList &ms_outputs); + void TransDataType(const TensorList &ms_inputs, const TensorList &ms_outputs); TilingCacheItemPtr GetOrGenerateTiling(); virtual internal::InternalOpPtr diff --git a/ccsrc/ops/ms_kernels_internal/reshape_and_cache.cc b/ccsrc/ops/ms_kernels_internal/reshape_and_cache.cc index 764fddd..e9d1226 100644 --- a/ccsrc/ops/ms_kernels_internal/reshape_and_cache.cc +++ b/ccsrc/ops/ms_kernels_internal/reshape_and_cache.cc @@ -62,7 +62,7 @@ constexpr size_t kInputHeadNumIndex = 5; constexpr size_t kOutputIndex = 0; class CustomReshapeAndCache : public InternalKernelMod { public: - CustomReshapeAndCache() : InternalKernelMod() {} + CustomReshapeAndCache() : InternalKernelMod(), skip_execution_(false) {} ~CustomReshapeAndCache() = default; void InitKernelInputsOutputsIndex() override { @@ -71,15 +71,64 @@ public: kernel_outputs_index_ = {kOutputIndex}; } + int Resize(const std::vector &inputs, + const std::vector &outputs) override { + // Check if any input has shape containing 0 + for (const auto &input : inputs) { + if (input == nullptr) + continue; + auto shape = input->GetShapeVector(); + for (const auto &dim : shape) { + if (dim == 0) { + MS_LOG(INFO) << "ReshapeAndCache: Skipping execution due to zero " + "dimension in input shape: " + << shape; + skip_execution_ = true; + return KernelMod::Resize(inputs, outputs); // Skip execution + } + } + } + + skip_execution_ = false; + // Call base class implementation + return InternalKernelMod::Resize(inputs, outputs); + } + + bool Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, + void *stream_ptr) override { + // Skip execution if flag is set + if (skip_execution_) { + return true; // Skip execution, return success + } + + // Call base class implementation + return InternalKernelMod::Launch(inputs, workspace, outputs, stream_ptr); + } + protected: internal::InternalOpPtr CreateKernel(const internal::InputsImmutableInfoList &inputs, const internal::OutputsImmutableInfoList &outputs, const std::vector &ms_inputs, const std::vector &ms_outputs) override { - return internal::CreateReshapeAndCacheOp( - inputs, outputs, internal::kInternalReshapeAndCacheOpName); + internal::ReshapeAndCacheParam param; + auto head_num = ms_inputs.at(internal::kIndex5); + if (head_num->dtype_id() == TypeId::kNumberTypeInt64) { + param.head_num = + static_cast(head_num->GetValue().value()); + } else { + MS_LOG(EXCEPTION) + << "ReshapeAndCache [head_num]'s dtype wrong, expect int64, but got: " + << head_num->dtype_id(); + } + return internal::CreateAsdReshapeAndCacheOp( + inputs, outputs, param, internal::kInternalAsdReshapeAndCacheOpName); } + +private: + bool skip_execution_; // Flag to skip execution when shape contains 0 }; } // namespace ms_custom_ops @@ -104,15 +153,17 @@ protected: internal::InternalOpPtr CreateKernel(const internal::InputsImmutableInfoList &inputs, const internal::OutputsImmutableInfoList &outputs) override { - return internal::CreateReshapeAndCacheOp( - inputs, outputs, internal::kInternalReshapeAndCacheOpName); + internal::ReshapeAndCacheParam param; + param.head_num = this->head_num_; + return internal::CreateAsdReshapeAndCacheOp( + inputs, outputs, param, internal::kInternalAsdReshapeAndCacheOpName); } private: int32_t head_num_{0}; }; MS_KERNELS_INTERNAL_NAME_REG(ReshapeAndCache, - internal::kInternalReshapeAndCacheOpName); + internal::kInternalAsdReshapeAndCacheOpName); } // namespace ms::pynative namespace ms_custom_ops { diff --git a/tests/st/test_custom_reshape_and_cache.py b/tests/st/test_custom_reshape_and_cache.py index 5ca5307..2598d79 100644 --- a/tests/st/test_custom_reshape_and_cache.py +++ b/tests/st/test_custom_reshape_and_cache.py @@ -14,98 +14,154 @@ # ============================================================================ """ tests_custom_pyboost_ascend """ +# Standard library imports +from enum import Enum + +# Third-party imports import numpy as np -import mindspore as ms -from mindspore.ops import CustomOpBuilder, ModuleWrapper -from mindspore import Tensor, context, Parameter, ops import pytest + +# MindSpore imports +import mindspore as ms +from mindspore import Tensor, context, Parameter, ops, nn +from mindspore.common.api import jit +from mindspore.common.np_dtype import bfloat16 + +# Local imports import ms_custom_ops -@pytest.mark.parametrize('exec_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) -@pytest.mark.parametrize('np_dtype', [np.float16]) -@pytest.mark.parametrize('kv_dim', [3]) -def test_custom_reshape_and_cache(exec_mode, np_dtype, kv_dim): - ms.set_device("Ascend") - ms.set_context(mode=exec_mode) - - class MyNet(ms.nn.Cell): - def __init__(self): - super().__init__() - - def construct(self, key, value, key_cache, value_cache, slot_mapping, head_num): - return ms_custom_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, head_num) - - num_slots = 50 - slot_size = 128 - b = 13 - s = 32 - n = 40 - d = 128 - - def create_nd_inputs(dtype=np.float16, kv_dim=3): - cache_shape = (num_slots, slot_size, n, d) - if kv_dim == 2: - update_shape = (b * s, n * d) - num_tokens = update_shape[0] - elif kv_dim == 3: - update_shape = (b, s, n * d) - num_tokens = update_shape[0] * update_shape[1] - else: - raise Exception( - "Key's dim should be 2 or 3, but got {0}".format(kv_dim)) - - if dtype == np.int8: - key_update = np.random.randint(low=-128, high=127, - size=update_shape, - dtype=np.int8) - value_update = np.random.randint(low=-128, high=127, - size=update_shape, - dtype=np.int8) - key_cache = np.random.randint(low=-128, high=127, +num_slots = 20 +slot_size = 64 +b = 13 +s = 3 +n = 16 +d = 32 + + +class ReshapeAndCacheAllNz(nn.Cell): + def __init__(self): + super().__init__() + + @jit + def construct(self, key, value, key_cache, value_cache, slot_map, head_num=0): + out = ms_custom_ops.reshape_and_cache( + key, value, key_cache, value_cache, slot_map, head_num) + return out + + +class ReshapeAndCacheKey(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, key, key_cache, slot_map): + out = ms_custom_ops.reshape_and_cache( + key, key_cache=key_cache, slot_mapping=slot_map) + return out + + +class ReshapeAndCacheAllNd(nn.Cell): + def __init__(self): + super().__init__() + + def construct(self, key, value=None, key_cache=None, value_cache=None, slot_map=None): + out = ms_custom_ops.reshape_and_cache( + key, value, key_cache, value_cache, slot_map) + return out + + +def create_ms_inputs(np_k, np_v, np_k_cache, np_v_cache, np_slot_map, format="", exec_mode=context.GRAPH_MODE): + """ + create inputs + """ + ms_key = Tensor(np_k) + ms_value = Tensor(np_v) + if exec_mode == context.GRAPH_MODE: + ms_key_cache = Parameter(Tensor(np_k_cache), storage_format=format, name="key_cache") + ms_value_cache = Parameter(Tensor(np_v_cache), storage_format=format, name="value_cache") + else: + ms_key_cache = Tensor(np_k_cache) + ms_value_cache = Tensor(np_v_cache) + ms_slot_map = Tensor(np_slot_map) + return ms_key, ms_value, ms_key_cache, ms_value_cache, ms_slot_map + + +# =============================== +# test nd format +# =============================== +def create_nd_inputs(dtype=np.float16, kv_dim=3): + """ + create_nd_inputs + """ + cache_shape = (num_slots, slot_size, n, d) + if kv_dim == 2: + update_shape = (b * s, n * d) + num_tokens = update_shape[0] + elif kv_dim == 3: + update_shape = (b, s, n * d) + num_tokens = update_shape[0] * update_shape[1] + else: + raise Exception( + "Key's dim should be 2 or 3, but got {0}".format(kv_dim)) + + if dtype == np.int8: + key_update = np.random.randint(low=-128, high=127, + size=update_shape, + dtype=np.int8) + value_update = np.random.randint(low=-128, high=127, + size=update_shape, + dtype=np.int8) + key_cache = np.random.randint(low=-128, high=127, + size=cache_shape, + dtype=np.int8) + value_cache = np.random.randint(low=-128, high=127, size=cache_shape, dtype=np.int8) - value_cache = np.random.randint(low=-128, high=127, - size=cache_shape, - dtype=np.int8) - else: - key_update = np.random.rand(*update_shape).astype(dtype) - value_update = np.random.rand(*update_shape).astype(dtype) - key_cache = np.random.rand(*cache_shape).astype(dtype) - value_cache = np.random.rand(*cache_shape).astype(dtype) - - slot_map = np.random.choice(np.arange(num_tokens), num_tokens, - replace=False).astype(np.int32) - - return key_update, value_update, key_cache, value_cache, slot_map - - def nd_inference(key, value, key_cache, value_cache, slot_map): - key_tmp = key.copy() - value_tmp = value.copy() - key_cache_ans = key_cache.copy() - value_cache_ans = value_cache.copy() - head = key_cache.shape[2] - head_dim = key_cache.shape[3] - key_tmp = key_tmp.reshape(-1, head, head_dim) - value_tmp = value_tmp.reshape(-1, head, head_dim) - for i, slot in enumerate(slot_map): - slot_idx = slot // key_cache.shape[1] - slot_offset = slot % key_cache.shape[1] - key_cache_ans[slot_idx][slot_offset] = key_tmp[i] - value_cache_ans[slot_idx][slot_offset] = value_tmp[i] - return key_cache_ans, value_cache_ans - - def create_ms_inputs(np_k, np_v, np_k_cache, np_v_cache, np_slot_map, format="", exec_mode=context.GRAPH_MODE): - ms_key = Tensor(np_k) - ms_value = Tensor(np_v) - if exec_mode == context.GRAPH_MODE: - ms_key_cache = Parameter(Tensor(np_k_cache), storage_format=format, name="key_cache") - ms_value_cache = Parameter(Tensor(np_v_cache), storage_format=format, name="value_cache") - else: - ms_key_cache = Tensor(np_k_cache) - ms_value_cache = Tensor(np_v_cache) - ms_slot_map = Tensor(np_slot_map) - return ms_key, ms_value, ms_key_cache, ms_value_cache, ms_slot_map - + else: + key_update = np.random.rand(*update_shape).astype(dtype) + value_update = np.random.rand(*update_shape).astype(dtype) + key_cache = np.random.rand(*cache_shape).astype(dtype) + value_cache = np.random.rand(*cache_shape).astype(dtype) + + slot_map = np.random.choice(np.arange(num_tokens), num_tokens, + replace=False).astype(np.int32) + + return key_update, value_update, key_cache, value_cache, slot_map + + +def nd_inference(key, value, key_cache, value_cache, slot_map): + """ + nd_inference + """ + key_tmp = key.copy() + value_tmp = value.copy() + key_cache_ans = key_cache.copy() + value_cache_ans = value_cache.copy() + head = key_cache.shape[2] + head_dim = key_cache.shape[3] + key_tmp = key_tmp.reshape(-1, head, head_dim) + value_tmp = value_tmp.reshape(-1, head, head_dim) + for i, slot in enumerate(slot_map): + slot_idx = slot // key_cache.shape[1] + slot_offset = slot % key_cache.shape[1] + key_cache_ans[slot_idx][slot_offset] = key_tmp[i] + value_cache_ans[slot_idx][slot_offset] = value_tmp[i] + return key_cache_ans, value_cache_ans + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize('np_dtype', [np.float16, np.int8, bfloat16]) +@pytest.mark.parametrize('kv_dim', [2, 3]) +def test_reshape_and_cache_nd_key_value(np_dtype, kv_dim): + """ + Feature: Test ReshapeAndCache. + Description: Test float16 inputs. + Expectation: Assert that results are consistent with numpy. + """ + context.set_context(device_target="Ascend", mode=context.GRAPH_MODE) + net = ReshapeAndCacheAllNz() + np_k, np_v, np_k_cache, np_v_cache, np_slot_map = create_nd_inputs( np_dtype, kv_dim) np_k_cache_out, np_v_cache_out = nd_inference( @@ -113,18 +169,248 @@ def test_custom_reshape_and_cache(exec_mode, np_dtype, kv_dim): ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( np_k, np_v, np_k_cache, np_v_cache, np_slot_map) - _ = MyNet()(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map, n) + _ = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map) + + if np_dtype == bfloat16: + assert np.allclose(ms_k_cache.float().asnumpy(), + np_k_cache_out.astype(np.float32), 0.001, 0.001) + assert np.allclose(ms_v_cache.float().asnumpy(), + np_v_cache_out.astype(np.float32), 0.001, 0.001) + else: + assert np.allclose(ms_k_cache.asnumpy(), np_k_cache_out, 0.001, 0.001) + assert np.allclose(ms_v_cache.asnumpy(), np_v_cache_out, 0.001, 0.001) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize('np_dtype', [np.float16, np.int8, bfloat16]) +@pytest.mark.parametrize('kv_dim', [2, 3]) +def test_reshape_and_cache_nd_key(np_dtype, kv_dim): + """ + Feature: Test ReshapeAndCache. + Description: Test float16 inputs. + Expectation: Assert that results are consistent with numpy. + """ + context.set_context(device_target="Ascend", mode=context.GRAPH_MODE) + context.set_context(jit_config={"jit_level": "O0"}) + net = ReshapeAndCacheKey() + + np_k, np_v, np_k_cache, np_v_cache, np_slot_map = create_nd_inputs( + np_dtype, kv_dim) + np_k_cache_out, np_v_cache_out = nd_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + _ = net(ms_k, key_cache=ms_k_cache, slot_map=ms_slot_map) + + if np_dtype == bfloat16: + assert np.allclose(ms_k_cache.float().asnumpy(), + np_k_cache_out.astype(np.float32), 0.001, 0.001) + else: + assert np.allclose(ms_k_cache.asnumpy(), np_k_cache_out, 0.001, 0.001) + + +# @pytest.mark.level0 +# @pytest.mark.platform_ascend910b +# @pytest.mark.env_onecard +# @pytest.mark.parametrize('np_dtype', [np.float16, np.int8, bfloat16]) +# @pytest.mark.parametrize('kv_dim', [2, 3]) +# def test_reshape_and_cache_expect_fail(np_dtype, kv_dim): +# """ +# Feature: Test ReshapeAndCache. +# Description: Test float16 inputs. +# Expectation: Assert that results are consistent with numpy. +# """ +# context.set_context(device_target="Ascend", mode=context.GRAPH_MODE) +# context.set_context(jit_config={"jit_level": "O0"}) +# net = ReshapeAndCacheAllNd() + +# np_k, np_v, np_k_cache, np_v_cache, np_slot_map = create_nd_inputs( +# np_dtype, kv_dim) +# np_k_cache_out, np_v_cache_out = nd_inference( +# np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + +# ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( +# np_k, np_v, np_k_cache, np_v_cache, np_slot_map) +# with pytest.raises(Exception): +# _ = net(ms_k, ms_v, key_cache=ms_k_cache, slot_map=ms_slot_map) +# with pytest.raises(Exception): +# _ = net(ms_k, key_cache=ms_k_cache, +# value_cache=ms_v_cache, slot_map=ms_slot_map) + + +# =============================== +# test nz format +# =============================== +def create_nz_inputs(k_dtype=np.float16, v_dtype=np.float16, kv_dim=3): + """ + create_nz_inputs + """ + k_cache_shape = (num_slots, slot_size, n * d) + v_cache_shape = (num_slots, slot_size, n * d) + + if kv_dim == 2: + update_shape = (b * s, n * d) + num_tokens = update_shape[0] + elif kv_dim == 3: + update_shape = (b, s, n * d) + num_tokens = update_shape[0] * update_shape[1] + else: + raise Exception( + "Key's dim should be 2 or 3, but got {0}".format(kv_dim)) + + if k_dtype == np.int8: + key_update = np.random.randint(low=-128, high=127, + size=update_shape, + dtype=np.int8) + # key_cache = np.random.randint(low=-128, high=127, + # size=k_cache_shape, + # dtype=np.int8) + else: + key_update = np.random.rand(*update_shape).astype(k_dtype) + # key_cache = np.random.rand(*k_cache_shape).astype(k_dtype) + key_cache = np.zeros(k_cache_shape, dtype=k_dtype) + + if v_dtype == np.int8: + value_update = np.random.randint(low=-128, high=127, + size=update_shape, + dtype=np.int8) + # value_cache = np.random.randint(low=-128, high=127, + # size=v_cache_shape, + # dtype=np.int8) + else: + value_update = np.random.rand(*update_shape).astype(v_dtype) + # value_cache = np.random.rand(*v_cache_shape).astype(v_dtype) + value_cache = np.zeros(v_cache_shape, dtype=v_dtype) + + slot_map = np.random.choice(np.arange(num_tokens), num_tokens, + replace=False).astype(np.int32) + + return key_update, value_update, key_cache, value_cache, slot_map + + +def nz_inference(key, value, key_cache, value_cache, slot_map): + """ + nz_inference + """ + key_tmp = key.copy() + value_tmp = value.copy() + key_cache_ans = key_cache.copy() + value_cache_ans = value_cache.copy() + key_tmp = key_tmp.reshape(-1, key_cache.shape[2]) + value_tmp = value_tmp.reshape(-1, key_cache.shape[2]) + for i, slot in enumerate(slot_map): + slot_idx = slot // key_cache.shape[1] + slot_offset = slot % key_cache.shape[1] + key_cache_ans[slot_idx][slot_offset] = key_tmp[i] + value_cache_ans[slot_idx][slot_offset] = value_tmp[i] + return key_cache_ans, value_cache_ans + + +def get_nz_cached_slots(cache, slot_map): + ans = [] + tmp = [] + + print(f"=========cache shape: {cache.shape}") + + num_slots = cache.shape[0] + slot_size = cache.shape[1] + hidden_size = cache.shape[2] + + if cache.dtype == np.int8: + cache_shape = (num_slots, hidden_size // 32, slot_size, 32) + else: + cache_shape = (num_slots, hidden_size // 16, slot_size, 16) + cache = cache.reshape(cache_shape) + for i, slot in enumerate(slot_map): + if slot < 0: + continue + slot_idx = slot // slot_size + slot_offset = slot % slot_size + for j in range(cache.shape[1]): + tmp.append(cache[slot_idx][j][slot_offset]) + ans.append(np.concatenate(tmp, axis=0)) + ans = np.concatenate(ans) + return ans + + +def get_nd_cached_slots(cache, slot_map): + ans = [] + for slot in slot_map: + if slot < 0: + continue + slot_idx = slot // slot_size + slot_offset = slot % slot_size + ans.append(cache[slot_idx][slot_offset]) + ans = np.concatenate(ans) + return ans + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize('kv_dim', [2, 3]) +@pytest.mark.parametrize('k_dtype', [np.float16, bfloat16, np.int8]) +@pytest.mark.parametrize('v_dtype', [np.float16, bfloat16]) +@pytest.mark.parametrize('exec_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +# @pytest.mark.parametrize('exec_mode', [context.GRAPH_MODE]) +def test_reshape_and_cache_nz(k_dtype, v_dtype, kv_dim, exec_mode): + """ + Feature: Test ReshapeAndCache. + Description: Test float16 inputs. + Expectation: Assert that results are consistent with numpy. + """ + if (k_dtype == np.float16 and v_dtype != np.float16) or \ + (k_dtype == bfloat16 and v_dtype != bfloat16): + pytest.skip(f"Invalid combo: {k_dtype} -> {v_dtype}") + + # todo check why need infer_boost + context.set_context(mode=exec_mode, device_target="Ascend", + jit_config={"jit_level": "O0", "infer_boost": "on"}, + save_graphs=False, save_graphs_path="reshape_and_cache_graph_wrong") + net = ReshapeAndCacheAllNz() + + np_k, np_v, np_k_cache, np_v_cache, np_slot_map = create_nz_inputs( + k_dtype, v_dtype, kv_dim) + np_k_cache_out, np_v_cache_out = nz_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + + # FRACTAL_NZ + if exec_mode == context.GRAPH_MODE: + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, format="FRACTAL_NZ") + else: + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, format="") + acl_format = 29 + ms_k_cache = ops.auto_generate.format_cast(ms_k_cache, acl_format) + ms_v_cache = ops.auto_generate.format_cast(ms_v_cache, acl_format) + + _ = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map, head_num=n) + + # attention: ms asnumpy() will cast nz to nd! + if k_dtype == bfloat16: + ms_k_cache_np = ms_k_cache.asnumpy().astype(np.float32) + np_k_cache_out = np_k_cache_out.astype(np.float32) + else: + ms_k_cache_np = ms_k_cache.asnumpy() + + if v_dtype == bfloat16: + ms_v_cache_np = ms_v_cache.asnumpy().astype(np.float32) + np_v_cache_out = np_v_cache_out.astype(np.float32) + else: + ms_v_cache_np = ms_v_cache.asnumpy() + + ms_k_output = get_nd_cached_slots(ms_k_cache_np, np_slot_map) + golden_k_output = get_nd_cached_slots(np_k_cache_out, np_slot_map) - # if np_dtype == bfloat16: - # assert np.allclose(ms_k_cache.float().asnumpy(), - # np_k_cache_out.astype(np.float32), 0.001, 0.001) - # assert np.allclose(ms_v_cache.float().asnumpy(), - # np_v_cache_out.astype(np.float32), 0.001, 0.001) - # else: - # assert np.allclose(ms_k_cache.asnumpy(), np_k_cache_out, 0.001, 0.001) - # assert np.allclose(ms_v_cache.asnumpy(), np_v_cache_out, 0.001, 0.001) + print(f"ms out: {ms_k_output}") + print(f"np out: {golden_k_output}") - assert np.allclose(ms_k_cache.asnumpy(), np_k_cache_out, 0.001, 0.001) - assert np.allclose(ms_v_cache.asnumpy(), np_v_cache_out, 0.001, 0.001) + ms_v_output = get_nd_cached_slots(ms_v_cache_np, np_slot_map) + golden_v_output = get_nd_cached_slots(np_v_cache_out, np_slot_map) -# test_custom_reshape_and_cache(context.PYNATIVE_MODE, np.float16, 3) + assert np.allclose(ms_k_output, golden_k_output, 0.001, 0.001) + assert np.allclose(ms_v_output, golden_v_output, 0.001, 0.001) -- Gitee From 9c8f34c7598d12e9190f9adc0987fff659dddafe Mon Sep 17 00:00:00 2001 From: mengyuanli Date: Tue, 29 Jul 2025 11:30:37 +0800 Subject: [PATCH 2/3] refactor test --- tests/st/test_custom_reshape_and_cache.py | 446 ++++++++++++---------- 1 file changed, 252 insertions(+), 194 deletions(-) diff --git a/tests/st/test_custom_reshape_and_cache.py b/tests/st/test_custom_reshape_and_cache.py index 2598d79..0a15325 100644 --- a/tests/st/test_custom_reshape_and_cache.py +++ b/tests/st/test_custom_reshape_and_cache.py @@ -16,6 +16,7 @@ # Standard library imports from enum import Enum +from typing import Tuple, Optional, Dict, Any # Third-party imports import numpy as np @@ -30,122 +31,223 @@ from mindspore.common.np_dtype import bfloat16 # Local imports import ms_custom_ops -num_slots = 20 -slot_size = 64 -b = 13 -s = 3 -n = 16 -d = 32 +# Global constants +NUM_SLOTS = 20 +SLOT_SIZE = 64 +BATCH_SIZE = 13 +SEQ_LEN = 3 +NUM_HEADS = 16 +HEAD_DIM = 32 -class ReshapeAndCacheAllNz(nn.Cell): +class CacheFormat(Enum): + """Cache format enumeration""" + ND = "nd" + NZ = "nz" + + +class DataType(Enum): + """Data type enumeration""" + FLOAT16 = np.float16 + BFLOAT16 = bfloat16 + INT8 = np.int8 + + +class ReshapeAndCacheAll(nn.Cell): + """Reshape and cache operation for NZ/ND format with all parameters""" + def __init__(self): super().__init__() @jit def construct(self, key, value, key_cache, value_cache, slot_map, head_num=0): - out = ms_custom_ops.reshape_and_cache( + return ms_custom_ops.reshape_and_cache( key, value, key_cache, value_cache, slot_map, head_num) - return out class ReshapeAndCacheKey(nn.Cell): + """Reshape and cache operation for NZ/ND format with key only""" + def __init__(self): super().__init__() def construct(self, key, key_cache, slot_map): - out = ms_custom_ops.reshape_and_cache( + return ms_custom_ops.reshape_and_cache( key, key_cache=key_cache, slot_mapping=slot_map) - return out -class ReshapeAndCacheAllNd(nn.Cell): - def __init__(self): - super().__init__() - - def construct(self, key, value=None, key_cache=None, value_cache=None, slot_map=None): - out = ms_custom_ops.reshape_and_cache( - key, value, key_cache, value_cache, slot_map) - return out +class MindSporeInputFactory: + """Factory for creating MindSpore inputs""" + + @staticmethod + def create_inputs(np_k: np.ndarray, np_v: np.ndarray, + np_k_cache: np.ndarray, np_v_cache: np.ndarray, + np_slot_map: np.ndarray, format: str = "", + exec_mode: context = context.GRAPH_MODE) -> Tuple[Tensor, ...]: + """Create MindSpore inputs""" + ms_key = Tensor(np_k) + ms_value = Tensor(np_v) + + if exec_mode == context.GRAPH_MODE: + ms_key_cache = Parameter(Tensor(np_k_cache), storage_format=format, name="key_cache") + ms_value_cache = Parameter(Tensor(np_v_cache), storage_format=format, name="value_cache") + else: + ms_key_cache = Tensor(np_k_cache) + ms_value_cache = Tensor(np_v_cache) + + ms_slot_map = Tensor(np_slot_map) + return ms_key, ms_value, ms_key_cache, ms_value_cache, ms_slot_map def create_ms_inputs(np_k, np_v, np_k_cache, np_v_cache, np_slot_map, format="", exec_mode=context.GRAPH_MODE): - """ - create inputs - """ - ms_key = Tensor(np_k) - ms_value = Tensor(np_v) - if exec_mode == context.GRAPH_MODE: - ms_key_cache = Parameter(Tensor(np_k_cache), storage_format=format, name="key_cache") - ms_value_cache = Parameter(Tensor(np_v_cache), storage_format=format, name="value_cache") - else: - ms_key_cache = Tensor(np_k_cache) - ms_value_cache = Tensor(np_v_cache) - ms_slot_map = Tensor(np_slot_map) - return ms_key, ms_value, ms_key_cache, ms_value_cache, ms_slot_map + """Legacy function for backward compatibility""" + return MindSporeInputFactory.create_inputs(np_k, np_v, np_k_cache, np_v_cache, np_slot_map, format, exec_mode) + + +class TestResultVerifier: + """Verify test results""" + + @staticmethod + def verify_results(ms_cache: Tensor, np_cache: np.ndarray, + dtype: np.dtype, rtol: float = 0.001, atol: float = 0.001) -> None: + """Verify results with appropriate dtype handling""" + if dtype == bfloat16: + ms_cache_np = ms_cache.float().asnumpy() + np_cache = np_cache.astype(np.float32) + else: + ms_cache_np = ms_cache.asnumpy() + + assert np.allclose(ms_cache_np, np_cache, rtol=rtol, atol=atol) + + +class TestConfig: + """Test configuration""" + + def __init__(self, device_target: str = "Ascend", mode: context = context.GRAPH_MODE, + jit_config: Optional[Dict[str, Any]] = None): + self.device_target = device_target + self.mode = mode + self.jit_config = jit_config or {} + + def apply(self): + """Apply test configuration""" + context.set_context(device_target=self.device_target, mode=self.mode) + if self.jit_config: + context.set_context(jit_config=self.jit_config) # =============================== # test nd format # =============================== -def create_nd_inputs(dtype=np.float16, kv_dim=3): - """ - create_nd_inputs - """ - cache_shape = (num_slots, slot_size, n, d) - if kv_dim == 2: - update_shape = (b * s, n * d) - num_tokens = update_shape[0] - elif kv_dim == 3: - update_shape = (b, s, n * d) - num_tokens = update_shape[0] * update_shape[1] - else: - raise Exception( - "Key's dim should be 2 or 3, but got {0}".format(kv_dim)) - - if dtype == np.int8: - key_update = np.random.randint(low=-128, high=127, - size=update_shape, - dtype=np.int8) - value_update = np.random.randint(low=-128, high=127, - size=update_shape, - dtype=np.int8) - key_cache = np.random.randint(low=-128, high=127, - size=cache_shape, - dtype=np.int8) - value_cache = np.random.randint(low=-128, high=127, - size=cache_shape, - dtype=np.int8) - else: - key_update = np.random.rand(*update_shape).astype(dtype) - value_update = np.random.rand(*update_shape).astype(dtype) - key_cache = np.random.rand(*cache_shape).astype(dtype) - value_cache = np.random.rand(*cache_shape).astype(dtype) +class TestDataGenerator: + """Data generator for test inputs""" + + @staticmethod + def create_random_data(shape: Tuple[int, ...], dtype: np.dtype) -> np.ndarray: + """Create random data with specified shape and dtype""" + if dtype == np.int8: + return np.random.randint(low=-128, high=127, size=shape, dtype=np.int8) + else: + return np.random.rand(*shape).astype(dtype) + + @staticmethod + def create_slot_map(num_tokens: int) -> np.ndarray: + """Create slot mapping""" + return np.random.choice(np.arange(num_tokens), num_tokens, replace=False).astype(np.int32) + + @staticmethod + def get_update_shape(kv_dim: int) -> Tuple[Tuple[int, ...], int]: + """Get update shape and number of tokens based on dimension""" + if kv_dim == 2: + update_shape = (BATCH_SIZE * SEQ_LEN, NUM_HEADS * HEAD_DIM) + num_tokens = update_shape[0] + elif kv_dim == 3: + update_shape = (BATCH_SIZE, SEQ_LEN, NUM_HEADS * HEAD_DIM) + num_tokens = update_shape[0] * update_shape[1] + else: + raise ValueError(f"Key's dim should be 2 or 3, but got {kv_dim}") + return update_shape, num_tokens + + +class NDDataGenerator(TestDataGenerator): + """Data generator for ND format""" + + @staticmethod + def create_inputs(dtype: np.dtype, kv_dim: int) -> Tuple[np.ndarray, ...]: + """Create ND format inputs""" + cache_shape = (NUM_SLOTS, SLOT_SIZE, NUM_HEADS, HEAD_DIM) + update_shape, num_tokens = TestDataGenerator.get_update_shape(kv_dim) + + key_update = TestDataGenerator.create_random_data(update_shape, dtype) + value_update = TestDataGenerator.create_random_data(update_shape, dtype) + key_cache = TestDataGenerator.create_random_data(cache_shape, dtype) + value_cache = TestDataGenerator.create_random_data(cache_shape, dtype) + slot_map = TestDataGenerator.create_slot_map(num_tokens) + + return key_update, value_update, key_cache, value_cache, slot_map - slot_map = np.random.choice(np.arange(num_tokens), num_tokens, - replace=False).astype(np.int32) - return key_update, value_update, key_cache, value_cache, slot_map +def create_nd_inputs(dtype=np.float16, kv_dim=3): + """Legacy function for backward compatibility""" + return NDDataGenerator.create_inputs(dtype, kv_dim) + + +class InferenceEngine: + """Inference engine for different formats""" + + @staticmethod + def nd_inference(key: np.ndarray, value: np.ndarray, + key_cache: np.ndarray, value_cache: np.ndarray, + slot_map: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ND format inference""" + key_tmp = key.copy() + value_tmp = value.copy() + key_cache_ans = key_cache.copy() + value_cache_ans = value_cache.copy() + + head = key_cache.shape[2] + head_dim = key_cache.shape[3] + key_tmp = key_tmp.reshape(-1, head, head_dim) + value_tmp = value_tmp.reshape(-1, head, head_dim) + + for i, slot in enumerate(slot_map): + slot_idx = slot // key_cache.shape[1] + slot_offset = slot % key_cache.shape[1] + key_cache_ans[slot_idx][slot_offset] = key_tmp[i] + value_cache_ans[slot_idx][slot_offset] = value_tmp[i] + + return key_cache_ans, value_cache_ans + + @staticmethod + def nz_inference(key: np.ndarray, value: np.ndarray, + key_cache: np.ndarray, value_cache: np.ndarray, + slot_map: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """NZ format inference""" + key_tmp = key.copy() + value_tmp = value.copy() + key_cache_ans = key_cache.copy() + value_cache_ans = value_cache.copy() + + key_tmp = key_tmp.reshape(-1, key_cache.shape[2]) + value_tmp = value_tmp.reshape(-1, key_cache.shape[2]) + + for i, slot in enumerate(slot_map): + slot_idx = slot // key_cache.shape[1] + slot_offset = slot % key_cache.shape[1] + key_cache_ans[slot_idx][slot_offset] = key_tmp[i] + value_cache_ans[slot_idx][slot_offset] = value_tmp[i] + + return key_cache_ans, value_cache_ans def nd_inference(key, value, key_cache, value_cache, slot_map): - """ - nd_inference - """ - key_tmp = key.copy() - value_tmp = value.copy() - key_cache_ans = key_cache.copy() - value_cache_ans = value_cache.copy() - head = key_cache.shape[2] - head_dim = key_cache.shape[3] - key_tmp = key_tmp.reshape(-1, head, head_dim) - value_tmp = value_tmp.reshape(-1, head, head_dim) - for i, slot in enumerate(slot_map): - slot_idx = slot // key_cache.shape[1] - slot_offset = slot % key_cache.shape[1] - key_cache_ans[slot_idx][slot_offset] = key_tmp[i] - value_cache_ans[slot_idx][slot_offset] = value_tmp[i] - return key_cache_ans, value_cache_ans + """Legacy function for backward compatibility""" + return InferenceEngine.nd_inference(key, value, key_cache, value_cache, slot_map) + + +def nz_inference(key, value, key_cache, value_cache, slot_map): + """Legacy function for backward compatibility""" + return InferenceEngine.nz_inference(key, value, key_cache, value_cache, slot_map) @pytest.mark.level0 @@ -156,11 +258,13 @@ def nd_inference(key, value, key_cache, value_cache, slot_map): def test_reshape_and_cache_nd_key_value(np_dtype, kv_dim): """ Feature: Test ReshapeAndCache. - Description: Test float16 inputs. + Description: Test ND format with key and value. Expectation: Assert that results are consistent with numpy. """ - context.set_context(device_target="Ascend", mode=context.GRAPH_MODE) - net = ReshapeAndCacheAllNz() + test_config = TestConfig(device_target="Ascend", mode=context.GRAPH_MODE) + test_config.apply() + + net = ReshapeAndCacheAll() np_k, np_v, np_k_cache, np_v_cache, np_slot_map = create_nd_inputs( np_dtype, kv_dim) @@ -169,16 +273,11 @@ def test_reshape_and_cache_nd_key_value(np_dtype, kv_dim): ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + + # Run test _ = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map) - - if np_dtype == bfloat16: - assert np.allclose(ms_k_cache.float().asnumpy(), - np_k_cache_out.astype(np.float32), 0.001, 0.001) - assert np.allclose(ms_v_cache.float().asnumpy(), - np_v_cache_out.astype(np.float32), 0.001, 0.001) - else: - assert np.allclose(ms_k_cache.asnumpy(), np_k_cache_out, 0.001, 0.001) - assert np.allclose(ms_v_cache.asnumpy(), np_v_cache_out, 0.001, 0.001) + TestResultVerifier.verify_results(ms_k_cache, np_k_cache_out, np_dtype) + TestResultVerifier.verify_results(ms_v_cache, np_v_cache_out, np_dtype) @pytest.mark.level0 @@ -189,27 +288,26 @@ def test_reshape_and_cache_nd_key_value(np_dtype, kv_dim): def test_reshape_and_cache_nd_key(np_dtype, kv_dim): """ Feature: Test ReshapeAndCache. - Description: Test float16 inputs. + Description: Test ND format with key only. Expectation: Assert that results are consistent with numpy. """ - context.set_context(device_target="Ascend", mode=context.GRAPH_MODE) - context.set_context(jit_config={"jit_level": "O0"}) + test_config = TestConfig(device_target="Ascend", mode=context.GRAPH_MODE, + jit_config={"jit_level": "O0"}) + test_config.apply() + net = ReshapeAndCacheKey() np_k, np_v, np_k_cache, np_v_cache, np_slot_map = create_nd_inputs( np_dtype, kv_dim) - np_k_cache_out, np_v_cache_out = nd_inference( + np_k_cache_out, _ = nd_inference( np_k, np_v, np_k_cache, np_v_cache, np_slot_map) ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + + # Run test _ = net(ms_k, key_cache=ms_k_cache, slot_map=ms_slot_map) - - if np_dtype == bfloat16: - assert np.allclose(ms_k_cache.float().asnumpy(), - np_k_cache_out.astype(np.float32), 0.001, 0.001) - else: - assert np.allclose(ms_k_cache.asnumpy(), np_k_cache_out, 0.001, 0.001) + TestResultVerifier.verify_results(ms_k_cache, np_k_cache_out, np_dtype) # @pytest.mark.level0 @@ -225,7 +323,7 @@ def test_reshape_and_cache_nd_key(np_dtype, kv_dim): # """ # context.set_context(device_target="Ascend", mode=context.GRAPH_MODE) # context.set_context(jit_config={"jit_level": "O0"}) -# net = ReshapeAndCacheAllNd() +# net = ReshapeAndCacheAll() # np_k, np_v, np_k_cache, np_v_cache, np_slot_map = create_nd_inputs( # np_dtype, kv_dim) @@ -244,69 +342,28 @@ def test_reshape_and_cache_nd_key(np_dtype, kv_dim): # =============================== # test nz format # =============================== -def create_nz_inputs(k_dtype=np.float16, v_dtype=np.float16, kv_dim=3): - """ - create_nz_inputs - """ - k_cache_shape = (num_slots, slot_size, n * d) - v_cache_shape = (num_slots, slot_size, n * d) - - if kv_dim == 2: - update_shape = (b * s, n * d) - num_tokens = update_shape[0] - elif kv_dim == 3: - update_shape = (b, s, n * d) - num_tokens = update_shape[0] * update_shape[1] - else: - raise Exception( - "Key's dim should be 2 or 3, but got {0}".format(kv_dim)) - - if k_dtype == np.int8: - key_update = np.random.randint(low=-128, high=127, - size=update_shape, - dtype=np.int8) - # key_cache = np.random.randint(low=-128, high=127, - # size=k_cache_shape, - # dtype=np.int8) - else: - key_update = np.random.rand(*update_shape).astype(k_dtype) - # key_cache = np.random.rand(*k_cache_shape).astype(k_dtype) - key_cache = np.zeros(k_cache_shape, dtype=k_dtype) - - if v_dtype == np.int8: - value_update = np.random.randint(low=-128, high=127, - size=update_shape, - dtype=np.int8) - # value_cache = np.random.randint(low=-128, high=127, - # size=v_cache_shape, - # dtype=np.int8) - else: - value_update = np.random.rand(*update_shape).astype(v_dtype) - # value_cache = np.random.rand(*v_cache_shape).astype(v_dtype) - value_cache = np.zeros(v_cache_shape, dtype=v_dtype) - - slot_map = np.random.choice(np.arange(num_tokens), num_tokens, - replace=False).astype(np.int32) - - return key_update, value_update, key_cache, value_cache, slot_map +class NZDataGenerator(TestDataGenerator): + """Data generator for NZ format""" + + @staticmethod + def create_inputs(k_dtype: np.dtype, v_dtype: np.dtype, kv_dim: int) -> Tuple[np.ndarray, ...]: + """Create NZ format inputs""" + k_cache_shape = (NUM_SLOTS, SLOT_SIZE, NUM_HEADS * HEAD_DIM) + v_cache_shape = (NUM_SLOTS, SLOT_SIZE, NUM_HEADS * HEAD_DIM) + update_shape, num_tokens = TestDataGenerator.get_update_shape(kv_dim) + + key_update = TestDataGenerator.create_random_data(update_shape, k_dtype) + value_update = TestDataGenerator.create_random_data(update_shape, v_dtype) + key_cache = np.zeros(k_cache_shape, dtype=k_dtype) + value_cache = np.zeros(v_cache_shape, dtype=v_dtype) + slot_map = TestDataGenerator.create_slot_map(num_tokens) + + return key_update, value_update, key_cache, value_cache, slot_map -def nz_inference(key, value, key_cache, value_cache, slot_map): - """ - nz_inference - """ - key_tmp = key.copy() - value_tmp = value.copy() - key_cache_ans = key_cache.copy() - value_cache_ans = value_cache.copy() - key_tmp = key_tmp.reshape(-1, key_cache.shape[2]) - value_tmp = value_tmp.reshape(-1, key_cache.shape[2]) - for i, slot in enumerate(slot_map): - slot_idx = slot // key_cache.shape[1] - slot_offset = slot % key_cache.shape[1] - key_cache_ans[slot_idx][slot_offset] = key_tmp[i] - value_cache_ans[slot_idx][slot_offset] = value_tmp[i] - return key_cache_ans, value_cache_ans +def create_nz_inputs(k_dtype=np.float16, v_dtype=np.float16, kv_dim=3): + """Legacy function for backward compatibility""" + return NZDataGenerator.create_inputs(k_dtype, v_dtype, kv_dim) def get_nz_cached_slots(cache, slot_map): @@ -341,8 +398,8 @@ def get_nd_cached_slots(cache, slot_map): for slot in slot_map: if slot < 0: continue - slot_idx = slot // slot_size - slot_offset = slot % slot_size + slot_idx = slot // SLOT_SIZE + slot_offset = slot % SLOT_SIZE ans.append(cache[slot_idx][slot_offset]) ans = np.concatenate(ans) return ans @@ -359,25 +416,27 @@ def get_nd_cached_slots(cache, slot_map): def test_reshape_and_cache_nz(k_dtype, v_dtype, kv_dim, exec_mode): """ Feature: Test ReshapeAndCache. - Description: Test float16 inputs. + Description: Test NZ format with key and value. Expectation: Assert that results are consistent with numpy. """ + # Skip invalid combinations if (k_dtype == np.float16 and v_dtype != np.float16) or \ (k_dtype == bfloat16 and v_dtype != bfloat16): pytest.skip(f"Invalid combo: {k_dtype} -> {v_dtype}") - - # todo check why need infer_boost - context.set_context(mode=exec_mode, device_target="Ascend", - jit_config={"jit_level": "O0", "infer_boost": "on"}, - save_graphs=False, save_graphs_path="reshape_and_cache_graph_wrong") - net = ReshapeAndCacheAllNz() + + # Setup context + jit_config = {"jit_level": "O0", "infer_boost": "on"} + test_config = TestConfig(device_target="Ascend", mode=exec_mode, jit_config=jit_config) + test_config.apply() + + net = ReshapeAndCacheAll() np_k, np_v, np_k_cache, np_v_cache, np_slot_map = create_nz_inputs( k_dtype, v_dtype, kv_dim) np_k_cache_out, np_v_cache_out = nz_inference( np_k, np_v, np_k_cache, np_v_cache, np_slot_map) - # FRACTAL_NZ + # Create MindSpore inputs with appropriate format if exec_mode == context.GRAPH_MODE: ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( np_k, np_v, np_k_cache, np_v_cache, np_slot_map, format="FRACTAL_NZ") @@ -388,29 +447,28 @@ def test_reshape_and_cache_nz(k_dtype, v_dtype, kv_dim, exec_mode): ms_k_cache = ops.auto_generate.format_cast(ms_k_cache, acl_format) ms_v_cache = ops.auto_generate.format_cast(ms_v_cache, acl_format) - _ = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map, head_num=n) + _ = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map, head_num=NUM_HEADS) - # attention: ms asnumpy() will cast nz to nd! + # Extract and verify results + ms_k_cache_np = ms_k_cache.asnumpy() + ms_v_cache_np = ms_v_cache.asnumpy() + + # Handle bfloat16 conversion if k_dtype == bfloat16: - ms_k_cache_np = ms_k_cache.asnumpy().astype(np.float32) + ms_k_cache_np = ms_k_cache_np.astype(np.float32) np_k_cache_out = np_k_cache_out.astype(np.float32) - else: - ms_k_cache_np = ms_k_cache.asnumpy() - + if v_dtype == bfloat16: - ms_v_cache_np = ms_v_cache.asnumpy().astype(np.float32) + ms_v_cache_np = ms_v_cache_np.astype(np.float32) np_v_cache_out = np_v_cache_out.astype(np.float32) - else: - ms_v_cache_np = ms_v_cache.asnumpy() - + + # Extract cached slots for verification ms_k_output = get_nd_cached_slots(ms_k_cache_np, np_slot_map) golden_k_output = get_nd_cached_slots(np_k_cache_out, np_slot_map) - - print(f"ms out: {ms_k_output}") - print(f"np out: {golden_k_output}") - + ms_v_output = get_nd_cached_slots(ms_v_cache_np, np_slot_map) golden_v_output = get_nd_cached_slots(np_v_cache_out, np_slot_map) - + + # Verify results assert np.allclose(ms_k_output, golden_k_output, 0.001, 0.001) assert np.allclose(ms_v_output, golden_v_output, 0.001, 0.001) -- Gitee From 03d43d37a57e4289a5cb2fa91e5aed90288d88f5 Mon Sep 17 00:00:00 2001 From: mengyuanli Date: Tue, 29 Jul 2025 11:48:02 +0800 Subject: [PATCH 3/3] eee --- tests/st/test_custom_reshape_and_cache.py | 302 ++++++++++++++++++---- 1 file changed, 254 insertions(+), 48 deletions(-) diff --git a/tests/st/test_custom_reshape_and_cache.py b/tests/st/test_custom_reshape_and_cache.py index 0a15325..7a8a006 100644 --- a/tests/st/test_custom_reshape_and_cache.py +++ b/tests/st/test_custom_reshape_and_cache.py @@ -37,7 +37,8 @@ SLOT_SIZE = 64 BATCH_SIZE = 13 SEQ_LEN = 3 NUM_HEADS = 16 -HEAD_DIM = 32 +K_HEAD_DIM = 32 +V_HEAD_DIM = 32 class CacheFormat(Enum): @@ -136,6 +137,25 @@ class TestConfig: context.set_context(jit_config=self.jit_config) +class DimensionTestHelper: + """Helper class for testing different dimension combinations""" + + @staticmethod + def run_with_dimensions(k_head_dim: int, v_head_dim: int, test_func): + """Run test with specified dimensions and restore original values""" + global K_HEAD_DIM, V_HEAD_DIM + original_k_head_dim = K_HEAD_DIM + original_v_head_dim = V_HEAD_DIM + + try: + K_HEAD_DIM = k_head_dim + V_HEAD_DIM = v_head_dim + test_func() + finally: + K_HEAD_DIM = original_k_head_dim + V_HEAD_DIM = original_v_head_dim + + # =============================== # test nd format # =============================== @@ -156,17 +176,25 @@ class TestDataGenerator: return np.random.choice(np.arange(num_tokens), num_tokens, replace=False).astype(np.int32) @staticmethod - def get_update_shape(kv_dim: int) -> Tuple[Tuple[int, ...], int]: - """Get update shape and number of tokens based on dimension""" + def get_update_shapes(kv_dim: int) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: + """Get update shapes for key and value, and number of tokens based on dimension""" if kv_dim == 2: - update_shape = (BATCH_SIZE * SEQ_LEN, NUM_HEADS * HEAD_DIM) - num_tokens = update_shape[0] + key_update_shape = (BATCH_SIZE * SEQ_LEN, NUM_HEADS * K_HEAD_DIM) + value_update_shape = (BATCH_SIZE * SEQ_LEN, NUM_HEADS * V_HEAD_DIM) + num_tokens = key_update_shape[0] elif kv_dim == 3: - update_shape = (BATCH_SIZE, SEQ_LEN, NUM_HEADS * HEAD_DIM) - num_tokens = update_shape[0] * update_shape[1] + key_update_shape = (BATCH_SIZE, SEQ_LEN, NUM_HEADS * K_HEAD_DIM) + value_update_shape = (BATCH_SIZE, SEQ_LEN, NUM_HEADS * V_HEAD_DIM) + num_tokens = key_update_shape[0] * key_update_shape[1] else: raise ValueError(f"Key's dim should be 2 or 3, but got {kv_dim}") - return update_shape, num_tokens + return key_update_shape, value_update_shape, num_tokens + + @staticmethod + def get_update_shape(kv_dim: int, is_key: bool = True) -> Tuple[Tuple[int, ...], int]: + """Legacy method for backward compatibility""" + key_shape, value_shape, num_tokens = TestDataGenerator.get_update_shapes(kv_dim) + return (key_shape if is_key else value_shape), num_tokens class NDDataGenerator(TestDataGenerator): @@ -175,13 +203,14 @@ class NDDataGenerator(TestDataGenerator): @staticmethod def create_inputs(dtype: np.dtype, kv_dim: int) -> Tuple[np.ndarray, ...]: """Create ND format inputs""" - cache_shape = (NUM_SLOTS, SLOT_SIZE, NUM_HEADS, HEAD_DIM) - update_shape, num_tokens = TestDataGenerator.get_update_shape(kv_dim) + key_cache_shape = (NUM_SLOTS, SLOT_SIZE, NUM_HEADS, K_HEAD_DIM) + value_cache_shape = (NUM_SLOTS, SLOT_SIZE, NUM_HEADS, V_HEAD_DIM) + key_update_shape, value_update_shape, num_tokens = TestDataGenerator.get_update_shapes(kv_dim) - key_update = TestDataGenerator.create_random_data(update_shape, dtype) - value_update = TestDataGenerator.create_random_data(update_shape, dtype) - key_cache = TestDataGenerator.create_random_data(cache_shape, dtype) - value_cache = TestDataGenerator.create_random_data(cache_shape, dtype) + key_update = TestDataGenerator.create_random_data(key_update_shape, dtype) + value_update = TestDataGenerator.create_random_data(value_update_shape, dtype) + key_cache = TestDataGenerator.create_random_data(key_cache_shape, dtype) + value_cache = TestDataGenerator.create_random_data(value_cache_shape, dtype) slot_map = TestDataGenerator.create_slot_map(num_tokens) return key_update, value_update, key_cache, value_cache, slot_map @@ -310,33 +339,139 @@ def test_reshape_and_cache_nd_key(np_dtype, kv_dim): TestResultVerifier.verify_results(ms_k_cache, np_k_cache_out, np_dtype) -# @pytest.mark.level0 -# @pytest.mark.platform_ascend910b -# @pytest.mark.env_onecard -# @pytest.mark.parametrize('np_dtype', [np.float16, np.int8, bfloat16]) -# @pytest.mark.parametrize('kv_dim', [2, 3]) -# def test_reshape_and_cache_expect_fail(np_dtype, kv_dim): -# """ -# Feature: Test ReshapeAndCache. -# Description: Test float16 inputs. -# Expectation: Assert that results are consistent with numpy. -# """ -# context.set_context(device_target="Ascend", mode=context.GRAPH_MODE) -# context.set_context(jit_config={"jit_level": "O0"}) -# net = ReshapeAndCacheAll() - -# np_k, np_v, np_k_cache, np_v_cache, np_slot_map = create_nd_inputs( -# np_dtype, kv_dim) -# np_k_cache_out, np_v_cache_out = nd_inference( -# np_k, np_v, np_k_cache, np_v_cache, np_slot_map) - -# ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( -# np_k, np_v, np_k_cache, np_v_cache, np_slot_map) -# with pytest.raises(Exception): -# _ = net(ms_k, ms_v, key_cache=ms_k_cache, slot_map=ms_slot_map) -# with pytest.raises(Exception): -# _ = net(ms_k, key_cache=ms_k_cache, -# value_cache=ms_v_cache, slot_map=ms_slot_map) +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize('np_dtype', [np.float16, np.int8, bfloat16]) +@pytest.mark.parametrize('kv_dim', [2, 3]) +@pytest.mark.parametrize('k_head_dim', [32, 64, 128]) +@pytest.mark.parametrize('v_head_dim', [32, 64, 128]) +def test_reshape_and_cache_nd_key_value_different_dimensions(np_dtype, kv_dim, k_head_dim, v_head_dim): + """ + Feature: Test ReshapeAndCache. + Description: Test ND format with different K_HEAD_DIM and V_HEAD_DIM combinations. + Expectation: Assert that results are consistent with numpy. + """ + def run_test(): + test_config = TestConfig(device_target="Ascend", mode=context.GRAPH_MODE) + test_config.apply() + + net = ReshapeAndCacheAll() + + np_k, np_v, np_k_cache, np_v_cache, np_slot_map = create_nd_inputs( + np_dtype, kv_dim) + np_k_cache_out, np_v_cache_out = nd_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + + # Run test + _ = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map) + TestResultVerifier.verify_results(ms_k_cache, np_k_cache_out, np_dtype) + TestResultVerifier.verify_results(ms_v_cache, np_v_cache_out, np_dtype) + + DimensionTestHelper.run_with_dimensions(k_head_dim, v_head_dim, run_test) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize('kv_dim', [2, 3]) +@pytest.mark.parametrize('exec_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_reshape_and_cache_nz_different_key_value_dimensions(kv_dim, exec_mode): + """ + Feature: Test ReshapeAndCache. + Description: Test NZ format with significantly different K_HEAD_DIM and V_HEAD_DIM. + Expectation: Assert that results are consistent with numpy. + """ + def run_test(): + # Setup context + jit_config = {"jit_level": "O0", "infer_boost": "on"} + test_config = TestConfig(device_target="Ascend", mode=exec_mode, jit_config=jit_config) + test_config.apply() + + net = ReshapeAndCacheAll() + + np_k, np_v, np_k_cache, np_v_cache, np_slot_map = create_nz_inputs( + np.float16, np.float16, kv_dim) + + # Verify that key and value have different shapes + assert np_k.shape != np_v.shape, f"Key and value should have different shapes: {np_k.shape} vs {np_v.shape}" + assert np_k_cache.shape != np_v_cache.shape, f"Key and value cache should have different shapes: {np_k_cache.shape} vs {np_v_cache.shape}" + + np_k_cache_out, np_v_cache_out = nz_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + + # Create MindSpore inputs with appropriate format + if exec_mode == context.GRAPH_MODE: + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, format="FRACTAL_NZ") + else: + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, format="") + acl_format = 29 + ms_k_cache = ops.auto_generate.format_cast(ms_k_cache, acl_format) + ms_v_cache = ops.auto_generate.format_cast(ms_v_cache, acl_format) + + _ = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map, head_num=NUM_HEADS) + + # Extract and verify results + ms_k_cache_np = ms_k_cache.asnumpy() + ms_v_cache_np = ms_v_cache.asnumpy() + + # Extract cached slots for verification + ms_k_output = get_nd_cached_slots(ms_k_cache_np, np_slot_map) + golden_k_output = get_nd_cached_slots(np_k_cache_out, np_slot_map) + + ms_v_output = get_nd_cached_slots(ms_v_cache_np, np_slot_map) + golden_v_output = get_nd_cached_slots(np_v_cache_out, np_slot_map) + + # Verify results + assert np.allclose(ms_k_output, golden_k_output, 0.001, 0.001) + assert np.allclose(ms_v_output, golden_v_output, 0.001, 0.001) + + # Test with very different dimensions + DimensionTestHelper.run_with_dimensions(96, 16, run_test) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize('kv_dim', [2, 3]) +def test_reshape_and_cache_different_key_value_dimensions(kv_dim): + """ + Feature: Test ReshapeAndCache. + Description: Test with significantly different K_HEAD_DIM and V_HEAD_DIM. + Expectation: Assert that results are consistent with numpy. + """ + def run_test(): + test_config = TestConfig(device_target="Ascend", mode=context.GRAPH_MODE) + test_config.apply() + + net = ReshapeAndCacheAll() + + # Test with very different dimensions + np_k, np_v, np_k_cache, np_v_cache, np_slot_map = create_nd_inputs( + np.float16, kv_dim) + + # Verify that key and value have different shapes + assert np_k.shape != np_v.shape, f"Key and value should have different shapes: {np_k.shape} vs {np_v.shape}" + assert np_k_cache.shape != np_v_cache.shape, f"Key and value cache should have different shapes: {np_k_cache.shape} vs {np_v_cache.shape}" + + np_k_cache_out, np_v_cache_out = nd_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + + # Run test + _ = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map) + TestResultVerifier.verify_results(ms_k_cache, np_k_cache_out, np.float16) + TestResultVerifier.verify_results(ms_v_cache, np_v_cache_out, np.float16) + + # Test with very different dimensions + DimensionTestHelper.run_with_dimensions(128, 32, run_test) # =============================== @@ -348,12 +483,12 @@ class NZDataGenerator(TestDataGenerator): @staticmethod def create_inputs(k_dtype: np.dtype, v_dtype: np.dtype, kv_dim: int) -> Tuple[np.ndarray, ...]: """Create NZ format inputs""" - k_cache_shape = (NUM_SLOTS, SLOT_SIZE, NUM_HEADS * HEAD_DIM) - v_cache_shape = (NUM_SLOTS, SLOT_SIZE, NUM_HEADS * HEAD_DIM) - update_shape, num_tokens = TestDataGenerator.get_update_shape(kv_dim) + k_cache_shape = (NUM_SLOTS, SLOT_SIZE, NUM_HEADS * K_HEAD_DIM) + v_cache_shape = (NUM_SLOTS, SLOT_SIZE, NUM_HEADS * V_HEAD_DIM) + key_update_shape, value_update_shape, num_tokens = TestDataGenerator.get_update_shapes(kv_dim) - key_update = TestDataGenerator.create_random_data(update_shape, k_dtype) - value_update = TestDataGenerator.create_random_data(update_shape, v_dtype) + key_update = TestDataGenerator.create_random_data(key_update_shape, k_dtype) + value_update = TestDataGenerator.create_random_data(value_update_shape, v_dtype) key_cache = np.zeros(k_cache_shape, dtype=k_dtype) value_cache = np.zeros(v_cache_shape, dtype=v_dtype) slot_map = TestDataGenerator.create_slot_map(num_tokens) @@ -370,8 +505,6 @@ def get_nz_cached_slots(cache, slot_map): ans = [] tmp = [] - print(f"=========cache shape: {cache.shape}") - num_slots = cache.shape[0] slot_size = cache.shape[1] hidden_size = cache.shape[2] @@ -472,3 +605,76 @@ def test_reshape_and_cache_nz(k_dtype, v_dtype, kv_dim, exec_mode): # Verify results assert np.allclose(ms_k_output, golden_k_output, 0.001, 0.001) assert np.allclose(ms_v_output, golden_v_output, 0.001, 0.001) + + +@pytest.mark.level0 +@pytest.mark.platform_ascend910b +@pytest.mark.env_onecard +@pytest.mark.parametrize('kv_dim', [2, 3]) +@pytest.mark.parametrize('k_dtype', [np.float16, bfloat16, np.int8]) +@pytest.mark.parametrize('v_dtype', [np.float16, bfloat16]) +@pytest.mark.parametrize('exec_mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +@pytest.mark.parametrize('k_head_dim', [32, 64, 128]) +@pytest.mark.parametrize('v_head_dim', [32, 64, 128]) +def test_reshape_and_cache_nz_different_dimensions(k_dtype, v_dtype, kv_dim, exec_mode, k_head_dim, v_head_dim): + """ + Feature: Test ReshapeAndCache. + Description: Test NZ format with different K_HEAD_DIM and V_HEAD_DIM combinations. + Expectation: Assert that results are consistent with numpy. + """ + # Skip invalid combinations + if (k_dtype == np.float16 and v_dtype != np.float16) or \ + (k_dtype == bfloat16 and v_dtype != bfloat16): + pytest.skip(f"Invalid combo: {k_dtype} -> {v_dtype}") + + def run_test(): + # Setup context + jit_config = {"jit_level": "O0", "infer_boost": "on"} + test_config = TestConfig(device_target="Ascend", mode=exec_mode, jit_config=jit_config) + test_config.apply() + + net = ReshapeAndCacheAll() + + np_k, np_v, np_k_cache, np_v_cache, np_slot_map = create_nz_inputs( + k_dtype, v_dtype, kv_dim) + np_k_cache_out, np_v_cache_out = nz_inference( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map) + + # Create MindSpore inputs with appropriate format + if exec_mode == context.GRAPH_MODE: + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, format="FRACTAL_NZ") + else: + ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map = create_ms_inputs( + np_k, np_v, np_k_cache, np_v_cache, np_slot_map, format="") + acl_format = 29 + ms_k_cache = ops.auto_generate.format_cast(ms_k_cache, acl_format) + ms_v_cache = ops.auto_generate.format_cast(ms_v_cache, acl_format) + + _ = net(ms_k, ms_v, ms_k_cache, ms_v_cache, ms_slot_map, head_num=NUM_HEADS) + + # Extract and verify results + ms_k_cache_np = ms_k_cache.asnumpy() + ms_v_cache_np = ms_v_cache.asnumpy() + + # Handle bfloat16 conversion + if k_dtype == bfloat16: + ms_k_cache_np = ms_k_cache_np.astype(np.float32) + np_k_cache_out = np_k_cache_out.astype(np.float32) + + if v_dtype == bfloat16: + ms_v_cache_np = ms_v_cache_np.astype(np.float32) + np_v_cache_out = np_v_cache_out.astype(np.float32) + + # Extract cached slots for verification + ms_k_output = get_nd_cached_slots(ms_k_cache_np, np_slot_map) + golden_k_output = get_nd_cached_slots(np_k_cache_out, np_slot_map) + + ms_v_output = get_nd_cached_slots(ms_v_cache_np, np_slot_map) + golden_v_output = get_nd_cached_slots(np_v_cache_out, np_slot_map) + + # Verify results + assert np.allclose(ms_k_output, golden_k_output, 0.001, 0.001) + assert np.allclose(ms_v_output, golden_v_output, 0.001, 0.001) + + DimensionTestHelper.run_with_dimensions(k_head_dim, v_head_dim, run_test) -- Gitee