diff --git a/mindspeed_llm/tasks/models/transformer/qwen3_next_gated_deltanet_attention.py b/mindspeed_llm/tasks/models/transformer/qwen3_next_gated_deltanet_attention.py index 313ed00aa678bc452d4b6fcf2b7bf05fb9f9ef3d..97d1a4452d510e3a89d068486f1ff7c3595d5afd 100644 --- a/mindspeed_llm/tasks/models/transformer/qwen3_next_gated_deltanet_attention.py +++ b/mindspeed_llm/tasks/models/transformer/qwen3_next_gated_deltanet_attention.py @@ -74,6 +74,12 @@ def torch_causal_conv1d_update( return out +def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): + """This function is intended to align with the l2norm implementation in the FLA library.""" + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + return x * inv_norm + + def torch_chunk_gated_delta_rule( query, key, @@ -87,8 +93,8 @@ def torch_chunk_gated_delta_rule( ): initial_dtype = query.dtype if use_qk_l2norm_in_kernel: - query = F.normalize(query, p=2, dim=-1) - key = F.normalize(key, p=2, dim=-1) + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) query, key, value, beta, g = [ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) ] @@ -160,8 +166,8 @@ def torch_recurrent_gated_delta_rule( ): initial_dtype = query.dtype if use_qk_l2norm_in_kernel: - query = F.normalize(query, p=2, dim=-1) - key = F.normalize(key, p=2, dim=-1) + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) query, key, value, beta, g = [ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) ]