一、问题现象(附报错日志上下文):
deepspeed模型迁移后,单卡可以跑训练,多卡训练启动后遇到EJ0001错误,按照日志报错提示等待一段时间后再次启动,仍然报错,详细报错在附件。
二、软件版本:
-- CANN 版本 (e.g., CANN 3.0.x,5.x.x): CANN版本:7.0.RC1
--Tensorflow/Pytorch/MindSpore 版本: Pytorch 2.1.0
--Python 版本 (e.g., Python 3.7.5): Python 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] on linux
-- MindStudio版本 (e.g., MindStudio 2.0.0 (beta3)): 未安装
--操作系统版本 (e.g., Ubuntu 18.04): TencentOS Server 3.1
三、测试步骤:
最小复现脚本:train_voicebox_test.py
import argparse
import time
import os
import deepspeed
import torch
from torch.utils.data import DataLoader
import torch_npu
from omegaconf import OmegaConf
class VoiceBoxDeepspeed(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(128,128)
def forward(self,data):
x = data['x'].npu()
y = data['y'].npu()
return torch.nn.functional.mse_loss(self.linear1(x),y)
def train_voicebox(rank, cfg):
torch_npu.npu.set_device(f'npu:{rank}')
deepspeed.init_distributed(dist_backend='hccl',rank=int(os.environ['LOCAL_RANK']),world_size=int(os.environ['WORLD_SIZE']))
voicebox_model = VoiceBoxDeepspeed()
voicebox_model.train()
model_engine, _, _, _ = deepspeed.initialize(
config=dict(cfg.Voicebox.train.deepspeed), model=voicebox_model
)
step = 1
while step <= 1024:
for batch in train_loader:
data = {
'x': torch.randn(128),
'y': torch.randn(128)
}
loss = model_engine(data)
model_engine.backward(loss)
model_engine.step()
step += 1
if rank == 0:
print(
f"step: {step}, total loss: {loss.item()}"
)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("cfg", type=str, help="path to config file")
parser.add_argument("--local_rank", type=int, help="local rank")
return parser.parse_args()
def main():
args = get_args()
cfg = OmegaConf.load(args.cfg)
train_voicebox(args.local_rank, cfg)
if __name__ == "__main__":
torch.npu.set_compile_mode(jit_compile=True)
main()
训练配置文件config.yaml
Voicebox:
train:
num_workers: 8
save_interval: 10000
seed: 123456
prefetch_factor: 4
deepspeed:
gradient_accumulation_steps: 1
# train_batch_size: 32
train_micro_batch_size_per_gpu: 4
optimizer:
type: Adam
scheduler:
type: WarmupDecayLR
params:
warmup_min_lr: 0
warmup_max_lr: 1.0e-4
warmup_num_steps: 5000
total_num_steps: 1500000
warmup_type: linear
fp16:
enabled: true
auto_cast: false
loss_scale: 0
gradient_clipping: 0.2
tensorboard:
enabled: true
output_path: /root/test
bf16:
enabled: false
zero_optimization:
stage: 3
reduce_bucket_size: 500000000
offload_param:
device: cpu
启动命令:
ASCEND_GLOBAL_LOG_LEVEL=0 export HCCL_WHITELIST_DISABLE=1 HCCL_CONNECT_TIMEOUT=1200 ASCEND_LAUNCH_BLOCKING=1 deepspeed --master_port 23333 --num_gpus 2 train_voicebox_test.py config.yaml
四、日志信息:
日志搜集的wiki连接有问题,不太清楚该如何搜集日志。
在启动训练前,已保证无程序占用显卡,下为npu-smi info
的输出:
+------------------------------------------------------------------------------------------------+
| npu-smi 23.0.0 Version: 23.0.0 |
+---------------------------+---------------+----------------------------------------------------+
| NPU Name | Health | Power(W) Temp(C) Hugepages-Usage(page)|
| Chip | Bus-Id | AICore(%) Memory-Usage(MB) HBM-Usage(MB) |
+===========================+===============+====================================================+
| 0 910B2C | OK | 90.7 36 0 / 0 |
| 0 | 0000:23:00.0 | 0 0 / 0 4168 / 65536 |
+===========================+===============+====================================================+
| 1 910B2C | OK | 93.8 38 0 / 0 |
| 0 | 0000:24:00.0 | 0 0 / 0 4241 / 65536 |
+===========================+===============+====================================================+
| 2 910B2C | OK | 92.4 37 0 / 0 |
| 0 | 0000:33:00.0 | 0 0 / 0 4241 / 65536 |
+===========================+===============+====================================================+
| 3 910B2C | OK | 93.8 37 0 / 0 |
| 0 | 0000:34:00.0 | 0 0 / 0 4241 / 65536 |
+===========================+===============+====================================================+
| 4 910B2C | OK | 91.1 37 0 / 0 |
| 0 | 0000:43:00.0 | 0 0 / 0 4240 / 65536 |
+===========================+===============+====================================================+
| 5 910B2C | OK | 99.4 38 0 / 0 |
| 0 | 0000:44:00.0 | 0 0 / 0 4240 / 65536 |
+===========================+===============+====================================================+
| 6 910B2C | OK | 90.7 37 0 / 0 |
| 0 | 0000:63:00.0 | 0 0 / 0 4240 / 65536 |
+===========================+===============+====================================================+
| 7 910B2C | OK | 92.0 37 0 / 0 |
| 0 | 0000:64:00.0 | 0 0 / 0 4240 / 65536 |
+===========================+===============+====================================================+
| 8 910B2C | OK | 90.1 37 0 / 0 |
| 0 | 0000:83:00.0 | 0 0 / 0 4240 / 65536 |
+===========================+===============+====================================================+
| 9 910B2C | OK | 95.8 36 0 / 0 |
| 0 | 0000:84:00.0 | 0 0 / 0 4240 / 65536 |
+===========================+===============+====================================================+
| 10 910B2C | OK | 91.4 36 0 / 0 |
| 0 | 0000:A3:00.0 | 0 0 / 0 4240 / 65536 |
+===========================+===============+====================================================+
| 11 910B2C | OK | 89.8 39 0 / 0 |
| 0 | 0000:A4:00.0 | 0 0 / 0 4240 / 65536 |
+===========================+===============+====================================================+
| 12 910B2C | OK | 88.7 37 0 / 0 |
| 0 | 0000:C3:00.0 | 0 0 / 0 4240 / 65536 |
+===========================+===============+====================================================+
| 13 910B2C | OK | 96.9 38 0 / 0 |
| 0 | 0000:C4:00.0 | 0 0 / 0 4240 / 65536 |
+===========================+===============+====================================================+
| 14 910B2C | OK | 92.5 37 0 / 0 |
| 0 | 0000:E3:00.0 | 0 0 / 0 4240 / 65536 |
+===========================+===============+====================================================+
| 15 910B2C | OK | 93.3 38 0 / 0 |
| 0 | 0000:E4:00.0 | 0 0 / 0 4240 / 65536 |
+===========================+===============+====================================================+
+---------------------------+---------------+----------------------------------------------------+
| NPU Chip | Process id | Process name | Process memory(MB) |
+===========================+===============+====================================================+
| No running processes found in NPU 0 |
+===========================+===============+====================================================+
| No running processes found in NPU 1 |
+===========================+===============+====================================================+
| No running processes found in NPU 2 |
+===========================+===============+====================================================+
| No running processes found in NPU 3 |
+===========================+===============+====================================================+
| No running processes found in NPU 4 |
+===========================+===============+====================================================+
| No running processes found in NPU 5 |
+===========================+===============+====================================================+
| No running processes found in NPU 6 |
+===========================+===============+====================================================+
| No running processes found in NPU 7 |
+===========================+===============+====================================================+
| No running processes found in NPU 8 |
+===========================+===============+====================================================+
| No running processes found in NPU 9 |
+===========================+===============+====================================================+
| No running processes found in NPU 10 |
+===========================+===============+====================================================+
| No running processes found in NPU 11 |
+===========================+===============+====================================================+
| No running processes found in NPU 12 |
+===========================+===============+====================================================+
| No running processes found in NPU 13 |
+===========================+===============+====================================================+
| No running processes found in NPU 14 |
+===========================+===============+====================================================+
| No running processes found in NPU 15 |
+===========================+===============+====================================================+