一、问题现象(附报错日志上下文):
单机8卡分布式训练报错,
二、软件版本:
torch=2.4.0
torch-npu=2.4.0.post2
CANN version=8.0.RC3
arch=aarch64
python=3.10.16
npu-smi=23.0.6
ubuntu=22.04
910B2 64G 8卡
三、测试步骤:
使用模型并行+数据并行,环境变量export HCCL_CONNECT_TIMEOUT=6000
报错代码如下:
def evaluate_ppo(self): # noqa: C901
# self.model.eval()
"""Samples model on `eval_prompts`, logs stats with `reward_fn` or `metric_fn` if provided"""
stats = {}
all_full_ids = []
all_rev_kl = []
all_lens = []
table = []
with torch.no_grad():
for batch in tqdm(self.eval_dataloader, "Generation Evaluation", disable=(not get_rank() == 0)):
batch, no_model_batch = batch
batch, _ = self.eval_pipeline.move_to_device(batch, no_model_batch, self.device)
# 2. 插入NPU流同步(新增)
stream = torch.npu.current_stream()
stream.synchronize()
# 获取学生模型的输出
gen_out = self.generate(
**batch,
return_dict_in_generate=True,
output_scores=True
)
full_ids = gen_out.sequences
gen_logits = gen_out.scores # NOTE: [b, s, h_p]
inf_mask = torch.isinf(gen_logits)
# 5. 同步NPU流后再收集数据(新增)
stream = torch.npu.current_stream()
stream.synchronize()
all_full_ids.append(full_ids)
input_ids = batch["input_ids"]
gen_ids = full_ids[:, input_ids.size(1):]
mask = self.get_mask(full_ids)
mask = mask[:, input_ids.size(1)-1:input_ids.size(1)+gen_ids.size(1)-1]
lens = torch.sum(mask, dim=-1)
teacher_rewards = self.reward_fn(input_ids, gen_ids)["rewards"] # \log p(y_t | y_{<t}, x)
_, logprobs = self.compute_logits_and_log_probs(input_ids, gen_ids, inf_mask=inf_mask, base="base") # \log q_{\theta}(y_t | y_{<t}, x)
kl = get_rev_kl(teacher_rewards, logprobs, mask) # 获取反向KL分数
kl = kl.sum(-1)
if self.args.length_norm:
kl = kl / lens
all_rev_kl.append(kl)
all_lens.append(lens)
# 7. 添加分布式同步点(新增)
torch.distributed.barrier()
all_full_ids = torch.cat(all_full_ids, dim=0)
all_rev_kl = torch.cat(all_rev_kl, dim=0)
all_lens = torch.cat(all_lens, dim=0)
full_ids = all_gather(all_full_ids, dim=1, world_size=self.dp_world_size, group=self.dp_group, op="stack")
full_ids = full_ids.view(-1, full_ids.size(-1))
prompt_ids = full_ids[:, :self.eval_pipeline.max_prompt_length]
all_rev_kl = all_gather(all_rev_kl, dim=0, world_size=self.dp_world_size, group=self.dp_group)
stats["rev_kl"] = all_rev_kl.mean()
all_lens = all_gather(all_lens, dim=0, world_size=self.dp_world_size, group=self.dp_group)
stats["lens"] = all_lens.float().mean()
# 7. 添加分布式同步点(新增)
torch.distributed.barrier()
response_texts = []
if get_rank() == 0:
# 解码前同步流(新增)
stream = torch.npu.current_stream()
stream.synchronize()
prompt_texts = self.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
response_texts = self.tokenizer.batch_decode(full_ids[:, self.eval_pipeline.max_prompt_length:], skip_special_tokens=True)
gen_texts = [p + g for p, g in zip(prompt_texts, response_texts)]
columns = ["prompts"]
columns_data = [prompt_texts]
# in online setting, compute the reward for validation
columns.append("samples")
if isinstance(gen_texts[0], str):
columns_data.append(gen_texts)
else:
columns_data.append(gen_texts.tolist())
table.append(list(zip(*columns_data)))
# 9. 添加全局同步点(关键修复点)
torch.distributed.barrier()
# Log and display evaluation metrics
if get_rank() == 0:
rows = sum(list(map(list, zip(*table))), [])
# Add metrics/rewards to the table's title
table_title = f"Evaluation #{self.nth_evaluation}"
for k, x in stats.items():
if k.startswith("reward") or k.startswith("metrics"):
table_title += f" {k}: {significant(x)}"
rich_table = Table(*columns, title=table_title, show_lines=True)
for ix in range(min(3, len(rows))):
rich_table.add_row(*[str(significant(x)) for x in rows[ix]])
try:
Console().print(rich_table)
except:
pass
# 10. 返回前同步所有进程(新增)
torch.distributed.barrier()
self.nth_evaluation += 1
return stats, table, response_texts
四、报错信息:
在以下这几行都报过一模一样的错误:all_lens.append(lens)
、torch.distributed.barrier()
、full_ids = all_gather(all_full_ids, dim=1, world_size=self.dp_world_size, group=self.dp_group, op="stack")
、prompt_texts = self.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
【03-11补充】:当eval dataloader样本数量少(如64)的时候不会报错,样本数量增多至200就报错
Generation Evaluation: 0%| | 0/106 [00:00<?, ?it/s]len(self.train_dataloader), 4
len(self.train_dataloader), 4
len(self.train_dataloader), 4
len(self.train_dataloader), 4
Generation Evaluation: 37%|███████████████████████████████████████████████████████▏ | 39/106 [39:39<1:11:12, 63.76s/it]
Generation Evaluation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 106/106 [1:26:00<00:00, 48.68s/it]
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
File "/home/jovyan/gqy/minillm/./train_minillm.py", line 113, in
File "/home/jovyan/gqy/minillm/./train_minillm.py", line 113, in
File "/home/jovyan/gqy/minillm/./train_minillm.py", line 113, in
Traceback (most recent call last):
File "/home/jovyan/gqy/minillm/./train_minillm.py", line 113, in
main()main()
File "/home/jovyan/gqy/minillm/./train_minillm.py", line 99, in main
File "/home/jovyan/gqy/minillm/./train_minillm.py", line 99, in main
main()
File "/home/jovyan/gqy/minillm/./train_minillm.py", line 99, in main
main()
File "/home/jovyan/gqy/minillm/./train_minillm.py", line 99, in main
train(train(
File "/home/jovyan/gqy/minillm/minillm/init.py", line 52, in train
File "/home/jovyan/gqy/minillm/minillm/init.py", line 52, in train
train(
File "/home/jovyan/gqy/minillm/minillm/init.py", line 52, in train
train(
File "/home/jovyan/gqy/minillm/minillm/init.py", line 52, in train
trainer.train()trainer.train()
File "/home/jovyan/gqy/minillm/minillm/trainer.py", line 244, in train
File "/home/jovyan/gqy/minillm/minillm/trainer.py", line 244, in train
trainer.train()
File "/home/jovyan/gqy/minillm/minillm/trainer.py", line 244, in train
trainer.train()
File "/home/jovyan/gqy/minillm/minillm/trainer.py", line 244, in train
self.global_iter_count = 1self.global_iter_count = 1
File "/home/jovyan/gqy/minillm/minillm/trainer.py", line 432, in evaluate
File "/home/jovyan/gqy/minillm/minillm/trainer.py", line 432, in evaluate
self.global_iter_count = 1
File "/home/jovyan/gqy/minillm/minillm/trainer.py", line 432, in evaluate
self.global_iter_count = 1
File "/home/jovyan/gqy/minillm/minillm/trainer.py", line 432, in evaluate
File "/home/jovyan/gqy/minillm/minillm/trainer.py", line 505, in evaluate_ppo
File "/home/jovyan/gqy/minillm/minillm/trainer.py", line 505, in evaluate_ppo
File "/home/jovyan/gqy/minillm/minillm/trainer.py", line 505, in evaluate_ppo
File "/home/jovyan/gqy/minillm/minillm/trainer.py", line 505, in evaluate_ppo
all_lens.append(lens)all_lens.append(lens)
File "/home/jovyan/.conda/envs/model_dis/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
File "/home/jovyan/.conda/envs/model_dis/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
all_lens.append(lens)all_lens.append(lens)
File "/home/jovyan/.conda/envs/model_dis/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
File "/home/jovyan/.conda/envs/model_dis/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
return func(*args, **kwargs)return func(*args, **kwargs)return func(*args, **kwargs)
File "/home/jovyan/.conda/envs/model_dis/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3703, in barrier
File "/home/jovyan/.conda/envs/model_dis/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3703, in barrier
File "/home/jovyan/.conda/envs/model_dis/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3703, in barrier
return func(*args, **kwargs)
File "/home/jovyan/.conda/envs/model_dis/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3703, in barrier
work.wait()work.wait()
work.wait()RuntimeError
RuntimeErrorRuntimeError: : npuSynchronizeDevice:build/CMakeFiles/torch_npu.dir/compiler_depend.ts:467 NPU function error: AclrtSynchronizeDeviceWithTimeout, error code is 507048
[ERROR] 2025-03-10-04:19:37 (PID:1185975, Device:2, RankID:2) ERR00100 PTA call acl api failed
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EI0002: [PID: 1185975] 2025-03-10-04:19:37.576.321 The wait execution of the Notify register times out. Reason: The Notify register has not received the Notify record from remote rank [unknown].base information: [streamID:[4265454028], taskID[13], tag[AllReduce_100.124.104.201%eth0_60000_0_1741569404003403], AlgType(level 0-1-2):[fullmesh-ring-ring].] task information: []
Possible Cause: 1. An exception occurs during the execution on some NPUs in the cluster. As a result, collective communication operation failed.2. The execution speed on some NPU in the cluster is too slow to complete a communication operation within the timeout interval. (default 1800s, You can set the interval by using HCCL_EXEC_TIMEOUT.)3. The number of training samples of each NPU is inconsistent.4. Packet loss or other connectivity problems occur on the communication link.
Solution: 1. If this error is reported on part of these ranks, check other ranks to see whether other errors have been reported earlier.2. If this error is reported for all ranks, check whether the error reporting time is consistent (the maximum difference must not exceed 1800s). If not, locate the cause or adjust the locate the cause or set the HCCL_EXEC_TIMEOUT environment variable to a larger value.3. Check whether the completion queue element (CQE) of the error exists in the plog(grep -rn 'error cqe'). If so, check the network connection status. (For details, see the TLS command and HCCN connectivity check examples.)4. Ensure that the number of training samples of each NPU is consistent. For details:https://www.hiascend.com/document
TraceBack (most recent call last):
The error from device(chipId:2, dieId:0), serial number is 3, hccl fftsplus task timeout occurred during task execution, stream_id:4, sq_id:4, task_id:13, stuck notify num:1, timeout:1836.[FUNC:ProcessStarsHcclFftsPlusTimeoutErrorInfo][FILE:device_error_proc.cc][LINE:1645]
The 0 stuck notify wait context info:(context_id=4, notify_id=3).[FUNC:ProcessStarsHcclFftsPlusTimeoutErrorInfo][FILE:device_error_proc.cc][LINE:1652]
The wait execution of the Notify register times out. Reason: The Notify register has not received the Notify record from remote rank [unknown].base information: [streamID:[4265454028], taskID[13], tag[AllReduce_100.124.104.201%eth0_60000_0_1741569404003403], AlgType(level 0-1-2):[fullmesh-ring-ring].] task information: []
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
npuSynchronizeDevice:build/CMakeFiles/torch_npu.dir/compiler_depend.ts:467 NPU function error: AclrtSynchronizeDeviceWithTimeout, error code is 507048
[ERROR] 2025-03-10-04:19:37 (PID:1185974, Device:1, RankID:1) ERR00100 PTA call acl api failed
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EI0002: [PID: 1185974] 2025-03-10-04:19:37.566.086 The wait execution of the Notify register times out. Reason: The Notify register has not received the Notify record from remote rank [unknown].base information: [streamID:[3917014364], taskID[13], tag[AllReduce_100.124.104.201%eth0_60000_0_1741569404003403], AlgType(level 0-1-2):[fullmesh-ring-ring].] task information: []
Possible Cause: 1. An exception occurs during the execution on some NPUs in the cluster. As a result, collective communication operation failed.2. The execution speed on some NPU in the cluster is too slow to complete a communication operation within the timeout interval. (default 1800s, You can set the interval by using HCCL_EXEC_TIMEOUT.)3. The number of training samples of each NPU is inconsistent.4. Packet loss or other connectivity problems occur on the communication link.
Solution: 1. If this error is reported on part of these ranks, check other ranks to see whether other errors have been reported earlier.2. If this error is reported for all ranks, check whether the error reporting time is consistent (the maximum difference must not exceed 1800s). If not, locate the cause or adjust the locate the cause or set the HCCL_EXEC_TIMEOUT environment variable to a larger value.3. Check whether the completion queue element (CQE) of the error exists in the plog(grep -rn 'error cqe'). If so, check the network connection status. (For details, see the TLS command and HCCN connectivity check examples.)4. Ensure that the number of training samples of each NPU is consistent. For details:https://www.hiascend.com/document
TraceBack (most recent call last):
The error from device(chipId:1, dieId:0), serial number is 3, hccl fftsplus task timeout occurred during task execution, stream_id:4, sq_id:4, task_id:13, stuck notify num:1, timeout:1836.[FUNC:ProcessStarsHcclFftsPlusTimeoutErrorInfo][FILE:device_error_proc.cc][LINE:1645]
The 0 stuck notify wait context info:(context_id=4, notify_id=11).[FUNC:ProcessStarsHcclFftsPlusTimeoutErrorInfo][FILE:device_error_proc.cc][LINE:1652]
The wait execution of the Notify register times out. Reason: The Notify register has not received the Notify record from remote rank [unknown].base information: [streamID:[3917014364], taskID[13], tag[AllReduce_100.124.104.201%eth0_60000_0_1741569404003403], AlgType(level 0-1-2):[fullmesh-ring-ring].] task information: []
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
:
work.wait()npuSynchronizeDevice:build/CMakeFiles/torch_npu.dir/compiler_depend.ts:467 NPU function error: AclrtSynchronizeDeviceWithTimeout, error code is 507048
[ERROR] 2025-03-10-04:19:37 (PID:1185976, Device:3, RankID:3) ERR00100 PTA call acl api failed
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EI0002: [PID: 1185976] 2025-03-10-04:19:37.588.583 The wait execution of the Notify register times out. Reason: The Notify register has not received the Notify record from remote rank [unknown].base information: [streamID:[3517536076], taskID[13], tag[AllReduce_100.124.104.201%eth0_60000_0_1741569404003403], AlgType(level 0-1-2):[fullmesh-ring-ring].] task information: []
Possible Cause: 1. An exception occurs during the execution on some NPUs in the cluster. As a result, collective communication operation failed.2. The execution speed on some NPU in the cluster is too slow to complete a communication operation within the timeout interval. (default 1800s, You can set the interval by using HCCL_EXEC_TIMEOUT.)3. The number of training samples of each NPU is inconsistent.4. Packet loss or other connectivity problems occur on the communication link.
Solution: 1. If this error is reported on part of these ranks, check other ranks to see whether other errors have been reported earlier.2. If this error is reported for all ranks, check whether the error reporting time is consistent (the maximum difference must not exceed 1800s). If not, locate the cause or adjust the locate the cause or set the HCCL_EXEC_TIMEOUT environment variable to a larger value.3. Check whether the completion queue element (CQE) of the error exists in the plog(grep -rn 'error cqe'). If so, check the network connection status. (For details, see the TLS command and HCCN connectivity check examples.)4. Ensure that the number of training samples of each NPU is consistent. For details:https://www.hiascend.com/document
TraceBack (most recent call last):
The error from device(chipId:3, dieId:0), serial number is 3, hccl fftsplus task timeout occurred during task execution, stream_id:4, sq_id:4, task_id:13, stuck notify num:1, timeout:1836.[FUNC:ProcessStarsHcclFftsPlusTimeoutErrorInfo][FILE:device_error_proc.cc][LINE:1645]
The 0 stuck notify wait context info:(context_id=4, notify_id=17).[FUNC:ProcessStarsHcclFftsPlusTimeoutErrorInfo][FILE:device_error_proc.cc][LINE:1652]
The wait execution of the Notify register times out. Reason: The Notify register has not received the Notify record from remote rank [unknown].base information: [streamID:[3517536076], taskID[13], tag[AllReduce_100.124.104.201%eth0_60000_0_1741569404003403], AlgType(level 0-1-2):[fullmesh-ring-ring].] task information: []
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
RuntimeError: npuSynchronizeDevice:build/CMakeFiles/torch_npu.dir/compiler_depend.ts:467 NPU function error: AclrtSynchronizeDeviceWithTimeout, error code is 507048
[ERROR] 2025-03-10-04:19:37 (PID:1185973, Device:0, RankID:0) ERR00100 PTA call acl api failed
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EI0002: [PID: 1185973] 2025-03-10-04:19:37.591.361 The wait execution of the Notify register times out. Reason: The Notify register has not received the Notify record from remote rank [unknown].base information: [streamID:[3366440588], taskID[13], tag[AllReduce_100.124.104.201%eth0_60000_0_1741569404003403], AlgType(level 0-1-2):[fullmesh-ring-ring].] task information: []
Possible Cause: 1. An exception occurs during the execution on some NPUs in the cluster. As a result, collective communication operation failed.2. The execution speed on some NPU in the cluster is too slow to complete a communication operation within the timeout interval. (default 1800s, You can set the interval by using HCCL_EXEC_TIMEOUT.)3. The number of training samples of each NPU is inconsistent.4. Packet loss or other connectivity problems occur on the communication link.
Solution: 1. If this error is reported on part of these ranks, check other ranks to see whether other errors have been reported earlier.2. If this error is reported for all ranks, check whether the error reporting time is consistent (the maximum difference must not exceed 1800s). If not, locate the cause or adjust the locate the cause or set the HCCL_EXEC_TIMEOUT environment variable to a larger value.3. Check whether the completion queue element (CQE) of the error exists in the plog(grep -rn 'error cqe'). If so, check the network connection status. (For details, see the TLS command and HCCN connectivity check examples.)4. Ensure that the number of training samples of each NPU is consistent. For details:https://www.hiascend.com/document
TraceBack (most recent call last):
The error from device(chipId:0, dieId:0), serial number is 3, hccl fftsplus task timeout occurred during task execution, stream_id:4, sq_id:4, task_id:13, stuck notify num:4, timeout:1836.[FUNC:ProcessStarsHcclFftsPlusTimeoutErrorInfo][FILE:device_error_proc.cc][LINE:1645]
The 0 stuck notify wait context info:(context_id=9, notify_id=14).[FUNC:ProcessStarsHcclFftsPlusTimeoutErrorInfo][FILE:device_error_proc.cc][LINE:1652]
The 1 stuck notify wait context info:(context_id=11, notify_id=24).[FUNC:ProcessStarsHcclFftsPlusTimeoutErrorInfo][FILE:device_error_proc.cc][LINE:1652]
The 2 stuck notify wait context info:(context_id=13, notify_id=20).[FUNC:ProcessStarsHcclFftsPlusTimeoutErrorInfo][FILE:device_error_proc.cc][LINE:1652]
The 3 stuck notify wait context info:(context_id=15, notify_id=18).[FUNC:ProcessStarsHcclFftsPlusTimeoutErrorInfo][FILE:device_error_proc.cc][LINE:1652]
The wait execution of the Notify register times out. Reason: The Notify register has not received the Notify record from remote rank [unknown].base information: [streamID:[3366440588], taskID[13], tag[AllReduce_100.124.104.201%eth0_60000_0_1741569404003403], AlgType(level 0-1-2):[fullmesh-ring-ring].] task information: []
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
[W compiler_depend.ts:487] Warning: NPU warning, error code is 507048[Error]:
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 1185973] 2025-03-10-04:19:38.591.420 wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function npuSynchronizeUsedDevices)
[W compiler_depend.ts:122] Warning: NPU warning, error code is 507048[Error]:
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 1185973] 2025-03-10-04:19:58.399.507 wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function empty_cache)
[W compiler_depend.ts:469] Warning: NPU warning, error code is 507048[Error]:
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 1185973] 2025-03-10-04:19:58.943.347 wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function npuSynchronizeDevice)
[W compiler_depend.ts:122] Warning: NPU warning, error code is 507048[Error]:
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 1185973] 2025-03-10-04:19:59.775.426 wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function empty_cache)
[W compiler_depend.ts:469] Warning: NPU warning, error code is 507048[Error]:
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 1185973] 2025-03-10-04:20:00.287.260 wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function npuSynchronizeDevice)
[W compiler_depend.ts:122] Warning: NPU warning, error code is 507048[Error]:
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 1185973] 2025-03-10-04:20:00.803.311 wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function empty_cache)
[W compiler_depend.ts:469] Warning: NPU warning, error code is 507048[Error]:
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 1185973] 2025-03-10-04:20:01.311.298 wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function npuSynchronizeDevice)
[W compiler_depend.ts:122] Warning: NPU warning, error code is 507048[Error]:
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 1185973] 2025-03-10-04:20:01.823.326 wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function empty_cache)
[W compiler_depend.ts:469] Warning: NPU warning, error code is 507048[Error]:
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 1185973] 2025-03-10-04:20:02.335.348 wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function npuSynchronizeDevice)
[W compiler_depend.ts:122] Warning: NPU warning, error code is 507048[Error]:
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 1185973] 2025-03-10-04:20:02.851.347 wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function empty_cache)
[W compiler_depend.ts:469] Warning: NPU warning, error code is 507048[Error]:
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 1185973] 2025-03-10-04:20:03.359.328 wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function npuSynchronizeDevice)
[W compiler_depend.ts:122] Warning: NPU warning, error code is 507048[Error]:
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 1185973] 2025-03-10-04:20:03.875.310 wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function empty_cache)
[W compiler_depend.ts:469] Warning: NPU warning, error code is 507048[Error]:
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 1185973] 2025-03-10-04:20:04.383.291 wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function npuSynchronizeDevice)
[W compiler_depend.ts:122] Warning: NPU warning, error code is 507048[Error]:
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 1185973] 2025-03-10-04:20:04.895.384 wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function empty_cache)
[W compiler_depend.ts:469] Warning: NPU warning, error code is 507048[Error]:
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 1185973] 2025-03-10-04:20:05.407.455 wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function npuSynchronizeDevice)
[W compiler_depend.ts:122] Warning: NPU warning, error code is 507048[Error]:
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 1185973] 2025-03-10-04:20:05.923.406 wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function empty_cache)
[W compiler_depend.ts:469] Warning: NPU warning, error code is 507048[Error]:
[Error]: The execution of the internal task times out.
Rectify the fault based on the error information in the ascend log.
EH9999: Inner Error!
rtDeviceSynchronizeWithTimeout execute failed, reason=[fftsplus timeout][FUNC:FuncErrorReason][FILE:error_message_manage.cc][LINE:53]
EH9999: [PID: 1185973] 2025-03-10-04:20:06.467.417 wait for compute device to finish failed, runtime result = 507048.[FUNC:ReportCallError][FILE:log_inner.cpp][LINE:161]
TraceBack (most recent call last):
(function npuSynchronizeDevice)
[2025-03-10 04:20:20,376] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1185977 closing signal SIGTERM
[2025-03-10 04:20:20,376] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1185978 closing signal SIGTERM
[2025-03-10 04:20:20,376] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1185979 closing signal SIGTERM
[2025-03-10 04:20:20,376] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1185980 closing signal SIGTERM
[2025-03-10 04:20:27,429] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 1185973) of binary: /home/jovyan/.conda/envs/model_dis/bin/python
Traceback (most recent call last):
File "/home/jovyan/.conda/envs/model_dis/bin/torchrun", line 8, in
sys.exit(main())
File "/home/jovyan/.conda/envs/model_dis/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 346, in wrapper
return f(*args, **kwargs)
File "/home/jovyan/.conda/envs/model_dis/lib/python3.10/site-packages/torch/distributed/run.py", line 806, in main
run(args)
File "/home/jovyan/.conda/envs/model_dis/lib/python3.10/site-packages/torch/distributed/run.py", line 797, in run
elastic_launch(
File "/home/jovyan/.conda/envs/model_dis/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in call
return launch_agent(self._config, self._entrypoint, list(args))
File "/home/jovyan/.conda/envs/model_dis/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
./train_minillm.py FAILED
Failures:
[1]:
time : 2025-03-10_04:20:20
host : nb-546241047869523123-0
rank : 1 (local_rank: 1)
exitcode : 1 (pid: 1185974)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[2]:
time : 2025-03-10_04:20:20
host : nb-546241047869523123-0
rank : 2 (local_rank: 2)
exitcode : 1 (pid: 1185975)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[3]:
time : 2025-03-10_04:20:20
host : nb-546241047869523123-0
rank : 3 (local_rank: 3)
exitcode : 1 (pid: 1185976)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
Root Cause (first observed failure):
[0]:
time : 2025-03-10_04:20:20
host : nb-546241047869523123-0
rank : 0 (local_rank: 0)
exitcode : 1 (pid: 1185973)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
五、日志信息:
日志太大,点击链接查看
日志文件链接
当前日志只看到了4个进程的error级别日志,请按照置顶issue提供相关日志信息,可以的话上传到gitee上方便查看
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。
已更新event日志信息,请查看:
event日志链接
请配置如下环境变量,并将之前的日志备份下,防止日志混在一起,然后运行脚本,将完整日志提供下(默认路径/root/ascend/log)
export ASCEND_GLOBAL_LOG_LEVEL=1
export ASCEND_GLOBAL_EVENT_ENABLE=1
配置环境变量后重新运行脚本,把/root/ascend/log下所有相关日志都打包了。
见链接:https://pan.quark.cn/s/45e21d44526a
已联系对应工程师分析中
当前看到4张卡超时,需要采集算子下发日志进一步定位分析
export ASCEND_GLOBAL_LOG_LEVEL=1
export HCCL_DIAGNOSE_ENABLE=1
export HCCL_ENTRY_LOG_ENABLE=1
请开启以上环境变量,运行用例,把/root/ascend/log下日志提供下(运行前请把之前的日志备份清理下,防止新旧日志混在一起)
或者参考资料排查其余卡是否忘记调用allreduce,卡死,提前退出等情况
https://www.hiascend.com/document/detail/zh/canncommercial/800/developmentguide/hccl/hcclug/hcclug_000031.html
代码中采用了模型并行+数据并行,其中(0,1,2,3)和(4,5,6,7)是两个模型并行组,(0,4),(1,5),(2,6),(3,7)是四个数据并行组,所以四卡超时可能是数据并行组出错了。
另:开启上述环境变量重新运行脚本,得到日志如链接所示:ascend_log_0320
最新日志看到0,1,2,3卡allreduce超时报错,但是用的是8卡通信域,其他卡没有调用这个allreduce算子。
请排查模型脚本,看下为啥使用8张卡的通信域,却只有4张卡调用了对应通信算子。
evaluate_ppo函数中的奖励函数teacher_rewards = self.reward_fn(input_ids, gen_ids)["rewards"]
有3处调用了all reduce算子,reward_fn具体如下,能帮忙看下为啥吗?
def reward_fn(self, input_ids: torch.Tensor, gen_ids: torch.Tensor,
inf_mask: Optional[torch.Tensor] = None, output_pos: bool = True) -> dict[str, torch.Tensor]:
# not include eos token
self.model.eval()
# input_ids = input_ids.repeat(1, 1)
model_inputs = self.get_input_batch(input_ids, gen_ids, output_pos=output_pos)
with torch.no_grad():
outputs = self.model(**model_inputs)
logits = outputs.logits # (B, L, V)
if self.args.model_parallel:
# ==================第1处调用===================
logits = logits - mpu.parallel_mean(logits.float(), dim=-1).unsqueeze(-1)
else:
logits = logits - torch.mean(logits, dim=-1, keepdim=True)
mask = model_inputs["attention_mask"]
logits = logits * mask.unsqueeze(-1) # set logits output by padding to 0
logits = logits[:, input_ids.size(-1)-1:, :]
mask = mask[:, input_ids.size(-1)-1:]
if self.args.model_parallel:
# ==================第2处调用===================
selection_value = mpu.parallel_gather(logits[:, :-1, :], -1, model_inputs["input_ids"][:, input_ids.size(-1):, None]).squeeze(-1) # 先单设备gather,后all_reduce
else:
selection_value = torch.gather(logits[:, :-1, :], -1, model_inputs["input_ids"][:, input_ids.size(-1):, None]).squeeze(-1)
current_logits = logits[:, :-1, :]
if self.args.model_parallel:
# ==================第3处调用===================
next_state_value = mpu.parallel_logsumexp(current_logits.float(), dim=-1)
else:
next_state_value = torch.logsumexp(current_logits, dim=-1)
next_state_value = next_state_value * mask[:, :-1]
scores = selection_value - next_state_value
assert all((~torch.isinf(scores.view(-1))) & (~torch.isnan(scores.view(-1))))
assert scores.size() == gen_ids.size()
return {
"rewards": scores,
"inf_mask": inf_mask
}
使用self.args.model_parallel模型并行的3个并行计算算子如下,均使用了all_reduce算子:
# mpu.parallel_mean
def parallel_mean(x, dim=-1):
# NOTE: dim is the model parallel dim
dim_size = x.size(dim)
x = torch.sum(x, dim=dim)
dist.all_reduce(x,
op=dist.ReduceOp.SUM,
group=get_model_parallel_group())
full_dim = dim_size * get_model_parallel_world_size()
x = x / full_dim
return x
# mpu.parallel_gather
def parallel_gather(logits, dim, ids):
return _ParallelGather.apply(logits, dim, ids)
class _ParallelGather(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, dim, ids):
partition_size = logits.size(dim)
rank = get_model_parallel_rank()
world_size = get_model_parallel_world_size()
get_range = VocabUtility.vocab_range_from_per_partition_vocab_size
start_index, end_index = get_range(partition_size, rank, world_size)
ids_mask = (ids < start_index) | (ids >= end_index)
masked_ids = ids - start_index
masked_ids[ids_mask] = 0
gathered_logits = torch.gather(logits, dim, masked_ids)
gathered_logits[ids_mask] = 0
dist.all_reduce(gathered_logits,
op=dist.ReduceOp.SUM,
group=get_model_parallel_group())
ctx.save_for_backward(ids_mask, masked_ids, torch.tensor([partition_size, dim], device=logits.device))
return gathered_logits
@staticmethod
def backward(ctx, grad_output):
raise NotImplementedError # not tested
# Retreive tensors from the forward path.
ids_mask, masked_ids, ints = ctx.saved_tensors
partition_size, dim = ints
size = ids_mask.size()[:-1] + (partition_size,)
grad_input = torch.zeros(size, dtype=grad_output.dtype, device=grad_output.device)
grad_input.scatter_(dim, masked_ids, (1 - ids_mask.to(grad_output.dtype)))
return grad_input, None, None
# mpu.parallel_logsumexp
def parallel_logsumexp(logits, dim=-1):
# NOTE: dim is the model parallel dim
sum_exp_x = torch.sum(torch.exp(logits), dim=dim)
dist.all_reduce(sum_exp_x,
op=dist.ReduceOp.SUM,
group=get_model_parallel_group())
log_sum_exp_x = torch.log(sum_exp_x)
return log_sum_exp_x
模型问题建议咨询模型提供方或者代码打点查看下
登录 后才可以发表评论