75 Star 582 Fork 1.1K

Ascend/pytorch

LoRA训练Mistral失败

Analysing
Bug-Report
创建于  
2023-11-20 15:34

一、问题现象(附报错日志上下文):
使用fastchat的LoRA训练Mistral,报错NotImplementedError: Could not run 'npu::npu_format_cast' with arguments from the 'CPU' backend.

二、软件版本:
-- CANN 版本 : 社区版Ascend-cann-toolkit_7.0.0.alpha001_linux-aarch64.run
--Tensorflow/Pytorch/MindSpore 版本: torch2.1.0
--Python 版本 (e.g., Python 3.7.5): 3.9.18
--操作系统版本 (e.g., Ubuntu 18.04): ubuntu 20.04
--卡:910B

三、测试步骤:

git clone https://github.com/lm-sys/FastChat.git
cd FastChat
pip install git+https://github.com/microsoft/DeepSpeed.git@master
pip install git+https://github.com/huggingface/accelerate.git@main
deepspeed fastchat/train/train_lora.py     --model_name_or_path /opt/big_models/Mistral-7B-v0.1/      --lora_r 8     --lora_alpha 16     --lora_dropout 0.05     --data_path /opt/nlp_data/evol-instruct-chinese-subset.json     --output_dir ./checkpoints     --num_train_epochs 2     --per_device_train_batch_size 8     --per_device_eval_batch_size 2     --gradient_accumulation_steps 1     --evaluation_strategy "no"     --save_strategy "steps"     --save_steps 1200     --save_total_limit 100     --learning_rate 2e-5     --weight_decay 0.     --warmup_ratio 0.03     --lr_scheduler_type "cosine"     --logging_steps 1     --model_max_length 1024     --deepspeed playground/deepspeed_config_s2.json

四、日志信息:

root@ascend910b-01:/opt/projects/FastChat# deepspeed fastchat/train/train_lora.py     --model_name_or_path /opt/big_models/Mistral-7B-v0.1/      --lora_r 8     --lora_alpha 16     --lora_dropout 0.05     --data_path /opt/nlp_data/evol-instruct-chinese-subset.json     --output_dir ./checkpoints     --num_train_epochs 2     --per_device_train_batch_size 8     --per_device_eval_batch_size 2     --gradient_accumulation_steps 1     --evaluation_strategy "no"     --save_strategy "steps"     --save_steps 1200     --save_total_limit 100     --learning_rate 2e-5     --weight_decay 0.     --warmup_ratio 0.03     --lr_scheduler_type "cosine"     --logging_steps 1     --model_max_length 1024     --deepspeed playground/deepspeed_config_s2.json
[2023-11-20 06:36:58,072] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to npu (auto detect)
[2023-11-20 06:36:58,938] [WARNING] [runner.py:203:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
[2023-11-20 06:36:58,938] [INFO] [runner.py:570:main] cmd = /root/miniconda/envs/torch_npu/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMCwgMV19 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None fastchat/train/train_lora.py --model_name_or_path /opt/big_models/Mistral-7B-v0.1/ --lora_r 8 --lora_alpha 16 --lora_dropout 0.05 --data_path /opt/nlp_data/evol-instruct-chinese-subset.json --output_dir ./checkpoints --num_train_epochs 2 --per_device_train_batch_size 8 --per_device_eval_batch_size 2 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 1200 --save_total_limit 100 --learning_rate 2e-5 --weight_decay 0. --warmup_ratio 0.03 --lr_scheduler_type cosine --logging_steps 1 --model_max_length 1024 --deepspeed playground/deepspeed_config_s2.json
[2023-11-20 06:37:03,366] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to npu (auto detect)
[2023-11-20 06:37:04,240] [INFO] [launch.py:145:main] WORLD INFO DICT: {'localhost': [0, 1]}
[2023-11-20 06:37:04,240] [INFO] [launch.py:151:main] nnodes=1, num_local_procs=2, node_rank=0
[2023-11-20 06:37:04,241] [INFO] [launch.py:162:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0, 1]})
[2023-11-20 06:37:04,241] [INFO] [launch.py:163:main] dist_world_size=2
[2023-11-20 06:37:04,241] [INFO] [launch.py:165:main] Setting CUDA_VISIBLE_DEVICES=0,1
[2023-11-20 06:37:08,369] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to npu (auto detect)
[2023-11-20 06:37:08,653] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to npu (auto detect)
/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/transformers/deepspeed.py:23: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
  warnings.warn(
[2023-11-20 06:37:09,191] [WARNING] [comm.py:163:init_deepspeed_backend] HCCL backend in DeepSpeed not yet implemented
[2023-11-20 06:37:09,191] [INFO] [comm.py:637:init_distributed] cdb=None
[2023-11-20 06:37:09,191] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend hccl
/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/transformers/deepspeed.py:23: FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations
  warnings.warn(
[2023-11-20 06:37:09,592] [WARNING] [comm.py:163:init_deepspeed_backend] HCCL backend in DeepSpeed not yet implemented
[2023-11-20 06:37:09,592] [INFO] [comm.py:637:init_distributed] cdb=None
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:28<00:00, 14.47s/it]
trainable params: 3,407,872 || all params: 7,245,139,968 || trainable%: 0.04703666202518836
/opt/projects/FastChat/fastchat/train/train.py:235: ResourceWarning: unclosed file <_io.TextIOWrapper name='/opt/nlp_data/evol-instruct-chinese-subset.json' mode='r' encoding='UTF-8'>
  train_json = json.load(open(data_args.data_path, "r"))
Detected kernel version 4.19.90, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Loading checkpoint shards:   0%|                                                                                                                            | 0/2 [00:00<?, ?it/s]Using /root/.cache/torch_extensions/py39_cpu as PyTorch extensions root...
Creating extension directory /root/.cache/torch_extensions/py39_cpu/cpu_adam...
Emitting ninja build file /root/.cache/torch_extensions/py39_cpu/cpu_adam/build.ninja...
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:23<00:00, 11.68s/it]
/opt/projects/FastChat/fastchat/train/train.py:235: ResourceWarning: unclosed file <_io.TextIOWrapper name='/opt/nlp_data/evol-instruct-chinese-subset.json' mode='r' encoding='UTF-8'>
  train_json = json.load(open(data_args.data_path, "r"))
Using /root/.cache/torch_extensions/py39_cpu as PyTorch extensions root...
[1/3] c++ -MMD -MF cpu_adam.o.d -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/usr/local/Ascend/ascend-toolkit/latest/include -I/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/torch_npu/include -I/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/deepspeed/ops/csrc/includes -isystem /root/miniconda/envs/torch_npu/lib/python3.9/site-packages/torch/include -isystem /root/miniconda/envs/torch_npu/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -isystem /root/miniconda/envs/torch_npu/lib/python3.9/site-packages/torch/include/TH -isystem /root/miniconda/envs/torch_npu/lib/python3.9/site-packages/torch/include/THC -isystem /root/miniconda/envs/torch_npu/include/python3.9 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -O3 -std=c++17 -g -Wno-reorder -fopenmp -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -Wl,--disable-new-dtags,--rpath -D__ENABLE_CANN__ -march=native -D__SCALAR__ -L/usr/local/Ascend/ascend-toolkit/latest/lib64 -L/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/torch_npu/lib -c /root/miniconda/envs/torch_npu/lib/python3.9/site-packages/deepspeed/ops/csrc/adam/cpu_adam.cpp -o cpu_adam.o 
[2/3] c++ -MMD -MF cpu_adam_impl.o.d -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/usr/local/Ascend/ascend-toolkit/latest/include -I/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/torch_npu/include -I/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/deepspeed/ops/csrc/includes -isystem /root/miniconda/envs/torch_npu/lib/python3.9/site-packages/torch/include -isystem /root/miniconda/envs/torch_npu/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -isystem /root/miniconda/envs/torch_npu/lib/python3.9/site-packages/torch/include/TH -isystem /root/miniconda/envs/torch_npu/lib/python3.9/site-packages/torch/include/THC -isystem /root/miniconda/envs/torch_npu/include/python3.9 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -O3 -std=c++17 -g -Wno-reorder -fopenmp -fstack-protector-all -Wl,-z,relro,-z,now,-z,noexecstack -Wl,--disable-new-dtags,--rpath -D__ENABLE_CANN__ -march=native -D__SCALAR__ -L/usr/local/Ascend/ascend-toolkit/latest/lib64 -L/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/torch_npu/lib -c /root/miniconda/envs/torch_npu/lib/python3.9/site-packages/deepspeed/ops/csrc/adam/cpu_adam_impl.cpp -o cpu_adam_impl.o 
[3/3] c++ cpu_adam.o cpu_adam_impl.o -shared -L/usr/local/Ascend/ascend-toolkit/latest/lib64 -lascendcl -L/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/torch_npu/lib -ltorch_npu -L/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/torch/lib -lc10 -ltorch_cpu -ltorch -ltorch_python -o cpu_adam.so
Loading extension module cpu_adam...
Time to load cpu_adam op: 68.06045770645142 seconds
Loading extension module cpu_adam...
Time to load cpu_adam op: 33.589106798172 seconds
param_group
param_group
Traceback (most recent call last):
  File "/opt/projects/FastChat/fastchat/train/train_lora.py", line 223, in <module>
    train()
  File "/opt/projects/FastChat/fastchat/train/train_lora.py", line 199, in train
    trainer.train()
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/transformers/trainer.py", line 1555, in train
    return inner_training_loop(
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/transformers/trainer.py", line 1689, in _inner_training_loop
    model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/accelerate/accelerator.py", line 1282, in prepare
    result = self._prepare_deepspeed(*args)
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/accelerate/accelerator.py", line 1663, in _prepare_deepspeed
    engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/deepspeed/__init__.py", line 171, in initialize
    engine = DeepSpeedEngine(args=args,
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 304, in __init__
    self._configure_optimizer(optimizer, model_parameters)
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1222, in _configure_optimizer
    self.optimizer = self._configure_zero_optimizer(basic_optimizer)
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1483, in _configure_zero_optimizer
    optimizer = DeepSpeedZeroOptimizer(
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 356, in __init__
    self._update_model_bit16_weights(i)
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 591, in _update_model_bit16_weights
    updated_params = self.unflatten(self.bit16_groups_flat[group_index],
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/torch/_utils.py", line 534, in _unflatten_dense_tensors
    return torch._C._nn.unflatten_dense_tensors(flat, tensors)
NotImplementedError: Could not run 'npu::npu_format_cast' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'npu::npu_format_cast' is only available for these backends: [PrivateUse1, SparsePrivateUse1, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradPrivateUse1, AutogradMeta, Tracer, AutocastCPU, AutocastCUDA, AutocastPrivateUse1, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

PrivateUse1: registered at /usr1/02/workspace/j_vqN6BFvg/pytorch/torch_npu/csrc/aten/CustomRegisterSchema.cpp:977 [kernel]
SparsePrivateUse1: registered at /usr1/02/workspace/j_vqN6BFvg/pytorch/torch_npu/csrc/aten/VariableFallbackKernel.cpp:70 [backend fallback]
BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:153 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at /pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:498 [backend fallback]
Functionalize: registered at /pytorch/aten/src/ATen/FunctionalizeFallbackKernel.cpp:290 [backend fallback]
Named: registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at /pytorch/aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at /pytorch/aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]
ZeroTensor: registered at /pytorch/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:53 [backend fallback]
AutogradCPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:57 [backend fallback]
AutogradCUDA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:65 [backend fallback]
AutogradXLA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:69 [backend fallback]
AutogradMPS: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:77 [backend fallback]
AutogradXPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:61 [backend fallback]
AutogradHPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:90 [backend fallback]
AutogradLazy: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:73 [backend fallback]
AutogradPrivateUse1: fallthrough registered at /usr1/02/workspace/j_vqN6BFvg/pytorch/torch_npu/csrc/aten/VariableFallbackKernel.cpp:36 [backend fallback]
AutogradMeta: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:81 [backend fallback]
Tracer: registered at /pytorch/torch/csrc/autograd/TraceTypeManual.cpp:296 [backend fallback]
AutocastCPU: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:382 [backend fallback]
AutocastCUDA: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:249 [backend fallback]
AutocastPrivateUse1: fallthrough registered at /usr1/02/workspace/j_vqN6BFvg/pytorch/torch_npu/csrc/aten/AutoCastOps.cpp:19 [backend fallback]
FuncTorchBatched: registered at /pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:710 [backend fallback]
FuncTorchVmapMode: fallthrough registered at /pytorch/aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at /pytorch/aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at /pytorch/aten/src/ATen/functorch/TensorWrapper.cpp:203 [backend fallback]
PythonTLSSnapshot: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:161 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at /pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:494 [backend fallback]
PreDispatch: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:165 [backend fallback]
PythonDispatcher: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:157 [backend fallback]

Traceback (most recent call last):
  File "/opt/projects/FastChat/fastchat/train/train_lora.py", line 223, in <module>
    train()
  File "/opt/projects/FastChat/fastchat/train/train_lora.py", line 199, in train
    trainer.train()
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/transformers/trainer.py", line 1555, in train
    return inner_training_loop(
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/transformers/trainer.py", line 1689, in _inner_training_loop
    model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/accelerate/accelerator.py", line 1282, in prepare
    result = self._prepare_deepspeed(*args)
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/accelerate/accelerator.py", line 1663, in _prepare_deepspeed
    engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/deepspeed/__init__.py", line 171, in initialize
    engine = DeepSpeedEngine(args=args,
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 304, in __init__
    self._configure_optimizer(optimizer, model_parameters)
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1222, in _configure_optimizer
    self.optimizer = self._configure_zero_optimizer(basic_optimizer)
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1483, in _configure_zero_optimizer
    optimizer = DeepSpeedZeroOptimizer(
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 356, in __init__
    self._update_model_bit16_weights(i)
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 591, in _update_model_bit16_weights
    updated_params = self.unflatten(self.bit16_groups_flat[group_index],
  File "/root/miniconda/envs/torch_npu/lib/python3.9/site-packages/torch/_utils.py", line 534, in _unflatten_dense_tensors
    return torch._C._nn.unflatten_dense_tensors(flat, tensors)
NotImplementedError: Could not run 'npu::npu_format_cast' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'npu::npu_format_cast' is only available for these backends: [PrivateUse1, SparsePrivateUse1, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradPrivateUse1, AutogradMeta, Tracer, AutocastCPU, AutocastCUDA, AutocastPrivateUse1, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

PrivateUse1: registered at /usr1/02/workspace/j_vqN6BFvg/pytorch/torch_npu/csrc/aten/CustomRegisterSchema.cpp:977 [kernel]
SparsePrivateUse1: registered at /usr1/02/workspace/j_vqN6BFvg/pytorch/torch_npu/csrc/aten/VariableFallbackKernel.cpp:70 [backend fallback]
BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:153 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at /pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:498 [backend fallback]
Functionalize: registered at /pytorch/aten/src/ATen/FunctionalizeFallbackKernel.cpp:290 [backend fallback]
Named: registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at /pytorch/aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at /pytorch/aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]
ZeroTensor: registered at /pytorch/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:53 [backend fallback]
AutogradCPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:57 [backend fallback]
AutogradCUDA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:65 [backend fallback]
AutogradXLA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:69 [backend fallback]
AutogradMPS: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:77 [backend fallback]
AutogradXPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:61 [backend fallback]
AutogradHPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:90 [backend fallback]
AutogradLazy: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:73 [backend fallback]
AutogradPrivateUse1: fallthrough registered at /usr1/02/workspace/j_vqN6BFvg/pytorch/torch_npu/csrc/aten/VariableFallbackKernel.cpp:36 [backend fallback]
AutogradMeta: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:81 [backend fallback]
Tracer: registered at /pytorch/torch/csrc/autograd/TraceTypeManual.cpp:296 [backend fallback]
AutocastCPU: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:382 [backend fallback]
AutocastCUDA: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:249 [backend fallback]
AutocastPrivateUse1: fallthrough registered at /usr1/02/workspace/j_vqN6BFvg/pytorch/torch_npu/csrc/aten/AutoCastOps.cpp:19 [backend fallback]
FuncTorchBatched: registered at /pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:710 [backend fallback]
FuncTorchVmapMode: fallthrough registered at /pytorch/aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at /pytorch/aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at /pytorch/aten/src/ATen/functorch/TensorWrapper.cpp:203 [backend fallback]
PythonTLSSnapshot: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:161 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at /pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:494 [backend fallback]
PreDispatch: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:165 [backend fallback]
PythonDispatcher: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:157 [backend fallback]

/root/miniconda/envs/torch_npu/lib/python3.9/tempfile.py:821: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/tmp/tmp_tfmzq_z'>
  _warnings.warn(warn_message, ResourceWarning)
/root/miniconda/envs/torch_npu/lib/python3.9/tempfile.py:821: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/tmp/tmpglrsmbyf'>
  _warnings.warn(warn_message, ResourceWarning)
[2023-11-20 06:40:35,503] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 2089
[2023-11-20 06:40:35,505] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 2090

评论 (5)

youjingq 创建了Bug-Report 2年前
youjingq 修改了标题 2年前
youjingq 修改了描述 2年前
youjingq 修改了标题 2年前
youjingq 修改了描述 2年前
展开全部操作日志

npu_format_cast只能用于npu tensor,cpu tensor用不了

从报错看是这样的,但是我理解它要么不应该走到npu_format_cast,要么就需要npu_format_cast适配cpu tensor

Destiny 任务状态TODO 修改为Analysing 1年前

这个cpu tensor哪来的呢

deepspeed是一个为了节省显存的库,会把一些参数卸载到cpu上进行计算

登录 后才可以发表评论

状态
负责人
项目
里程碑
Pull Requests
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
开始日期   -   截止日期
-
置顶选项
优先级
预计工期 (小时)
参与者(2)
Destiny-wx1103340 13545371 youjingq 1705563965
Python
1
https://gitee.com/ascend/pytorch.git
git@gitee.com:ascend/pytorch.git
ascend
pytorch
pytorch

搜索帮助