From 1276ffa1e236fb5b93ebcd6892f4ab2fde7e914f Mon Sep 17 00:00:00 2001 From: guozhihua Date: Fri, 26 Sep 2025 11:20:47 +0800 Subject: [PATCH] [pytorch][model]change l2norm in qwen3_next for hf --- .../qwen3_next_gated_deltanet_attention.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mindspeed_llm/tasks/models/transformer/qwen3_next_gated_deltanet_attention.py b/mindspeed_llm/tasks/models/transformer/qwen3_next_gated_deltanet_attention.py index 313ed00aa..97d1a4452 100644 --- a/mindspeed_llm/tasks/models/transformer/qwen3_next_gated_deltanet_attention.py +++ b/mindspeed_llm/tasks/models/transformer/qwen3_next_gated_deltanet_attention.py @@ -74,6 +74,12 @@ def torch_causal_conv1d_update( return out +def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): + """This function is intended to align with the l2norm implementation in the FLA library.""" + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + return x * inv_norm + + def torch_chunk_gated_delta_rule( query, key, @@ -87,8 +93,8 @@ def torch_chunk_gated_delta_rule( ): initial_dtype = query.dtype if use_qk_l2norm_in_kernel: - query = F.normalize(query, p=2, dim=-1) - key = F.normalize(key, p=2, dim=-1) + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) query, key, value, beta, g = [ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) ] @@ -160,8 +166,8 @@ def torch_recurrent_gated_delta_rule( ): initial_dtype = query.dtype if use_qk_l2norm_in_kernel: - query = F.normalize(query, p=2, dim=-1) - key = F.normalize(key, p=2, dim=-1) + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) query, key, value, beta, g = [ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) ] -- Gitee