75 Star 582 Fork 1.1K

Ascend/pytorch

Simple FSDP test fails with pytorch-npu 2.0.1-5.0.rc2

DONE
Bug-Report
创建于  
2023-08-03 11:06

一、问题现象(附报错日志上下文):

A minimum FSDP test extracted from test_fsdp.py (this commit) fails with PyTorch 2.0.1 + NPU adaptor v2.0.1-5.0.rc2 + CANN 6.3.RC2.

二、软件版本:
-- CANN 版本: CANN 6.3.RC2
-- Python 版本: 3.9.17
-- PyTorch 版本: 2.0.1 + v2.0.1-5.0.rc2 NPU adaptor

三、测试步骤:

Install CANN + PyTorch + NPU environment:

# Assume Python 3.9 environment

wget https://download.pytorch.org/whl/torch-2.0.1-cp39-cp39-manylinux2014_aarch64.whl
wget https://gitee.com/ascend/pytorch/releases/download/v5.0.rc2-pytorch2.0.1/torch_npu-2.0.1rc1-cp39-cp39-linux_aarch64.whl

pip install torch-2.0.1-cp39-cp39-manylinux2014_aarch64.whl
pip install torch_npu-2.0.1rc1-cp39-cp39-linux_aarch64.whl

wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Florence-ASL/Florence-ASL%20V100R001C30SPC813/Ascend-cann-nnae_6.3.RC2.alpha003_linux-aarch64.run
wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Florence-ASL/Florence-ASL%20V100R001C30SPC703/Ascend-cann-toolkit_6.3.RC2.alpha003_linux-aarch64.run

sh ./Ascend-cann-nnae_6.3.RC2.alpha003_linux-aarch64.run --install
sh ./Ascend-cann-toolkit_6.3.RC2.alpha003_linux-aarch64.run --install

source /usr/local/Ascend/nnae/set_env.sh
source /usr/local/Ascend/ascend-toolkit/set_env.sh

export LD_PRELOAD=/lib64/libgomp.so.1  # avoid multi-threading bug in some cases

Run the script below by torchrun --nproc_per_node 2 --nnodes 1 fsdp_npu_debug.py

import os
import functools

import torch
import torch_npu
import torch.distributed as dist

from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    ShardingStrategy
)
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(2, 2)
        self.fc2 = torch.nn.Linear(2, 2)

    def forward(self, x):
        x = self.fc1(x)
        return self.fc2(x)

def main():
    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    device = "npu"
    torch.npu.set_device(local_rank)
    dist.init_process_group(backend="hccl")
    
    my_auto_wrap_policy = functools.partial(
        size_based_auto_wrap_policy, min_num_params=100
    )
    
    model = FSDP(
        MyModule(),
        auto_wrap_policy=my_auto_wrap_policy,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        device_id=torch.device(f"npu:{local_rank}"),
    )
    
    optimizer = torch.optim.AdamW(model.parameters())

    for step in range(5):
        x = model(torch.ones(2, 2).to(device))
        loss = x.sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print(f"(rank {rank}) step {step}, loss: {loss.item()}")
            
if __name__ == '__main__':
    main()

四、日志信息:

WARNING:torch.distributed.run:
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
Traceback (most recent call last):
  File "/root/test_dir/work_code/fsdp_npu_debug.py", line 61, in <module>
    main()
  File "/root/test_dir/work_code/fsdp_npu_debug.py", line 43, in main
    model = FSDP(
  File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 408, in __init__
    _init_param_handle_from_module(
  File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 418, in _init_param_handle_from_module
    state.compute_device = _get_compute_device(
  File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 882, in _get_compute_device
    compute_device = torch.device("cuda", torch.cuda.current_device())
  File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/cuda/__init__.py", line 674, in current_device
    _lazy_init()
  File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/cuda/__init__.py", line 239, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
Traceback (most recent call last):
  File "/root/test_dir/work_code/fsdp_npu_debug.py", line 61, in <module>
    main()
  File "/root/test_dir/work_code/fsdp_npu_debug.py", line 43, in main
    model = FSDP(
  File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 408, in __init__
    _init_param_handle_from_module(
  File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 418, in _init_param_handle_from_module
    state.compute_device = _get_compute_device(
  File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 882, in _get_compute_device
    compute_device = torch.device("cuda", torch.cuda.current_device())
  File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/cuda/__init__.py", line 674, in current_device
    _lazy_init()
  File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/cuda/__init__.py", line 239, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 42742) of binary: /root/test_dir/miniconda/envs/py39_pt2_npu/bin/python
Traceback (most recent call last):
  File "/root/test_dir/miniconda/envs/py39_pt2_npu/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/run.py", line 794, in main
    run(args)
  File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
fsdp_npu_debug.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2023-08-03_10:52:28
  host      : devserver-com
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 42743)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-08-03_10:52:28
  host      : devserver-com
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 42742)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

评论 (2)

佳威 创建了Bug-Report 2年前

The error AssertionError: Torch not compiled with CUDA enabled is simply caused by torch.cuda.current_device() in fsdp/_init_utils.py, which is not replaced by the NPU equivalent torch.npu.current_device() in the 2.0.1-5.0.rc2 adaptor. I wonder if FSDP is supported & tested in 2.0.1-5.0.rc2? Would it be useful if I send a PR?

佳威 修改了描述 2年前

不支持

Destiny 任务状态TODO 修改为DONE 2年前

登录 后才可以发表评论

状态
负责人
项目
里程碑
Pull Requests
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
开始日期   -   截止日期
-
置顶选项
优先级
预计工期 (小时)
参与者(2)
8458573 jiaweizhuang 1610180855 Destiny-wx1103340
Python
1
https://gitee.com/ascend/pytorch.git
git@gitee.com:ascend/pytorch.git
ascend
pytorch
pytorch

搜索帮助