diff --git a/vllm_mindspore/attention/layer.py b/vllm_mindspore/attention/layer.py index 1490d3cf1189d1ea1a7d2e85012a1de8a9495c5f..ea92b00dd7818156ba2872b2a813ae77f55db87c 100644 --- a/vllm_mindspore/attention/layer.py +++ b/vllm_mindspore/attention/layer.py @@ -102,8 +102,6 @@ class Attention(nn.Cell): """ output = query # ensure that the input tensors of reshape_and_cache is continuous - key = key.contiguous() - value = value.contiguous() cache_out = self.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) query = ops.depend(query, cache_out) diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index 38118fbd232e35dc5f0fd52efe25c04f5e2620a9..01684ba12ad5d470c66acbf2b9b511504443b391 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -111,14 +111,20 @@ class UnquantizedLinearMethod(LinearMethodBase): layer: nn.Cell, x: Tensor, bias: Parameter = None) -> Tensor: - output_shape = x.shape[:-1] + (self.output_size_per_partition, ) - x = x.view(-1, self.input_size_per_partition) x = self.matmul(x, layer.weight) if bias is not None: x = mint.add(x, bias) - x = x.view(output_shape) return x + def format_to_nz(self, param): + cast_weight = ops.auto_generate.format_cast( + param, FORMAT_TYPE['nz']) + param.set_data(cast_weight) + + def process_weights_after_loading(self, layer): + if is_310p(): + self.format_to_nz(layer.weight) + class LinearBase(nn.Cell): """Base linear layer. @@ -163,16 +169,6 @@ class LinearBase(nn.Cell): def construct(self, x: Tensor) -> Tensor: raise NotImplementedError - def format_to_nz(self, param, merge_count=1): - current_count = self.param_load_counts.get(param.name, 0) + 1 - self.param_load_counts[param.name] = current_count - - if current_count == merge_count: - cast_weight = ops.auto_generate.format_cast( - param, FORMAT_TYPE['nz']) - param.set_data(cast_weight) - del self.param_load_counts[param.name] - class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. @@ -296,8 +292,6 @@ class ColumnParallelLinear(LinearBase): assert param.shape == loaded_weight.shape param.set_data(ms.from_numpy(loaded_weight)) - if is_310p() and param.name.endswith("weight"): - self.format_to_nz(param, merge_count=1) class MergedColumnParallelLinear(ColumnParallelLinear): @@ -370,9 +364,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): loaded_weight = split_loaded_weight(loaded_weight, output_dim, start_idx, shard_size) param.set_data(ms.from_numpy(loaded_weight)) - if is_310p() and param.name.endswith("weight"): - loaded_shard_num = 1 # gating/hidden - self.format_to_nz(param, loaded_shard_num) else: current_shard_offset = 0 shard_offsets = [] @@ -394,9 +385,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): # assert loaded_weight.shape == (shard_size, param.shape[1]) param[shard_offset:shard_offset + shard_size] = ms.from_numpy(loaded_weight) - if is_310p() and param.name.endswith("weight"): - loaded_shard_num = 2 # gating/hidden - self.format_to_nz(param, loaded_shard_num) class QKVParallelLinear(ColumnParallelLinear): @@ -505,9 +493,6 @@ class QKVParallelLinear(ColumnParallelLinear): assert loaded_weight.shape == param.shape param.set_data(loaded_weight) - if is_310p() and param.name.endswith("weight"): - loaded_shard_num = 3 # q/k/v - self.format_to_nz(param, loaded_shard_num) return assert loaded_shard_id in ["q", "k", "v"] @@ -539,9 +524,6 @@ class QKVParallelLinear(ColumnParallelLinear): if param.name.endswith("bias"): assert loaded_weight.shape == (shard_size, ) param[shard_offset:shard_offset + shard_size] = loaded_weight - if is_310p() and param.name.endswith("weight"): - loaded_shard_num = 3 # q/k/v - self.format_to_nz(param, loaded_shard_num) class RowParallelLinear(LinearBase): @@ -676,5 +658,3 @@ class RowParallelLinear(LinearBase): assert param.shape == loaded_weight.shape param.set_data(ms.from_numpy(loaded_weight)) - if is_310p() and param.name.endswith("weight"): - self.format_to_nz(param, merge_count=1) diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py index 57cf46b53221acc402bc12022c21d023a5651aa5..813ec1a5c2ee8dc475372ad10b2596c93c49d8fc 100644 --- a/vllm_mindspore/model_executor/layers/rotary_embedding.py +++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py @@ -307,16 +307,19 @@ class MRotaryEmbedding(RotaryEmbedding): ###################################################################### num_tokens = positions.shape[-1] cos_sin = self.cos_sin_cache[positions] - cos, sin = ops.chunk(cos_sin, 2, axis=-1) if positions.ndim == 2: - cos_l = mint.split(cos, self.mrope_section, dim=-1) - sin_l = mint.split(sin, self.mrope_section, dim=-1) - cos, sin = (), () - for i in range(len(self.mrope_section)): # type: ignore[arg-type] - cos += (cos_l[i][i], ) - sin += (sin_l[i][i], ) - cos = mint.cat(cos, dim=-1) - sin = mint.cat(sin, dim=-1) + cos_sin_total = () + cos_sin = cos_sin.reshape(len(self.mrope_section), -1, + 2, self.rotary_dim // 2) + cos_sin_l = mint.split(cos_sin, self.mrope_section, dim = -1) + for i in range(len(self.mrope_section)): + cos_sin_total += (cos_sin_l[i][i],) + cos_sin_total = mint.cat(cos_sin_total, dim = -1) + cos, sin = ops.chunk(cos_sin_total, 2, axis=-2) + cos = cos.squeeze(-2) + sin = sin.squeeze(-2) + else: + cos, sin = ops.chunk(cos_sin, 2, axis=-1) query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) diff --git a/vllm_mindspore/model_executor/models/glm4_1v.py b/vllm_mindspore/model_executor/models/glm4_1v.py index 279381d92eb1711da15e3fa55a7119867abcc9ba..eaedb8da34ce4297e0308dff48b4bb92214bb537 100644 --- a/vllm_mindspore/model_executor/models/glm4_1v.py +++ b/vllm_mindspore/model_executor/models/glm4_1v.py @@ -1201,10 +1201,12 @@ class Glm4vForConditionalGeneration(NativeModel, SupportsMultiModal): # self.language_model = MindFormersForCausalLM(vllm_config=vllm_config, # prefix=maybe_prefix(prefix, "language_model")) self.model = self.wrap_model.model - self.model.embed_tokens.construct = ms.jit( - function=self.model.embed_tokens, jit_level='O0') + self.model.embed_tokens._set_jit_graph_name("prefill") + self.model.embed_tokens.phase = "prefill" dyn_input_ids = ms.Tensor(shape=[None], dtype=ms.int32) self.model.embed_tokens.set_inputs(dyn_input_ids) + self.model.embed_tokens.construct = ms.jit( + function=self.model.embed_tokens, jit_level='O0') self.lm_head = self.wrap_model.lm_head self.common_preprocess(vllm_config, prefix) diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index 64973e5b9dc0d41903cff85b216680afcc76287c..0f56be044cbd67ea79cf1ec0d8a2db05320168da 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -147,8 +147,7 @@ def _prepare_inputs( self.seq_lens[num_reqs:].fill_(0) # Note: pad query_start_loc to be non-decreasing, as kernels # like FlashAttention requires that - self.query_start_loc[num_reqs + 1:].fill_( - self.query_start_loc_cpu[num_reqs].item()) + self.query_start_loc[num_reqs + 1:] = self.query_start_loc_cpu[num_reqs] # vllm-mindspore begin query_start_loc = ms.from_numpy(self.query_start_loc_np[:num_reqs + 1])