diff --git a/mint_qwen3_next_delta_rule_ut.py b/mint_qwen3_next_delta_rule_ut.py new file mode 100644 index 0000000000000000000000000000000000000000..c5c1c21ae32e89f65fd81a41abef2c5fa7e6c576 --- /dev/null +++ b/mint_qwen3_next_delta_rule_ut.py @@ -0,0 +1,212 @@ +# coding=utf-8 +""" +Standalone UT to compare PyTorch vs MindSpore mint implementations of the +recurrent gated delta rule used by Qwen3-Next. + +Run: + - As script: python mint_qwen3_next_delta_rule_ut.py + - With pytest: pytest -q mint_qwen3_next_delta_rule_ut.py +""" + +from __future__ import annotations + +import sys +import numpy as np + + +def _torch_impl_available(): + try: + import torch # noqa: F401 + return True + except Exception: + return False + + +def _mint_available(): + try: + import mindspore # noqa: F401 + from mindspore import mint # noqa: F401 + return True + except Exception: + return False + + +def torch_recurrent_gated_delta_rule( + query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False +): + import torch + import torch.nn.functional as F + + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = F.normalize(query, p=2, dim=-1) + key = F.normalize(key, p=2, dim=-1) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + + batch_size, sequence_length, num_heads, k_head_dim = key.shape + v_head_dim = value.shape[-1] + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out = torch.zeros(batch_size, sequence_length, num_heads, v_head_dim).to(value) + last_recurrent_state = ( + torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + + for i in range(num_heads): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, i].unsqueeze(-1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state + + +def mint_recurrent_gated_delta_rule( + query, + key, + value, + g, + beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + import mindspore as ms + from mindspore import mint + from mindspore import dtype as mstype + + def _astype_fp32(x): + return ms.astype(x, mstype.float32) + + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = mint.nn.functional.normalize(query, p=2.0, dim=-1) + key = mint.nn.functional.normalize(key, p=2.0, dim=-1) + + # (b, s, h, d) -> (b, h, s, d)? Torch code performs transpose(1,2) but uses (b, s, h, d) afterwards. + # Their inputs to this function are (b, h, s, d) so transpose(1,2) yields (b, s, h, d). + # We mirror that behavior: always end up with (b, s, h, d) + query = (mint.transpose(query, 1, 2)).astype(ms.float32) + key = (mint.transpose(key, 1, 2)).astype(ms.float32) + value = (mint.transpose(value, 1, 2)).astype(ms.float32) + beta = (mint.transpose(beta, 1, 2)).astype(ms.float32) + g = (mint.transpose(g, 1, 2)).astype(ms.float32) + + batch_size, sequence_length, num_heads, k_head_dim = key.shape + v_head_dim = value.shape[-1] + + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + core_attn_out_list = [] # per-head outputs of shape (b, s, v) + if initial_state is None: + last_recurrent_state = mint.zeros((batch_size, sequence_length, k_head_dim, v_head_dim), dtype=value.dtype) + else: + last_recurrent_state = initial_state.astype(ms.float32) + + for i in range(num_heads): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = mint.unsqueeze(mint.unsqueeze(mint.exp(g[:, :, i]), -1), -1) + beta_t = mint.unsqueeze(beta[:, :, i], -1) + + last_recurrent_state = last_recurrent_state * g_t + kv_mem = mint.sum(last_recurrent_state * mint.unsqueeze(k_t, -1), dim=-2) + delta = (v_t - kv_mem) * beta_t + last_recurrent_state = last_recurrent_state + mint.unsqueeze(k_t, -1) * mint.unsqueeze(delta, -2) + out_t = mint.sum(last_recurrent_state * mint.unsqueeze(q_t, -1), dim=-2) + core_attn_out_list.append(out_t) + + core_attn_out = mint.stack(core_attn_out_list, dim=2) # (b, s, h, v) + core_attn_out = mint.transpose(core_attn_out, 1, 2) # match torch: (b, h, s, v) + + if not output_final_state: + last_recurrent_state = None + + # cast back to input dtype + core_attn_out = core_attn_out.astype(initial_dtype) + return core_attn_out, last_recurrent_state + + +def _gen_inputs(batch_size, seq_len, num_heads, k_head_dim, v_head_dim, with_init): + import torch + + torch.manual_seed(0) + query = torch.randn(batch_size, seq_len, num_heads, k_head_dim, dtype=torch.float32) + key = torch.randn(batch_size, seq_len, num_heads, k_head_dim, dtype=torch.float32) + value = torch.randn(batch_size, seq_len, num_heads, v_head_dim, dtype=torch.float32) + g = torch.randn(batch_size, seq_len, num_heads, dtype=torch.float32) + beta = torch.randn(batch_size, seq_len, num_heads, dtype=torch.float32) + init = ( + torch.randn(batch_size, seq_len, k_head_dim, v_head_dim, dtype=torch.float32) + if with_init + else None + ) + return query, key, value, g, beta, init + + +def _to_ms(x_torch): + from mindspore import Tensor as MSTensor + from mindspore import dtype as mstype + + return MSTensor(x_torch.detach().cpu().numpy(), dtype=mstype.float32) + + +def run_equivalence_once(batch_size=2, seq_len=3, num_heads=3, k_head_dim=4, v_head_dim=6, use_norm=False, with_init=False): + if not _torch_impl_available(): + raise RuntimeError("PyTorch is required for the reference implementation.") + if not _mint_available(): + raise RuntimeError("MindSpore mint is required for the mint implementation.") + + import torch + + query, key, value, g, beta, init = _gen_inputs( + batch_size, seq_len, num_heads, k_head_dim, v_head_dim, with_init + ) + + out_torch, state_torch = torch_recurrent_gated_delta_rule( + query, key, value, g, beta, init, True, use_norm + ) + + out_ms, state_ms = mint_recurrent_gated_delta_rule( + _to_ms(query), _to_ms(key), _to_ms(value), _to_ms(g), _to_ms(beta), None if init is None else _to_ms(init), True, use_norm + ) + + np.testing.assert_allclose(out_torch.detach().cpu().numpy(), out_ms.asnumpy(), rtol=1e-5, atol=1e-6) + if state_ms is None: + raise AssertionError("Expected final state when output_final_state=True") + np.testing.assert_allclose(state_torch.detach().cpu().numpy(), state_ms.asnumpy(), rtol=1e-5, atol=1e-6) + + +def test_equivalence(): # pytest will pick this up + for use_norm in (False, True): + for with_init in (False, True): + run_equivalence_once(use_norm=use_norm, with_init=with_init) + + +if __name__ == "__main__": + try: + test_equivalence() + except Exception as e: + print("[FAILED]", e) + sys.exit(1) + print("[PASSED] mint implementation matches torch reference") + sys.exit(0) + +