16 Star 433 Fork 291

Ascend/MindSpeed-RL
暂停

加入 Gitee
与超过 1400万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
remove_padding.py 5.83 KB
一键复制 编辑 原始数据 按行查看 历史
zhoubeirong 提交于 2025-07-30 11:49 +08:00 . !488【optimize CP】
# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
from typing import Tuple
import numpy as np
import torch
import torch_npu
def preprocess_packed_seqs(
input_ids: torch.Tensor,
labels: torch.Tensor,
attention_mask_1d: torch.Tensor,
tp_size: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Packs variable-length sequences from a batch into a single contiguous tensor for efficient processing.
Parameters:
input_ids (torch.Tensor): Tensor of shape (batch_size, seq_len) containing token IDs.
attention_mask_1d (torch.Tensor): Binary mask tensor of shape (batch_size, seq_len) where
each entry indicates valid token positions (1) vs padding (0). dtype should be torch.int or torch.bool.
tp_size (int): Alignment factor for packing; sequences are padded so that their lengths are
are multiples of this size.
Returns:
input_ids_packed (torch.Tensor): Tensor of shape (1, pack_length) with all valid tokens packed sequentially.
position_ids_packed (torch.Tensor): Tensor of shape (1, pack_length) containing positional
indices within each padded sequence block.
seqlens_in_batch (torch.Tensor): 1D int32 tensor of shape (batch_size,) with original
sequence lengths (number of valid tokens per sample).
cu_seqlens_padded (torch.Tensor): 1D int32 tensor of shape (batch_size+1,) containing
cumulative padded sequence lengths, used for indexing into the packed tensor.
Raises:
ValueError: If input_ids and attention_mask_1d have incompatible shapes.
"""
batch_size, seq_len = input_ids.shape
if attention_mask_1d.shape != (batch_size, seq_len):
raise ValueError("attention_mask_1d must have shape (batch_size, seq_len) matching input_ids")
# Compute actual sequence lengths per sample
seqlens_in_batch = attention_mask_1d.sum(dim=1, dtype=torch.int32)
# Compute padding needed to align lengths to tp_size
pad_size = (tp_size - (seqlens_in_batch % tp_size)) % tp_size
seqlens_in_batch_padded = seqlens_in_batch + pad_size
# Cumulative lengths without and with padding
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0)
cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0)
# Total packed length after padding
pack_length = int(seqlens_in_batch_padded.sum().item())
input_ids_packed = torch.zeros(pack_length, dtype=input_ids.dtype, device=input_ids.device)
# Copy valid tokens sequentially
for i in range(batch_size):
start = cu_seqlens_padded[i].item()
length = seqlens_in_batch[i].item()
input_ids_packed[start:start + length] = input_ids[i, :length]
# Generate position IDs within each padded segment
position_ids_packed = torch.zeros(pack_length, dtype=torch.int32, device=input_ids.device)
for i in range(batch_size):
start = cu_seqlens_padded[i].item()
end = cu_seqlens_padded[i + 1].item()
position_ids_packed[start:end] = torch.arange(
end - start, dtype=torch.int32, device=input_ids.device
)
labels_packed = torch.zeros(pack_length, dtype=input_ids.dtype, device=input_ids.device)
# Copy valid tokens sequentially
for i in range(batch_size):
start = cu_seqlens_padded[i].item()
length = seqlens_in_batch[i].item()
labels_packed[start:start + length] = labels[i, :length]
return (
input_ids_packed.unsqueeze(0),
position_ids_packed.unsqueeze(0),
labels_packed.unsqueeze(0),
seqlens_in_batch,
cu_seqlens_padded
)
def postprocess_packed_seqs(
output: torch.Tensor,
seqlens_in_batch: torch.Tensor,
cu_seqlens_padded: torch.Tensor,
seq_len: int,
prompt_length: torch.Tensor = None
) -> torch.Tensor:
"""
Unpacks a packed output tensor back into the original batch shape, restoring padding.
Optionally truncates the beginning of each sequence based on prompt_length.
Parameters:
output (torch.Tensor): Packed tensor of shape (1, pack_length, ...), typically the model output.
seqlens_in_batch (torch.Tensor): 1D int32 tensor of original sequence lengths, shape (batch_size,).
cu_seqlens_padded (torch.Tensor): 1D int32 tensor of cumulative padded lengths, shape (batch_size+1,).
seq_len (int): Maximum sequence length (including padding) for the output reconstruction.
prompt_length (torch.Tensor, optional): 1D tensor specifying the length to truncate from the beginning of each sequence.
If None, no truncation is applied. Default is None.
Returns:
output_new (torch.Tensor): Tensor of shape (batch_size, seq_len, ...), with original outputs
in the first seqlens_in_batch positions and zeros for padding positions.
Raises:
ValueError: If output tensor does not have expected batch dimension of 1.
"""
if output.shape[0] != 1:
raise ValueError("Expected output tensor to have shape[0] == 1 (packed batch dimension)")
# Prepare new output with padding
batch_size = seqlens_in_batch.shape[0]
full_shape = [batch_size, seq_len] + list(output.shape[2:])
output_new = torch.zeros(full_shape, dtype=output.dtype, device=output.device)
for i in range(batch_size):
start = cu_seqlens_padded[i].item()
length = seqlens_in_batch[i].item()
if prompt_length is not None:
trunc_length = prompt_length[i].item() - 1
if trunc_length < length:
length -= trunc_length + 1
start += trunc_length
output_new[i, :length] = output[0, start:start + length]
return output_new
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ascend/MindSpeed-RL.git
git@gitee.com:ascend/MindSpeed-RL.git
ascend
MindSpeed-RL
MindSpeed-RL
master

搜索帮助