From 9e3500aacd8fadae0f1e1af39443cbf836137991 Mon Sep 17 00:00:00 2001 From: wangzw1022 Date: Tue, 16 Jul 2024 15:58:20 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E9=80=82=E9=85=8D=E6=96=B0=E7=89=88ROPE?= =?UTF-8?q?=E8=9E=8D=E5=90=88=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspeed/core/fusions/rotary_pos_embedding.py | 4 ++-- .../model_tests/perf_model/llama2/pretrain_llama2_70b_32k.sh | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mindspeed/core/fusions/rotary_pos_embedding.py b/mindspeed/core/fusions/rotary_pos_embedding.py index e3d4a7ca..c3c3ab27 100644 --- a/mindspeed/core/fusions/rotary_pos_embedding.py +++ b/mindspeed/core/fusions/rotary_pos_embedding.py @@ -3,7 +3,7 @@ import torch_npu from torch import Tensor from functools import wraps from megatron.training import get_args - +from mindspeed.ops.npu_rotary_position_embedding import npu_rotary_position_embedding def apply_fused_rotary_pos_emb_bshd_wrapper(fn): @wraps(fn) @@ -14,7 +14,7 @@ def apply_fused_rotary_pos_emb_bshd_wrapper(fn): t, t_pass = t[..., :rot_dim], t[..., rot_dim:] cos_ = torch.cos(freqs).to(t.dtype) sin_ = torch.sin(freqs).to(t.dtype) - t = torch_npu.npu_rotary_mul(t, cos_, sin_).to(t.dtype) + t = npu_rotary_position_embedding(t.contiguous(), cos_, sin_).to(t.dtype) return torch.cat((t, t_pass), dim=-1) return fn(t, freqs, rotary_interleaved) diff --git a/tests_extend/model_tests/perf_model/llama2/pretrain_llama2_70b_32k.sh b/tests_extend/model_tests/perf_model/llama2/pretrain_llama2_70b_32k.sh index 9c03ca89..515cddbc 100644 --- a/tests_extend/model_tests/perf_model/llama2/pretrain_llama2_70b_32k.sh +++ b/tests_extend/model_tests/perf_model/llama2/pretrain_llama2_70b_32k.sh @@ -34,6 +34,7 @@ GPT_ARGS=" --context-parallel-algo megatron_cp_algo \ --use-ascend-mc2 \ --reuse-fp32-param \ + --recompute-activation-function \ --use-fused-rotary-pos-emb \ --use-fused-swiglu \ --use-fused-rmsnorm \ -- Gitee From 254d2549e057ffa6e28e78914e5c3da78094ab5a Mon Sep 17 00:00:00 2001 From: wangzw1022 Date: Tue, 16 Jul 2024 16:31:01 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E9=80=82=E9=85=8D=E6=96=B0=E7=89=88ROPE?= =?UTF-8?q?=E8=9E=8D=E5=90=88=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindspeed/core/fusions/rotary_pos_embedding.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mindspeed/core/fusions/rotary_pos_embedding.py b/mindspeed/core/fusions/rotary_pos_embedding.py index c3c3ab27..e165260e 100644 --- a/mindspeed/core/fusions/rotary_pos_embedding.py +++ b/mindspeed/core/fusions/rotary_pos_embedding.py @@ -5,6 +5,7 @@ from functools import wraps from megatron.training import get_args from mindspeed.ops.npu_rotary_position_embedding import npu_rotary_position_embedding + def apply_fused_rotary_pos_emb_bshd_wrapper(fn): @wraps(fn) def wrapper(t: Tensor, freqs: Tensor, rotary_interleaved: bool = False) -> Tensor: -- Gitee From 1bd7da77829e539a93ece8631b75e4730061806f Mon Sep 17 00:00:00 2001 From: wangzw1022 Date: Tue, 16 Jul 2024 20:42:44 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E9=80=82=E9=85=8D=E6=96=B0=E7=89=88ROPE?= =?UTF-8?q?=E8=9E=8D=E5=90=88=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../unit_tests/ops/cann/test_npu_rotary_pos_embedding.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests_extend/unit_tests/ops/cann/test_npu_rotary_pos_embedding.py b/tests_extend/unit_tests/ops/cann/test_npu_rotary_pos_embedding.py index cd8a9453..6cb30dad 100644 --- a/tests_extend/unit_tests/ops/cann/test_npu_rotary_pos_embedding.py +++ b/tests_extend/unit_tests/ops/cann/test_npu_rotary_pos_embedding.py @@ -21,7 +21,8 @@ def create_test_args(use_fused_rotary_pos_emb=False): class TestNpuRotaryEmbedding(DistributedTest): world_size = 1 - @pytest.mark.skipif(DEVICE_NAME != 'Ascend910B', reason='device type is not supported, skip this UT!') + # @pytest.mark.skipif(DEVICE_NAME != 'Ascend910B', reason='device type is not supported, skip this UT!') + @pytest.mark.skip(reason="Not found RotaryPositionEmbedding operator in the environment.") @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) def test_npu_rotary_pos_embedding(self, dtype): t_ori = torch.rand(2, 2, 5, 128).npu().to(dtype) -- Gitee