From 4a2c1acbf03e120859929a524ae9dbaae81d3f3e Mon Sep 17 00:00:00 2001 From: z00804007 Date: Fri, 18 Jul 2025 18:29:39 +0800 Subject: [PATCH] optimize cp --- mindspeed_rl/config_cls/megatron_config.py | 6 ++ .../models/base/base_training_engine.py | 57 ++++++++++------- mindspeed_rl/models/loss/base_loss_func.py | 64 +++++++++---------- .../models/loss/dapo_actor_loss_func.py | 4 +- .../models/loss/grpo_actor_loss_func.py | 4 +- .../models/loss/ppo_actor_loss_func.py | 5 +- .../models/loss/reference_loss_func.py | 2 +- mindspeed_rl/models/rollout/vllm_engine.py | 1 + mindspeed_rl/utils/context_parallel.py | 32 +++++----- mindspeed_rl/utils/pad_process.py | 35 ++++++++++ mindspeed_rl/utils/remove_padding.py | 27 ++++++-- 11 files changed, 154 insertions(+), 83 deletions(-) diff --git a/mindspeed_rl/config_cls/megatron_config.py b/mindspeed_rl/config_cls/megatron_config.py index e53b1ab1..50e6670f 100644 --- a/mindspeed_rl/config_cls/megatron_config.py +++ b/mindspeed_rl/config_cls/megatron_config.py @@ -204,6 +204,8 @@ class MegatronConfig(BaseConfig): noop_layers: noop layers string cp_attention_mask_type: attention mask type in cp reset_attention_mask: reset attention_mask in cp + use_cp_send_recv_overlap: To support send receive overlap in cp, suggest to be true + use_fused_ring_attention_update: use fused_ring_attention_update in cp, suggest to be true use_ascend_coc: switch to open CoC feature (default: False) coc_mode: 0=original, 1=rewrite, 2=coc default coc_parallel_num: number of parallel in CoC features (default: 1) @@ -376,6 +378,8 @@ class MegatronConfig(BaseConfig): self.noop_layers = None self.cp_attention_mask_type = 'causal' self.reset_attention_mask = False + self.use_cp_send_recv_overlap = False + self.use_fused_ring_attention_update = False self.dpo_loss_type = 'sigmoid' self.use_ascend_coc = False @@ -391,3 +395,5 @@ class MegatronConfig(BaseConfig): self.gemm_gradient_accumulation_fusion = False self.update(training_config, model_config) + + self.pad_to_multiple_of = self.tensor_model_parallel_size * self.context_parallel_size diff --git a/mindspeed_rl/models/base/base_training_engine.py b/mindspeed_rl/models/base/base_training_engine.py index 044fd36c..a964fac3 100644 --- a/mindspeed_rl/models/base/base_training_engine.py +++ b/mindspeed_rl/models/base/base_training_engine.py @@ -16,7 +16,7 @@ from mindspeed_rl.utils.utils import ( from mindspeed_rl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx from mindspeed_rl.utils.remove_padding import preprocess_packed_seqs, postprocess_packed_seqs from mindspeed_rl.utils.compute import get_parallel_state -from mindspeed_rl.utils.context_parallel import get_batch_on_this_cp_rank, get_ring_degree, get_output_allgather_cp_with_pack, get_output_allgather_cp_without_pack +from mindspeed_rl.utils.context_parallel import get_batch_on_this_cp_rank, get_ring_degree from mindspeed_rl.utils.utils import is_multimodal @@ -177,11 +177,6 @@ class BaseTrainingEngine(ABC): elif self.use_remove_padding: input_ids, position_ids, process_batch, seqlens_in_batch, cu_seqlens_padded, index = self._get_forward_batch_info(batch_iter) output = model(input_ids=input_ids, attention_mask=None, position_ids=position_ids) - if post_process: - if cp_size > 1: - output = get_output_allgather_cp_with_pack(output, cp_size, index) - if self.megatron_config.cp_attention_mask_type == 'causal': - cu_seqlens_padded *= get_ring_degree(self.megatron_config) output.div_(self.temperature) return output, partial(self.loss_func.compute_loss, batch=process_batch, @@ -189,21 +184,21 @@ class BaseTrainingEngine(ABC): use_remove_padding=self.use_remove_padding, seqlens_in_batch=seqlens_in_batch, cu_seqlens_padded=cu_seqlens_padded, + seq_len=seq_len, use_dynamic_bsz=self.use_dynamic_bsz, - actual_micro_batch_size=batch_size / n_micro_batch) + actual_micro_batch_size=batch_size / n_micro_batch, + index=index) else: input_ids, position_ids, process_batch, index = self._get_forward_batch_info(batch_iter) output = model(input_ids=input_ids, attention_mask=None, position_ids=position_ids) - if post_process: - if cp_size > 1: - output = get_output_allgather_cp_without_pack(output, cp_size, index) output.div_(self.temperature) return output, partial(self.loss_func.compute_loss, batch=process_batch, forward_only=forward_only, use_dynamic_bsz=self.use_dynamic_bsz, - actual_micro_batch_size=batch_size / n_micro_batch) + actual_micro_batch_size=batch_size / n_micro_batch, + index=index) # batch should be a list of batches inside micro-batches losses_reduced = self.forward_backward_func( @@ -233,21 +228,33 @@ class BaseTrainingEngine(ABC): """ return {} - def _get_batch_data_with_cp(self, input_ids, position_ids): + def _get_batch_data_with_cp(self, batch, input_ids, position_ids, labels): batch_for_cp = { 'input_ids': input_ids, - 'position_ids': position_ids + 'position_ids': position_ids, + 'labels': labels } batch_cp, index = get_batch_on_this_cp_rank(self.megatron_config, batch_for_cp, self.get_actual_seq_len()) - input_ids = batch_cp['input_ids'] - position_ids = batch_cp['position_ids'] - self.set_position_ids(position_ids.transpose(0, 1).contiguous()) - return input_ids, position_ids, index + batch['input_ids'] = batch_cp['input_ids'] + batch['position_ids'] = batch_cp['position_ids'] + batch['labels'] = batch_cp['labels'] + self.set_position_ids(batch['position_ids'].transpose(0, 1).contiguous()) + + return batch['input_ids'], batch['position_ids'], batch, index def _get_forward_batch_info(self, batch_iter): batch = next(batch_iter) input_ids = batch['input_ids'] + + # generate a labels tensor based on input_id. Remove the first token along the sequence dimension, and append a token with value 0 at the end. + # This is done to align the data and enable subsequent log probability (logP) calculation + labels = batch['input_ids'] + labels = labels[:, 1:] + tmp_add = torch.zeros(labels.size(0), 1, dtype=labels.dtype, device=labels.device) + labels = torch.cat((labels, tmp_add), dim=1) + batch['labels'] = labels + attention_mask_1d = generate_mask(input_ids, batch['prompt_length'] + batch['response_length']).to( input_ids.device) cp_size = get_parallel_state().get_context_parallel_world_size() @@ -259,15 +266,17 @@ class BaseTrainingEngine(ABC): else: multi = tp_size * cp_size - input_ids, position_ids, seqlens_in_batch, cu_seqlens_padded = preprocess_packed_seqs( - input_ids=input_ids, attention_mask_1d=attention_mask_1d, tp_size=multi) - + input_ids, position_ids, labels, seqlens_in_batch, cu_seqlens_padded = preprocess_packed_seqs( + input_ids=input_ids, labels=labels, attention_mask_1d=attention_mask_1d, tp_size=multi) + batch['labels'] = labels + + cu_seqlens_padded_ring = cu_seqlens_padded if self.megatron_config.cp_attention_mask_type == 'causal': - cu_seqlens_padded /= get_ring_degree(self.megatron_config) - self.set_actual_seq_len(cu_seqlens_padded.tolist()) + cu_seqlens_padded_ring = (cu_seqlens_padded / get_ring_degree(self.megatron_config)).to(torch.int) + self.set_actual_seq_len(cu_seqlens_padded_ring.tolist()) if cp_size > 1: - input_ids, position_ids, index = self._get_batch_data_with_cp(input_ids, position_ids) + input_ids, position_ids, batch, index = self._get_batch_data_with_cp(batch, input_ids, position_ids, labels) return input_ids, position_ids, batch, seqlens_in_batch, cu_seqlens_padded, index @@ -275,7 +284,7 @@ class BaseTrainingEngine(ABC): position_ids = torch.tensor(generate_position_ids(input_ids)).to(input_ids.device) if cp_size > 1: - input_ids, position_ids, index = self._get_batch_data_with_cp(input_ids, position_ids) + input_ids, position_ids, batch, index = self._get_batch_data_with_cp(batch, input_ids, position_ids, labels) return input_ids, position_ids, batch, index diff --git a/mindspeed_rl/models/loss/base_loss_func.py b/mindspeed_rl/models/loss/base_loss_func.py index 77df2821..2464a488 100644 --- a/mindspeed_rl/models/loss/base_loss_func.py +++ b/mindspeed_rl/models/loss/base_loss_func.py @@ -6,7 +6,10 @@ from typing import Dict, Tuple import torch from mindspeed_rl.utils.compute import compute_log_probs, vocab_parallel_entropy -from mindspeed_rl.utils.pad_process import truncate_middle_and_pad +from mindspeed_rl.utils.pad_process import truncate_prompt_and_pad +from mindspeed_rl.utils.context_parallel import get_tensor_allgather_cp_without_pack, get_tensor_allgather_cp_with_pack +from mindspeed_rl.utils.compute import get_parallel_state +from mindspeed_rl.utils.remove_padding import postprocess_packed_seqs class BaseLossFunc(ABC): @@ -36,44 +39,41 @@ class BaseLossFunc(ABC): pass @staticmethod - def _get_compute_log_probs_input(output: torch.Tensor, batch: Dict[str, torch.Tensor]): - if 'responses' not in batch: - raise ValueError("The responses is None") + def _get_log_probs_remove_prompt_pad(logprob: torch.Tensor, batch: Dict[str, torch.Tensor]): responses = batch['responses'] truncate_lengths = torch.cat([batch['prompt_length'], batch['prompt_length'] + batch['response_length']], dim=1) - 1 - logits = truncate_middle_and_pad(responses, output, truncate_lengths) - return responses, logits + logprob = truncate_prompt_and_pad(responses, logprob, truncate_lengths) + return logprob - def compute_log_probs(self, output, batch: Dict[str, torch.Tensor], update=False, **kwargs): + def compute_log_probs(self, output, batch: Dict[str, torch.Tensor], skip_entropy=True, **kwargs): use_remove_padding = kwargs.get('use_remove_padding', False) + index = kwargs.get('index', None) + labels = batch['labels'] + + log_probs = compute_log_probs(output, labels) + + cp_size = get_parallel_state().get_context_parallel_world_size() + if use_remove_padding: + log_probs_allgather = get_tensor_allgather_cp_with_pack(log_probs, cp_size, index) seqlens_in_batch = kwargs.get('seqlens_in_batch', None) cu_seqlens_padded = kwargs.get('cu_seqlens_padded', None) - batch_size = seqlens_in_batch.shape[0] - log_probs_list = [] - entropy_list = [] - for i in range(batch_size): - start = cu_seqlens_padded[i].item() - length = seqlens_in_batch[i].item() - single_output = output[0, start:start + length].unsqueeze(0) # [1, length, vocab_size] - single_batch = {key: value[i].unsqueeze(0) for key, value in batch.items()} - response, logits = self._get_compute_log_probs_input(single_output, single_batch) - single_log_probs = compute_log_probs(logits, response) - log_probs_list.append(single_log_probs) - if update: - single_entropy = vocab_parallel_entropy(logits) - entropy_list.append(single_entropy) - log_probs = torch.cat(log_probs_list, dim=0) - if update: - entropy = torch.cat(entropy_list, dim=0) - return log_probs, entropy + seq_len = batch['responses'].shape[-1] + log_probs = postprocess_packed_seqs(log_probs_allgather, seqlens_in_batch, cu_seqlens_padded, seq_len, prompt_length=batch['prompt_length']) + if not skip_entropy: + entropy = vocab_parallel_entropy(output) + entropy = postprocess_packed_seqs(entropy, seqlens_in_batch, cu_seqlens_padded, seq_len, prompt_length=batch['prompt_length']) else: - return log_probs + entropy = torch.zeros_like(log_probs) + + return log_probs, entropy + else: - responses, logits = self._get_compute_log_probs_input(output, batch) - log_probs = compute_log_probs(logits, responses) - if update: - entropy = vocab_parallel_entropy(logits) - return log_probs, entropy + log_probs_allgather = get_tensor_allgather_cp_without_pack(log_probs, cp_size, index) + log_probs = self._get_log_probs_remove_prompt_pad(log_probs_allgather, batch) + if not skip_entropy: + entropy = vocab_parallel_entropy(output) else: - return log_probs \ No newline at end of file + entropy = torch.zeros_like(log_probs) + + return log_probs, entropy \ No newline at end of file diff --git a/mindspeed_rl/models/loss/dapo_actor_loss_func.py b/mindspeed_rl/models/loss/dapo_actor_loss_func.py index 0f5fffeb..948ddae5 100644 --- a/mindspeed_rl/models/loss/dapo_actor_loss_func.py +++ b/mindspeed_rl/models/loss/dapo_actor_loss_func.py @@ -58,9 +58,9 @@ class DAPOActorLossFunc(BaseLossFunc): """ # compute log probs if forward_only: - log_probs = super().compute_log_probs(output=output, batch=batch, **kwargs) + log_probs, _ = super().compute_log_probs(output=output, batch=batch, **kwargs) return log_probs - log_probs, entropy = super().compute_log_probs(output=output, batch=batch, update=True, **kwargs) + log_probs, entropy = super().compute_log_probs(output=output, batch=batch, skip_entropy=(self.entropy_coeff == 0), **kwargs) response_mask, old_log_prob, advantages = self._get_policy_loss_input(batch=batch) # compute policy loss diff --git a/mindspeed_rl/models/loss/grpo_actor_loss_func.py b/mindspeed_rl/models/loss/grpo_actor_loss_func.py index 171a952c..fa0f4a58 100644 --- a/mindspeed_rl/models/loss/grpo_actor_loss_func.py +++ b/mindspeed_rl/models/loss/grpo_actor_loss_func.py @@ -54,9 +54,9 @@ class GRPOActorLossFunc(BaseLossFunc): """ # compute log probs if forward_only: - log_probs = super().compute_log_probs(output=output, batch=batch, **kwargs) + log_probs, _ = super().compute_log_probs(output=output, batch=batch, **kwargs) return log_probs - log_probs, entropy = super().compute_log_probs(output=output, batch=batch, update=True, **kwargs) + log_probs, entropy = super().compute_log_probs(output=output, batch=batch, skip_entropy=(self.entropy_coeff == 0), **kwargs) response_mask, old_log_prob, advantages, ref_log_prob = self._get_policy_loss_input(batch=batch) # compute policy loss diff --git a/mindspeed_rl/models/loss/ppo_actor_loss_func.py b/mindspeed_rl/models/loss/ppo_actor_loss_func.py index 54205658..db8bb61a 100644 --- a/mindspeed_rl/models/loss/ppo_actor_loss_func.py +++ b/mindspeed_rl/models/loss/ppo_actor_loss_func.py @@ -51,8 +51,9 @@ class PPOActorLossFunc(BaseLossFunc): """ # compute log probs if forward_only: - return super().compute_log_probs(output=output, batch=batch, **kwargs) - log_probs, entropy = super().compute_log_probs(output=output, batch=batch, update=True, **kwargs) + log_probs, _ = super().compute_log_probs(output=output, batch=batch, **kwargs) + return log_probs + log_probs, entropy = super().compute_log_probs(output=output, batch=batch, skip_entropy=(self.entropy_coeff == 0), **kwargs) response_mask, old_log_prob, advantages, ref_log_prob = self._get_policy_loss_input(batch=batch) diff --git a/mindspeed_rl/models/loss/reference_loss_func.py b/mindspeed_rl/models/loss/reference_loss_func.py index d5455730..c4332c1b 100644 --- a/mindspeed_rl/models/loss/reference_loss_func.py +++ b/mindspeed_rl/models/loss/reference_loss_func.py @@ -19,7 +19,7 @@ class ReferenceLossFunc(BaseLossFunc): non_loss_data=True, **kwargs) -> Tuple[torch.Tensor, Dict]: # compute log probs - log_probs = super().compute_log_probs(output=output, batch=batch, **kwargs) + log_probs, _ = super().compute_log_probs(output=output, batch=batch, **kwargs) if forward_only: return log_probs return None diff --git a/mindspeed_rl/models/rollout/vllm_engine.py b/mindspeed_rl/models/rollout/vllm_engine.py index 73fe1af4..d3bfbb15 100644 --- a/mindspeed_rl/models/rollout/vllm_engine.py +++ b/mindspeed_rl/models/rollout/vllm_engine.py @@ -325,6 +325,7 @@ class VLLMInferEngine(BaseInferEngine): self.free_cache_engine() return outs + @torch.no_grad() def async_generate_sequences(self, idx_list, indexes, n_samples_per_prompt=None, **kwargs): with self.update_sampling_params(**kwargs): diff --git a/mindspeed_rl/utils/context_parallel.py b/mindspeed_rl/utils/context_parallel.py index f9f29aea..833c11c0 100644 --- a/mindspeed_rl/utils/context_parallel.py +++ b/mindspeed_rl/utils/context_parallel.py @@ -116,36 +116,36 @@ def get_ring_degree(megatron_config): return ring_degree -def allgather_output_cp_group(output_orig, cp_size): - output_list = [torch.empty_like(output_orig) for _ in range(cp_size)] - torch.distributed.all_gather(output_list, output_orig.detach(), group=get_parallel_state().get_context_parallel_group()) - output_list[get_parallel_state().get_context_parallel_rank()] = output_orig +def allgather_tensor_cp_group(cp_tensor, cp_size): + output_list = [torch.empty_like(cp_tensor) for _ in range(cp_size)] + torch.distributed.all_gather(output_list, cp_tensor.detach(), group=get_parallel_state().get_context_parallel_group()) + output_list[get_parallel_state().get_context_parallel_rank()] = cp_tensor output_all_cp = torch.cat(output_list, dim=1) return output_all_cp -def get_output_allgather_cp_with_pack(output_orig, cp_size, index): - # output allgather - output_all_cp = allgather_output_cp_group(output_orig, cp_size) +def get_tensor_allgather_cp_with_pack(cp_tensor, cp_size, index): + # cp_tensor allgather + output_all_cp = allgather_tensor_cp_group(cp_tensor, cp_size) # when use ring cp, the index is not none. Need to restore the output order based on the index. if index is not None: - # Step1 index allgather + # index allgather index_list = [torch.empty_like(index) for _ in range(cp_size)] torch.distributed.all_gather(index_list, index, group=get_parallel_state().get_context_parallel_group()) - index_all_cp = torch.cat(index_list, dim=0) - index_expand = index_all_cp.view(1, output_all_cp.shape[1], 1) - # Step2 use scatter to restore the output order based on the index - output_order_restored = torch.zeros_like(output_all_cp) - output_order_restored.scatter_(1, index_expand.expand(-1, -1, output_all_cp.shape[2]), output_all_cp) + index_all_cp = torch.cat(index_list, dim=0).cpu().numpy() + index_all_cp_argsort = np.argsort(index_all_cp) + + output_order_restored = output_all_cp[:, index_all_cp_argsort] output = output_order_restored else: output = output_all_cp + return output -def get_output_allgather_cp_without_pack(output_orig, cp_size, index): - # output allgather - output_all_cp = allgather_output_cp_group(output_orig, cp_size) +def get_tensor_allgather_cp_without_pack(cp_tensor, cp_size, index): + # cp_tensor allgather + output_all_cp = allgather_tensor_cp_group(cp_tensor, cp_size) if index is not None: # Step1 get index and argsort it for select index_list = [] diff --git a/mindspeed_rl/utils/pad_process.py b/mindspeed_rl/utils/pad_process.py index 33115231..645650e2 100644 --- a/mindspeed_rl/utils/pad_process.py +++ b/mindspeed_rl/utils/pad_process.py @@ -106,6 +106,41 @@ def truncate_middle_and_pad(responses, input_tensor, truncate_lengths, pad_value return output_tensor +def truncate_prompt_and_pad(responses, input_tensor, truncate_lengths, pad_value=0.0): + """ + input_tensor: Tensor of shape (mbs, seq_len) + truncate_lengths: Tensor of shape (mbs, 2), where truncate_lengths[i, 0] is the start index to keep, + and truncate_lengths[i, 1] is the end index to keep (exclusive). + pad_value: Value to use for padding (default is 0.0) + """ + + mbs, seq_len = input_tensor.shape + + # Ensure truncate_lengths is within valid range + truncate_lengths = torch.clamp(truncate_lengths, 0, seq_len) + + # Calculate the new lengths after truncation + new_lengths = truncate_lengths[:, 1] - truncate_lengths[:, 0] # (mbs,) + + # Find the maximum length after truncation + max_new_len = responses.shape[-1] + + # Initialize the output tensor with padding values + output_tensor = torch.full((mbs, max_new_len), pad_value, dtype=input_tensor.dtype, + device=input_tensor.device) + + # Fill the output tensor with truncated values + for i in range(mbs): + start_idx = truncate_lengths[i, 0].item() # Start index to keep + end_idx = truncate_lengths[i, 1].item() # End index to keep (exclusive) + new_len = new_lengths[i].item() # New length after truncation + + # Copy the middle part of the row to the output tensor + output_tensor[i, :new_len] = input_tensor[i, start_idx:end_idx] + + return output_tensor + + def truncate_rows(tensor, index_tensor, left_pad=False): """ tensor: 二维 Tensor,形状为 (mbs, seq_len) diff --git a/mindspeed_rl/utils/remove_padding.py b/mindspeed_rl/utils/remove_padding.py index d730474f..fc3fac4f 100644 --- a/mindspeed_rl/utils/remove_padding.py +++ b/mindspeed_rl/utils/remove_padding.py @@ -9,6 +9,7 @@ import torch_npu def preprocess_packed_seqs( input_ids: torch.Tensor, + labels: torch.Tensor, attention_mask_1d: torch.Tensor, tp_size: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -69,10 +70,18 @@ def preprocess_packed_seqs( position_ids_packed[start:end] = torch.arange( end - start, dtype=torch.int32, device=input_ids.device ) + + labels_packed = torch.zeros(pack_length, dtype=input_ids.dtype, device=input_ids.device) + # Copy valid tokens sequentially + for i in range(batch_size): + start = cu_seqlens_padded[i].item() + length = seqlens_in_batch[i].item() + labels_packed[start:start + length] = labels[i, :length] return ( input_ids_packed.unsqueeze(0), position_ids_packed.unsqueeze(0), + labels_packed.unsqueeze(0), seqlens_in_batch, cu_seqlens_padded ) @@ -82,17 +91,20 @@ def postprocess_packed_seqs( output: torch.Tensor, seqlens_in_batch: torch.Tensor, cu_seqlens_padded: torch.Tensor, - seq_len: int + seq_len: int, + prompt_length: torch.Tensor = None ) -> torch.Tensor: """ Unpacks a packed output tensor back into the original batch shape, restoring padding. + Optionally truncates the beginning of each sequence based on prompt_length. Parameters: output (torch.Tensor): Packed tensor of shape (1, pack_length, ...), typically the model output. seqlens_in_batch (torch.Tensor): 1D int32 tensor of original sequence lengths, shape (batch_size,). cu_seqlens_padded (torch.Tensor): 1D int32 tensor of cumulative padded lengths, shape (batch_size+1,). - batch_size (int): Original batch size. seq_len (int): Maximum sequence length (including padding) for the output reconstruction. + prompt_length (torch.Tensor, optional): 1D tensor specifying the length to truncate from the beginning of each sequence. + If None, no truncation is applied. Default is None. Returns: output_new (torch.Tensor): Tensor of shape (batch_size, seq_len, ...), with original outputs @@ -101,7 +113,6 @@ def postprocess_packed_seqs( Raises: ValueError: If output tensor does not have expected batch dimension of 1. """ - if output.shape[0] != 1: raise ValueError("Expected output tensor to have shape[0] == 1 (packed batch dimension)") @@ -109,9 +120,17 @@ def postprocess_packed_seqs( batch_size = seqlens_in_batch.shape[0] full_shape = [batch_size, seq_len] + list(output.shape[2:]) output_new = torch.zeros(full_shape, dtype=output.dtype, device=output.device) + for i in range(batch_size): start = cu_seqlens_padded[i].item() length = seqlens_in_batch[i].item() + + if prompt_length is not None: + trunc_length = prompt_length[i].item() - 1 + if trunc_length < length: + length -= trunc_length + 1 + start += trunc_length + output_new[i, :length] = output[0, start:start + length] - return output_new + return output_new -- Gitee