登录
注册
开源
企业版
高校版
搜索
帮助中心
使用条款
关于我们
开源
企业版
高校版
私有云
模力方舟
登录
注册
代码拉取完成,页面将自动刷新
开源项目
>
人工智能
>
机器学习/深度学习
&&
捐赠
捐赠前请先登录
取消
前往登录
扫描微信二维码支付
取消
支付完成
支付提示
将跳转至支付宝完成支付
确定
取消
Watch
不关注
关注所有动态
仅关注版本发行动态
关注但不提醒动态
88
Star
639
Fork
1.3K
Ascend
/
pytorch
代码
Issues
38
Pull Requests
378
Wiki
统计
流水线
服务
质量分析
Jenkins for Gitee
腾讯云托管
腾讯云 Serverless
悬镜安全
阿里云 SAE
Codeblitz
SBOM
我知道了,不再自动展开
更新失败,请稍后重试!
移除标识
内容风险标识
本任务被
标识为内容中包含有代码安全 Bug 、隐私泄露等敏感信息,仓库外成员不可访问
Simple FSDP test fails with pytorch-npu 2.0.1-5.0.rc2
DONE
#I7Q760
Bug-Report
佳威
创建于
2023-08-03 11:06
一、问题现象(附报错日志上下文): A minimum FSDP test extracted from [test_fsdp.py](https://gitee.com/ascend/pytorch/blob/master/test/test_fsdp/test_fsdp.py) (this [commit](https://gitee.com/ascend/pytorch/commit/5306b277649946304590c2d6916db57c6371d885)) fails with PyTorch 2.0.1 + NPU adaptor v2.0.1-5.0.rc2 + CANN 6.3.RC2. 二、软件版本: -- CANN 版本: CANN 6.3.RC2 -- Python 版本: 3.9.17 -- PyTorch 版本: 2.0.1 + v2.0.1-5.0.rc2 NPU adaptor 三、测试步骤: Install CANN + PyTorch + NPU environment: ```bash # Assume Python 3.9 environment wget https://download.pytorch.org/whl/torch-2.0.1-cp39-cp39-manylinux2014_aarch64.whl wget https://gitee.com/ascend/pytorch/releases/download/v5.0.rc2-pytorch2.0.1/torch_npu-2.0.1rc1-cp39-cp39-linux_aarch64.whl pip install torch-2.0.1-cp39-cp39-manylinux2014_aarch64.whl pip install torch_npu-2.0.1rc1-cp39-cp39-linux_aarch64.whl wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Florence-ASL/Florence-ASL%20V100R001C30SPC813/Ascend-cann-nnae_6.3.RC2.alpha003_linux-aarch64.run wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Florence-ASL/Florence-ASL%20V100R001C30SPC703/Ascend-cann-toolkit_6.3.RC2.alpha003_linux-aarch64.run sh ./Ascend-cann-nnae_6.3.RC2.alpha003_linux-aarch64.run --install sh ./Ascend-cann-toolkit_6.3.RC2.alpha003_linux-aarch64.run --install source /usr/local/Ascend/nnae/set_env.sh source /usr/local/Ascend/ascend-toolkit/set_env.sh export LD_PRELOAD=/lib64/libgomp.so.1 # avoid multi-threading bug in some cases ``` Run the script below by `torchrun --nproc_per_node 2 --nnodes 1 fsdp_npu_debug.py` ```python import os import functools import torch import torch_npu import torch.distributed as dist from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, ShardingStrategy ) from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.fc1 = torch.nn.Linear(2, 2) self.fc2 = torch.nn.Linear(2, 2) def forward(self, x): x = self.fc1(x) return self.fc2(x) def main(): rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) device = "npu" torch.npu.set_device(local_rank) dist.init_process_group(backend="hccl") my_auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=100 ) model = FSDP( MyModule(), auto_wrap_policy=my_auto_wrap_policy, sharding_strategy=ShardingStrategy.FULL_SHARD, device_id=torch.device(f"npu:{local_rank}"), ) optimizer = torch.optim.AdamW(model.parameters()) for step in range(5): x = model(torch.ones(2, 2).to(device)) loss = x.sum() loss.backward() optimizer.step() optimizer.zero_grad() print(f"(rank {rank}) step {step}, loss: {loss.item()}") if __name__ == '__main__': main() ``` 四、日志信息: ``` WARNING:torch.distributed.run: ***************************************** Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. ***************************************** Traceback (most recent call last): File "/root/test_dir/work_code/fsdp_npu_debug.py", line 61, in <module> main() File "/root/test_dir/work_code/fsdp_npu_debug.py", line 43, in main model = FSDP( File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 408, in __init__ _init_param_handle_from_module( File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 418, in _init_param_handle_from_module state.compute_device = _get_compute_device( File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 882, in _get_compute_device compute_device = torch.device("cuda", torch.cuda.current_device()) File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/cuda/__init__.py", line 674, in current_device _lazy_init() File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/cuda/__init__.py", line 239, in _lazy_init raise AssertionError("Torch not compiled with CUDA enabled") AssertionError: Torch not compiled with CUDA enabled Traceback (most recent call last): File "/root/test_dir/work_code/fsdp_npu_debug.py", line 61, in <module> main() File "/root/test_dir/work_code/fsdp_npu_debug.py", line 43, in main model = FSDP( File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 408, in __init__ _init_param_handle_from_module( File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 418, in _init_param_handle_from_module state.compute_device = _get_compute_device( File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 882, in _get_compute_device compute_device = torch.device("cuda", torch.cuda.current_device()) File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/cuda/__init__.py", line 674, in current_device _lazy_init() File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/cuda/__init__.py", line 239, in _lazy_init raise AssertionError("Torch not compiled with CUDA enabled") AssertionError: Torch not compiled with CUDA enabled ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 42742) of binary: /root/test_dir/miniconda/envs/py39_pt2_npu/bin/python Traceback (most recent call last): File "/root/test_dir/miniconda/envs/py39_pt2_npu/bin/torchrun", line 8, in <module> sys.exit(main()) File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper return f(*args, **kwargs) File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/run.py", line 794, in main run(args) File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/run.py", line 785, in run elastic_launch( File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 134, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError: ============================================================ fsdp_npu_debug.py FAILED ------------------------------------------------------------ Failures: [1]: time : 2023-08-03_10:52:28 host : devserver-com rank : 1 (local_rank: 1) exitcode : 1 (pid: 42743) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html ------------------------------------------------------------ Root Cause (first observed failure): [0]: time : 2023-08-03_10:52:28 host : devserver-com rank : 0 (local_rank: 0) exitcode : 1 (pid: 42742) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html ============================================================ ```
一、问题现象(附报错日志上下文): A minimum FSDP test extracted from [test_fsdp.py](https://gitee.com/ascend/pytorch/blob/master/test/test_fsdp/test_fsdp.py) (this [commit](https://gitee.com/ascend/pytorch/commit/5306b277649946304590c2d6916db57c6371d885)) fails with PyTorch 2.0.1 + NPU adaptor v2.0.1-5.0.rc2 + CANN 6.3.RC2. 二、软件版本: -- CANN 版本: CANN 6.3.RC2 -- Python 版本: 3.9.17 -- PyTorch 版本: 2.0.1 + v2.0.1-5.0.rc2 NPU adaptor 三、测试步骤: Install CANN + PyTorch + NPU environment: ```bash # Assume Python 3.9 environment wget https://download.pytorch.org/whl/torch-2.0.1-cp39-cp39-manylinux2014_aarch64.whl wget https://gitee.com/ascend/pytorch/releases/download/v5.0.rc2-pytorch2.0.1/torch_npu-2.0.1rc1-cp39-cp39-linux_aarch64.whl pip install torch-2.0.1-cp39-cp39-manylinux2014_aarch64.whl pip install torch_npu-2.0.1rc1-cp39-cp39-linux_aarch64.whl wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Florence-ASL/Florence-ASL%20V100R001C30SPC813/Ascend-cann-nnae_6.3.RC2.alpha003_linux-aarch64.run wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/Florence-ASL/Florence-ASL%20V100R001C30SPC703/Ascend-cann-toolkit_6.3.RC2.alpha003_linux-aarch64.run sh ./Ascend-cann-nnae_6.3.RC2.alpha003_linux-aarch64.run --install sh ./Ascend-cann-toolkit_6.3.RC2.alpha003_linux-aarch64.run --install source /usr/local/Ascend/nnae/set_env.sh source /usr/local/Ascend/ascend-toolkit/set_env.sh export LD_PRELOAD=/lib64/libgomp.so.1 # avoid multi-threading bug in some cases ``` Run the script below by `torchrun --nproc_per_node 2 --nnodes 1 fsdp_npu_debug.py` ```python import os import functools import torch import torch_npu import torch.distributed as dist from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, ShardingStrategy ) from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.fc1 = torch.nn.Linear(2, 2) self.fc2 = torch.nn.Linear(2, 2) def forward(self, x): x = self.fc1(x) return self.fc2(x) def main(): rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) device = "npu" torch.npu.set_device(local_rank) dist.init_process_group(backend="hccl") my_auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=100 ) model = FSDP( MyModule(), auto_wrap_policy=my_auto_wrap_policy, sharding_strategy=ShardingStrategy.FULL_SHARD, device_id=torch.device(f"npu:{local_rank}"), ) optimizer = torch.optim.AdamW(model.parameters()) for step in range(5): x = model(torch.ones(2, 2).to(device)) loss = x.sum() loss.backward() optimizer.step() optimizer.zero_grad() print(f"(rank {rank}) step {step}, loss: {loss.item()}") if __name__ == '__main__': main() ``` 四、日志信息: ``` WARNING:torch.distributed.run: ***************************************** Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. ***************************************** Traceback (most recent call last): File "/root/test_dir/work_code/fsdp_npu_debug.py", line 61, in <module> main() File "/root/test_dir/work_code/fsdp_npu_debug.py", line 43, in main model = FSDP( File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 408, in __init__ _init_param_handle_from_module( File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 418, in _init_param_handle_from_module state.compute_device = _get_compute_device( File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 882, in _get_compute_device compute_device = torch.device("cuda", torch.cuda.current_device()) File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/cuda/__init__.py", line 674, in current_device _lazy_init() File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/cuda/__init__.py", line 239, in _lazy_init raise AssertionError("Torch not compiled with CUDA enabled") AssertionError: Torch not compiled with CUDA enabled Traceback (most recent call last): File "/root/test_dir/work_code/fsdp_npu_debug.py", line 61, in <module> main() File "/root/test_dir/work_code/fsdp_npu_debug.py", line 43, in main model = FSDP( File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 408, in __init__ _init_param_handle_from_module( File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 418, in _init_param_handle_from_module state.compute_device = _get_compute_device( File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/fsdp/_init_utils.py", line 882, in _get_compute_device compute_device = torch.device("cuda", torch.cuda.current_device()) File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/cuda/__init__.py", line 674, in current_device _lazy_init() File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/cuda/__init__.py", line 239, in _lazy_init raise AssertionError("Torch not compiled with CUDA enabled") AssertionError: Torch not compiled with CUDA enabled ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 42742) of binary: /root/test_dir/miniconda/envs/py39_pt2_npu/bin/python Traceback (most recent call last): File "/root/test_dir/miniconda/envs/py39_pt2_npu/bin/torchrun", line 8, in <module> sys.exit(main()) File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper return f(*args, **kwargs) File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/run.py", line 794, in main run(args) File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/run.py", line 785, in run elastic_launch( File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 134, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) File "/root/test_dir/miniconda/envs/py39_pt2_npu/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError: ============================================================ fsdp_npu_debug.py FAILED ------------------------------------------------------------ Failures: [1]: time : 2023-08-03_10:52:28 host : devserver-com rank : 1 (local_rank: 1) exitcode : 1 (pid: 42743) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html ------------------------------------------------------------ Root Cause (first observed failure): [0]: time : 2023-08-03_10:52:28 host : devserver-com rank : 0 (local_rank: 0) exitcode : 1 (pid: 42742) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html ============================================================ ```
评论 (
2
)
登录
后才可以发表评论
状态
DONE
TODO
Analysing
ACCEPTED
WIP
Feedback
TEST
DONE
REJECTED
负责人
未设置
标签
未设置
项目
未立项任务
未立项任务
里程碑
未关联里程碑
未关联里程碑
Pull Requests
未关联
未关联
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
未关联
分支 (79)
标签 (179)
master
v2.8.0
v2.7.1
v2.1.0
v2.5.1
v2.6.0
v2.6.0-7.1.0
v2.5.1-7.1.0
v2.1.0-7.1.0
revert-merge-23967-master
revert-merge-23966-v2.8.0
revert-merge-23965-v2.7.1
revert-merge-23964-v2.6.0
revert-merge-23962-v2.5.1
revert-merge-23789-v2.1.0
v2.1.0-7.0.0
v2.4.0-7.0.0
v2.4.0
v2.3.1
v2.3.1-7.0.0
v2.5.1-7.0.0
v2.4.0-6.0.0
v2.3.1-6.0.0
v2.1.0-6.0.0
v2.1.0-6.0.rc3
v2.3.1-6.0.rc3
v2.4.0-6.0.rc3
v2.2.0
v1.11.0-6.0.rc1
v2.1.0-6.0.rc1
v2.2.0-6.0.rc1
v1.11.0-6.0.rc2
v2.1.0-6.0.rc2
v2.2.0-6.0.rc2
v2.3.1-6.0.rc2
v1.11.0
v2.1.0-5.0.0
v2.0.1-5.0.0
v1.11.0-5.0.0
v2.0.1
v2.1.0-5.0.rc3
v1.11.0-5.0.rc3
v2.0.1-5.0.rc3
v1.11.0-5.0.rc3.3
v1.8.1
v1.11.0-x1
v1.8.1-5.0.rc3
v1.11.0-5.0.rc2.2
v1.11.0-zj
v1.11.0-5.0.rc2.1
v2.0.1-5.0.rc2
v1.11.0-5.0.rc2
v1.8.1-5.0.rc2
v2.0.0-5.0.rc2
v1.8.1-5.0.rc1
v1.11.0-5.0.rc1
v1.11.0-yd
v1.11.0-xf
v1.11.0-infer
v1.11.0-bigkernel
v1.11.0-host_api
v1.8.1-3.0.0
v1.11.0-5.0.rc2.t100
v1.8.1-5.0.rc2.t100
v1.8.1-3.0.0-dev
v1.11.0-3.0.0
v2.0-dev
v1.8.1-3.0.rc3
v1.5.0-3.0.0
v1.5.0
v1.8.1-3.0.rc1
v1.11.0-3.0.rc3
v1.8.1-3.0.rc2
v1.5.0-3.0.rc3
v1.5.0-3.0.rc2
2.0.4.tr5
v1.5.0-3.0.rc1
2.0.2.tr5
2.0.3.tr5
v7.0.0.1-pytorch2.4.0
v7.0.0.1-pytorch2.1.0
v7.2.RC1.alpha001-pytorch2.8.0
v7.2.RC1.alpha001-pytorch2.7.1
v7.2.RC1.alpha001-pytorch2.6.0
v7.2.RC1.alpha001-pytorch2.5.1
v7.2.RC1.alpha001-pytorch2.1.0
v7.1.0.1-pytorch2.6.0
v7.1.0.1-pytorch2.5.1
v7.1.0.1-pytorch2.1.0
v7.1.0-pytorch2.6.0
v7.1.0-pytorch2.5.1
v7.1.0-pytorch2.1.0
v7.1.RC1.alpha003-pytorch2.6.0
v7.1.RC1.alpha003-pytorch2.5.1
v7.1.RC1.alpha003-pytorch2.1.0
v7.1.RC1.alpha002-pytorch2.7.1
v7.1.RC1.alpha002-pytorch2.6.0
v7.1.RC1.alpha002-pytorch2.5.1
v7.1.RC1.alpha002-pytorch2.4.0
v7.1.RC1.alpha002-pytorch2.3.1
v7.1.RC1.alpha002-pytorch2.1.0
v6.0.0.1-pytorch2.4.0
v6.0.0.1-pytorch2.3.1
v6.0.0.1-pytorch2.1.0
v7.1.RC1.alpha001-pytorch2.6.0
v7.1.RC1.alpha001-pytorch2.5.1
v7.1.RC1.alpha001-pytorch2.4.0
v7.1.RC1.alpha001-pytorch2.3.1
v7.1.RC1.alpha001-pytorch2.1.0
v7.0.0-pytorch2.5.1
v7.0.0-pytorch2.4.0
v7.0.0-pytorch2.3.1
v7.0.RC1.alpha002-pytorch2.6.0
v7.0.0-pytorch2.1.0
v7.0.RC1.alpha002-pytorch2.5.1
v7.0.RC1.alpha002-pytorch2.4.0
v7.0.RC1.alpha002-pytorch2.3.1
v7.0.RC1.alpha002-pytorch2.1.0
v7.0.RC1.alpha001-pytorch2.5.1
v7.0.RC1.alpha001-pytorch2.1.0
v7.0.RC1.alpha001-pytorch2.4.0
v7.0.RC1.alpha001-pytorch2.3.1
v6.0.0-pytorch2.4.0
v6.0.0-pytorch2.3.1
v6.0.0-pytorch2.1.0
v6.0.0.alpha003-pytorch2.4.0
v6.0.0.alpha003-pytorch2.3.1
v6.0.0.alpha003-pytorch2.1.0
v6.0.0.alpha002-pytorch2.4.0
v6.0.0.alpha002-pytorch2.3.1
v6.0.0.alpha002-pytorch2.1.0
v6.0.0.alpha001-pytorch2.5.1
v6.0.rc3-pytorch2.4.0
v6.0.rc3-pytorch2.3.1
v6.0.rc3-pytorch2.1.0
v6.0.0.alpha001-pytorch2.4.0
v6.0.0.alpha001-pytorch2.3.1
v6.0.0.alpha001-pytorch2.1.0
v6.0.rc2.1-pytorch1.11.0
v6.0.rc2.1-pytorch2.3.1
v6.0.rc2.1-pytorch2.2.0
v6.0.rc2.1-pytorch2.1.0
v6.0.rc3.alpha003-pytorch2.3.1
v6.0.rc3.alpha003-pytorch2.1.0
v6.0.rc3.alpha001-pytorch2.4.0
v6.0.rc3.alpha002-pytorch2.3.1
v6.0.rc3.alpha002-pytorch2.2.0
v6.0.rc3.alpha002-pytorch2.1.0
v6.0.rc3.alpha002-pytorch1.11.0
v6.0.rc2-pytorch2.1.0
v6.0.rc2-pytorch2.3.1
v6.0.rc2-pytorch2.2.0
v6.0.rc2-pytorch1.11.0
v6.0.rc3.alpha001-pytorch2.3.1
v6.0.rc3.alpha001-pytorch2.2.0
v6.0.rc3.alpha001-pytorch2.1.0
v6.0.rc3.alpha001-pytorch1.11.0
v6.0.rc2.alpha002-pytorch2.3.1
v6.0.rc2.alpha003-pytorch1.11.0
v6.0.rc2.alpha003-pytorch2.2.0
v6.0.rc2.alpha003-pytorch2.1.0
v6.0.rc1.1-pytorch2.2.0
v6.0.rc1.1-pytorch2.1.0
v6.0.rc1.1-pytorch1.11.0
v5.0.1.2-pytorch1.11.0
v5.0.1.2-pytorch2.1.0
v5.0.1.2-pytorch2.0.1
v6.0.rc2.alpha002-pytorch2.2.0
v6.0.rc2.alpha002-pytorch2.1.0
v6.0.rc2.alpha002-pytorch1.11.0
v6.0.rc1-pytorch2.2.0
v6.0.rc1-pytorch2.1.0
v6.0.rc1-pytorch1.11.0
v6.0.rc2.alpha001-pytorch2.2.0
v6.0.rc2.alpha001-pytorch2.1.0
v6.0.rc2.alpha001-pytorch1.11.0
v6.0.rc1.alpha003-pytorch2.0.1
v6.0.rc1.alpha003-pytorch2.1.0
v5.0.1.1-pytorch2.0.1
v5.0.1.1-pytorch1.11.0
v5.0.1.1-pytorch2.1.0
v6.0.rc1.alpha003-pytorch1.11.0
v6.0.rc1.alpha002-pytorch2.1.0
v6.0.rc1.alpha002-pytorch1.11.0
v6.0.rc1.alpha002-pytorch2.0.1
v6.0.rc1.alpha001-pytorch2.2.0
v5.0.1-pytorch2.1.0
v5.0.1-pytorch2.0.1
v5.0.1-pytorch1.11.0
v6.0.RC1.alpha001-pytorch2.0.1
v6.0.RC1.alpha001-pytorch2.1.0
v6.0.RC1.alpha001-pytorch1.11.0
v5.0.0-pytorch2.1.0
v5.0.0-pytorch2.0.1
v5.0.0-pytorch1.11.0
v5.0.0.alpha003-pytorch2.1.0
v5.0.0.alpha003-pytorch2.0.1
v5.0.0.alpha003-pytorch1.11.0
v5.0.rc3.3-pytorch1.11.0
v5.0.rc3.2-pytorch1.11.0
v5.0.0.alpha002-pytorch2.1.0
v5.0.0.alpha002-pytorch2.0.1
v5.0.0.alpha002-pytorch1.11.0
v5.0.rc3.1-pytorch1.11.0
v5.0.0.alpha001-pytorch2.1.0
v5.0.0.alpha001-pytorch2.0.1
v5.0.0.alpha001-pytorch1.11.0
v5.0.rc3-pytorch2.1.0
v5.0.rc3-pytorch2.0.1
v5.0.rc3-pytorch1.11.0
v5.0.rc3.alpha003-pytorch2.0.1
v5.0.rc3.alpha003-pytorch1.11.0
v5.0.rc3.alpha003-pytorch1.8.1
v5.0.rc2.2-pytorch1.11.0
v5.0.rc2.1-pytorch1.11.0
v5.0.rc3.alpha002-pytorch2.0.1
v5.0.rc3.alpha002-pytorch1.11.0
v5.0.rc3.alpha002-pytorch1.8.1
v5.0.rc2-pytorch2.0.1
v5.0.rc2-pytorch1.11.0
v5.0.rc2-pytorch1.8.1
v5.0.rc3.alpha001-pytorch1.8.1
v5.0.rc3.alpha001-pytorch1.11.0
v5.0.rc2.alpha003-pytorch1.11.0
v5.0.rc2.alpha003-pytorch1.8.1
v5.0.rc2.alpha002-pytorch1.11.0
v5.0.rc2.alpha002-pytorch1.8.1
v5.0.rc1.alpha003-pytorch1.11.0
v5.0.rc1.alpha003-pytorch1.8.1
v5.0.rc1-pytorch1.11.0
v5.0.rc1-pytorch1.8.1
v5.0.rc1.alpha002-pytorch1.11.0
v5.0.rc1.alpha002-pytorch1.8.1
v5.0.rc1.alpha001-pytorch1.11.0
v5.0.rc1.alpha001-pytorch1.8.1
v3.0.0-pytorch1.11.0
v3.0.0-pytorch1.8.1
v3.0.0-pytorch1.5.0
v3.0.alpha006-pytorch1.8.1
v3.0.alpha005-pytorch1.8.1
v3.0.alpha003-pytorch1.8.1
v3.0.rc3-pytorch1.11.0
v3.0.rc3-pytorch1.8.1
v3.0.rc3-pytorch1.5.0
v3.0.rc2-pytorch1.8.1
v3.0.rc2-pytorch1.5.0
v3.0.rc1-pytorch1.8.1
v3.0.rc1-pytorch1.5.0
v2.0.4
v2.0.4-rc2
v2.0.4-rc1
v2.0.3.1
v2.0.3
v2.0.3-rc4
v2.0.3-rc3
v2.0.3-rc2
v2.0.3-rc1
v2.0.2
开始日期   -   截止日期
-
置顶选项
不置顶
置顶等级:高
置顶等级:中
置顶等级:低
优先级
不指定
严重
主要
次要
不重要
预计工期
(小时)
参与者(1)
Python
1
https://gitee.com/ascend/pytorch.git
git@gitee.com:ascend/pytorch.git
ascend
pytorch
pytorch
点此查找更多帮助
搜索帮助
Git 命令在线学习
如何在 Gitee 导入 GitHub 仓库
Git 仓库基础操作
企业版和社区版功能对比
SSH 公钥设置
如何处理代码冲突
仓库体积过大,如何减小?
如何找回被删除的仓库数据
Gitee 产品配额说明
GitHub仓库快速导入Gitee及同步更新
什么是 Release(发行版)
将 PHP 项目自动发布到 packagist.org
仓库举报
回到顶部
登录提示
该操作需登录 Gitee 帐号,请先登录后再操作。
立即登录
没有帐号,去注册