diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py index 0465a1e006e821f06c4ae5d3e6dc7e3181d100af..674f30ed476fb60c32c8a765ac82dc92b566ce44 100644 --- a/vllm_mindspore/model_executor/layers/rotary_embedding.py +++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py @@ -356,7 +356,12 @@ class MRotaryEmbedding(RotaryEmbedding): interleaved [THTHWHTHW...TT], preserving frequency continuity. """ x = ops.transpose(x, (1, 0, 2)) - x = mint.flatten(x, start_dim=1) + # mint.flatten ops in the ir is ShapeCalc + Reshape, and the + # ShapeCalc is cpu ops, so we change it to shape and reshape for + # aclgraph enable. + # see https://gitee.com/mindspore/mindspore/issues/IDBWDX for details. + t, _, _ = x.shape + x = ops.reshape(x, (t, -1)) x_t = mint.index_select(x, -1, self.rope_select_index) return x_t @@ -367,7 +372,12 @@ class MRotaryEmbedding(RotaryEmbedding): non-interleaved [TTTHHHWWW]. """ x = ops.transpose(x, (1, 0, 2)) - x = mint.flatten(x, start_dim=1) + # mint.flatten ops in the ir is ShapeCalc + Reshape, and the + # ShapeCalc is cpu ops, so we change it to shape and reshape for + # aclgraph enable. + # see https://gitee.com/mindspore/mindspore/issues/IDBWDX for details. + t, _, _ = x.shape + x = ops.reshape(x, (t, -1)) x_t = mint.index_select(x, -1, self.rope_select_index) return x_t