代码拉取完成,页面将自动刷新
# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
from typing import Tuple
import numpy as np
import torch
import torch_npu
def preprocess_packed_seqs(
input_ids: torch.Tensor,
labels: torch.Tensor,
attention_mask_1d: torch.Tensor,
tp_size: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Packs variable-length sequences from a batch into a single contiguous tensor for efficient processing.
Parameters:
input_ids (torch.Tensor): Tensor of shape (batch_size, seq_len) containing token IDs.
attention_mask_1d (torch.Tensor): Binary mask tensor of shape (batch_size, seq_len) where
each entry indicates valid token positions (1) vs padding (0). dtype should be torch.int or torch.bool.
tp_size (int): Alignment factor for packing; sequences are padded so that their lengths are
are multiples of this size.
Returns:
input_ids_packed (torch.Tensor): Tensor of shape (1, pack_length) with all valid tokens packed sequentially.
position_ids_packed (torch.Tensor): Tensor of shape (1, pack_length) containing positional
indices within each padded sequence block.
seqlens_in_batch (torch.Tensor): 1D int32 tensor of shape (batch_size,) with original
sequence lengths (number of valid tokens per sample).
cu_seqlens_padded (torch.Tensor): 1D int32 tensor of shape (batch_size+1,) containing
cumulative padded sequence lengths, used for indexing into the packed tensor.
Raises:
ValueError: If input_ids and attention_mask_1d have incompatible shapes.
"""
batch_size, seq_len = input_ids.shape
if attention_mask_1d.shape != (batch_size, seq_len):
raise ValueError("attention_mask_1d must have shape (batch_size, seq_len) matching input_ids")
# Compute actual sequence lengths per sample
seqlens_in_batch = attention_mask_1d.sum(dim=1, dtype=torch.int32)
# Compute padding needed to align lengths to tp_size
pad_size = (tp_size - (seqlens_in_batch % tp_size)) % tp_size
seqlens_in_batch_padded = seqlens_in_batch + pad_size
# Cumulative lengths without and with padding
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0)
cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0)
# Total packed length after padding
pack_length = int(seqlens_in_batch_padded.sum().item())
input_ids_packed = torch.zeros(pack_length, dtype=input_ids.dtype, device=input_ids.device)
# Copy valid tokens sequentially
for i in range(batch_size):
start = cu_seqlens_padded[i].item()
length = seqlens_in_batch[i].item()
input_ids_packed[start:start + length] = input_ids[i, :length]
# Generate position IDs within each padded segment
position_ids_packed = torch.zeros(pack_length, dtype=torch.int32, device=input_ids.device)
for i in range(batch_size):
start = cu_seqlens_padded[i].item()
end = cu_seqlens_padded[i + 1].item()
position_ids_packed[start:end] = torch.arange(
end - start, dtype=torch.int32, device=input_ids.device
)
labels_packed = torch.zeros(pack_length, dtype=input_ids.dtype, device=input_ids.device)
# Copy valid tokens sequentially
for i in range(batch_size):
start = cu_seqlens_padded[i].item()
length = seqlens_in_batch[i].item()
labels_packed[start:start + length] = labels[i, :length]
return (
input_ids_packed.unsqueeze(0),
position_ids_packed.unsqueeze(0),
labels_packed.unsqueeze(0),
seqlens_in_batch,
cu_seqlens_padded
)
def postprocess_packed_seqs(
output: torch.Tensor,
seqlens_in_batch: torch.Tensor,
cu_seqlens_padded: torch.Tensor,
seq_len: int,
prompt_length: torch.Tensor = None
) -> torch.Tensor:
"""
Unpacks a packed output tensor back into the original batch shape, restoring padding.
Optionally truncates the beginning of each sequence based on prompt_length.
Parameters:
output (torch.Tensor): Packed tensor of shape (1, pack_length, ...), typically the model output.
seqlens_in_batch (torch.Tensor): 1D int32 tensor of original sequence lengths, shape (batch_size,).
cu_seqlens_padded (torch.Tensor): 1D int32 tensor of cumulative padded lengths, shape (batch_size+1,).
seq_len (int): Maximum sequence length (including padding) for the output reconstruction.
prompt_length (torch.Tensor, optional): 1D tensor specifying the length to truncate from the beginning of each sequence.
If None, no truncation is applied. Default is None.
Returns:
output_new (torch.Tensor): Tensor of shape (batch_size, seq_len, ...), with original outputs
in the first seqlens_in_batch positions and zeros for padding positions.
Raises:
ValueError: If output tensor does not have expected batch dimension of 1.
"""
if output.shape[0] != 1:
raise ValueError("Expected output tensor to have shape[0] == 1 (packed batch dimension)")
# Prepare new output with padding
batch_size = seqlens_in_batch.shape[0]
full_shape = [batch_size, seq_len] + list(output.shape[2:])
output_new = torch.zeros(full_shape, dtype=output.dtype, device=output.device)
for i in range(batch_size):
start = cu_seqlens_padded[i].item()
length = seqlens_in_batch[i].item()
if prompt_length is not None:
trunc_length = prompt_length[i].item() - 1
if trunc_length < length:
length -= trunc_length + 1
start += trunc_length
output_new[i, :length] = output[0, start:start + length]
return output_new
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。