代码拉取完成,页面将自动刷新
使用torchattacks去攻击resnet-18模型的时候,在cuda的服务器上,可以到0精度,但是在ascend上面还有60多精度。
代码如下:
import torch
import torch_npu
import os
import inspect
import fire
import __main__
import shutil
import pytorch_lightning as pl
from torchmetrics.classification import MulticlassAccuracy
from torch import optim as torch_optim
from pytorch_lightning.callbacks import ModelCheckpoint
import torch.nn.functional as F
from utils import EarlyStopping
# from accelerator_npu import NPUAccelerator
from data_loader import get_data_module
from torchattacks import FGSM, PGD
from model import ResNet18
import accelerator_npu_transfer
from torch_npu.contrib import transfer_to_npu
class TrainingStep(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.model = self.get_model(config['model'])
self.lr = config['lr']
self.optimizer = config['optimizer']
self.automatic_optimization = False
self.best_val_pgd_acc = 0
self.test_eps = 0.031
self.optim = config['optim']
self._init_metrics()
def _init_metrics(self):
self.train_acc = MulticlassAccuracy(num_classes=10)
self.train_adv_acc = MulticlassAccuracy(num_classes=10)
self.val_acc = MulticlassAccuracy(num_classes=10)
self.val_fgsm_acc = MulticlassAccuracy(num_classes=10)
self.val_pgd_acc = MulticlassAccuracy(num_classes=10)
def on_fit_start(self):
self.fgsm_attack = FGSM(self.model)
self.pgd_attack = PGD(self.model)
def training_step(self, batch, batch_idx):
optimizer = self.optimizers()
x, y = batch
y_pred = self.model(x)
loss = F.cross_entropy(y_pred, y)
optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()
self.train_acc.to('cpu')
self.train_acc(y_pred.cpu(), y.cpu())
self.log('train_acc', self.train_acc, on_epoch=True)
def on_validation_model_eval(self, *args, **kwargs):
super().on_validation_model_eval(*args, **kwargs)
torch.set_grad_enabled(True)
def training_epoch_end(self, outputs):
self.lr_schedulers().step()
def configure_optimizers(self):
if self.optim == 'adam':
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
elif self.optim == 'sgd':
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9, weight_decay=2e-4)
lr_scheduler = torch_optim.lr_scheduler.MultiStepLR(
optimizer=optimizer, milestones=[100, 150], gamma=0.1)
return [optimizer], [lr_scheduler]
def validation_step(self, batch, batch_idx):
x, y = batch
y_pred = self.model(x)
val_loss_classification = F.cross_entropy(y_pred, y)
self.val_acc.to('cpu')
self.val_acc(y_pred.cpu(), y.cpu())
self.log('val_loss_classification', val_loss_classification, on_epoch=True)
self.log('val_acc', self.val_acc, on_epoch=True, prog_bar=True)
self.attack_validation(batch, 'fgsm')
self.attack_validation(batch, 'pgd')
def attack_validation(self, batch, attack):
x, y = batch
attack_method = getattr(self, f'{attack}_attack')
y_pred_adv = self.model(attack_method(*batch))
acc = getattr(self, f'val_{attack}_acc').to('cpu')
acc(y_pred_adv.to('cpu'), y.to('cpu'))
self.log(f'val_{attack}_acc', acc, on_epoch=True)
def get_model(self, model_name):
if model_name == 'resnet18':
return ResNet18()
def run(config):
datamodule, config['input_channels'], config['num_classes'] = get_data_module(config)
if not config['tune']:
checkpoint_callback = ModelCheckpoint('checkpoint', "checkpoint-{epoch:03d}-{train_acc:.3f}", save_last=True,
monitor="train_acc", save_top_k=3, save_on_train_epoch_end=False, mode="max")
else:
checkpoint_callback = TuneReportCheckpointCallback(
metrics={
"train_acc": "train_acc",
})
trainer = pl.Trainer(
accelerator='gpu',
devices=config['devices'],
enable_progress_bar=not config['tune'],
callbacks=[
EarlyStopping(monitor="val_loss_classification", patience=config['max_epochs'], strict=False),
# EarlyStopping(monitor="best_val_pgd_acc", patience=config['max_epochs'], threshold_patience=15, stopping_threshold=0.2, strict=False),
checkpoint_callback
],
max_epochs=config['max_epochs'],
fast_dev_run=config['fast_dev_run'],
# limit_train_batches=2,
# limit_val_batches=2,
num_sanity_val_steps=0,
sync_batchnorm=True,
precision=32,
)
ckpt_path = os.path.join(config['checkpoint_dir'], 'checkpoint') if config['checkpoint_dir'] else None
trainer.fit(TrainingStep(config), datamodule=datamodule, ckpt_path=ckpt_path)
def main(
model : str = 'resnet18',
dataset : str = 'cifar10',
batch_size : int = 128,
lr : float = 0.0001,
optimizer : str = 'sgd',
seed : int = 0,
precision : int = 32,
optim : str = 'adam',
eps : float = 0.031,
beam_size : int = 3,
max_epochs = 200,
table_name = None,
project_path = None,
refresh = False,
fast_dev_run = False,
checkpoint_dir = None,
):
devices = [1]
config = locals()
if table_name is None:
config['table_name'] = os.path.splitext(os.path.basename(__main__.__file__))[0]
if project_path is None:
config['project_path'] = os.path.dirname(os.path.realpath(__main__.__file__))
if refresh and os.path.exists(config['output_path']):
shutil.rmtree(config['output_path'])
config['function_signature'] = inspect.signature(main)
pl.seed_everything(config['seed'])
run(config)
if __name__ == '__main__':
fire.Fire(main)
机器为:
Ascend 910,鲲鹏处理器。
pytorch 1.8和1.11版本都有这个问题。
环境:
# Name Version Build Channel
_openmp_mutex 4.5 1_gnu conda-forge
absl-py 0.13.0 <pip>
addict 2.4.0 <pip>
aiohttp 3.8.4 <pip>
aiosignal 1.3.1 <pip>
albumentations 0.4.5 <pip>
antlr4-python3-runtime 4.9.3 <pip>
asgiref 3.5.0 <pip>
astor 0.8.1 <pip>
asttokens 2.0.5 <pip>
async-timeout 4.0.2 <pip>
asynctest 0.13.0 <pip>
attrs 19.3.0 <pip>
audioread 3.0.0 <pip>
autopep8 2.0.2 <pip>
backcall 0.2.0 <pip>
boto3 1.12.22 <pip>
botocore 1.15.49 <pip>
ca-certificates 2021.10.8 h4fd8a4c_0 conda-forge
cachetools 5.3.0 <pip>
certifi 2021.10.8 <pip>
cffi 1.14.0 <pip>
chardet 3.0.4 <pip>
click 8.0.4 <pip>
cloudpickle 1.3.0 <pip>
cycler 0.11.0 <pip>
Cython 0.29.14 <pip>
dask 2.18.1 <pip>
decorator 4.4.1 <pip>
dill 0.3.6 <pip>
distlib 0.3.6 <pip>
Django 3.2.12 <pip>
docutils 0.15.2 <pip>
easydict 1.9 <pip>
einops 0.6.0 <pip>
entrypoints 0.3 <pip>
esdk-obs-python 3.20.1 <pip>
et-xmlfile 1.1.0 <pip>
filelock 3.10.0 <pip>
fire 0.5.0 <pip>
flake8 4.0.1 pyhd3eb1b0_1 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
Flask 2.0.1 <pip>
flit-core 3.6.0 pyhd3eb1b0_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
frozenlist 1.3.3 <pip>
fsspec 2023.1.0 <pip>
future 0.18.2.post20200723173923 <pip>
gast 0.2.2 <pip>
google-auth 2.16.2 <pip>
google-auth-oauthlib 0.4.6 <pip>
google-pasta 0.2.0 <pip>
grpcio 1.44.0 <pip>
grpcio-tools 1.26.0 <pip>
gunicorn 20.0.4 <pip>
h5py 2.10.0 <pip>
huaweicloud-sdk-python-modelarts-dataset 0.1.5 <pip>
idna 2.10 <pip>
image 1.5.28 <pip>
imageio 2.9.0 <pip>
imgaug 0.2.6 <pip>
importlib-metadata 6.0.0 <pip>
importlib-metadata 4.11.3 py37hd43f75c_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
importlib-resources 5.12.0 <pip>
inflect 6.0.2 <pip>
ipykernel 5.3.4 <pip>
ipython 7.31.1 <pip>
ipython-genutils 0.2.0 <pip>
itsdangerous 2.1.1 <pip>
jdcal 1.4.1 <pip>
jedi 0.18.1 <pip>
Jinja2 3.0.3 <pip>
jmespath 0.10.0 <pip>
joblib 1.1.0 <pip>
jsonschema 4.17.3 <pip>
jupyter-client 7.1.2 <pip>
jupyter-core 4.9.1 <pip>
Keras 2.3.1 <pip>
Keras-Applications 1.0.8 <pip>
Keras-Preprocessing 1.1.2 <pip>
kfac 0.2.0 <pip>
kiwisolver 1.1.0 <pip>
lazy-import 0.2.2 <pip>
ld_impl_linux-aarch64 2.36.1 h02ad14f_2 conda-forge
libffi 3.4.2 h3557bc0_5 conda-forge
libgcc-ng 11.2.0 hf1cc4e7_12 conda-forge
libgomp 11.2.0 hf1cc4e7_12 conda-forge
libnsl 2.0.0 hf897c2e_0 conda-forge
librosa 0.8.0 <pip>
libstdcxx-ng 11.2.0 h0d0a5bb_12 conda-forge
libzlib 1.2.11 hb9de7d4_1013 conda-forge
lightning-utilities 0.8.0 <pip>
llvmlite 0.39.1 <pip>
lxml 4.4.2 <pip>
Markdown 3.3.6 <pip>
MarkupSafe 2.1.1 <pip>
marshmallow 3.15.0 <pip>
matplotlib 3.2.1 <pip>
matplotlib-inline 0.1.3 <pip>
mccabe 0.7.0 pyhd3eb1b0_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
mccabe 0.6.1 <pip>
mindinsight 1.6.1 <pip>
mindspore-ascend 1.7.0 <pip>
mindx-elastic 0.0.1 <pip>
mmcv 0.2.14 <pip>
moxing-framework 2.0.1.rc0.ffd1c0c8 <pip>
mpmath 1.2.1 <pip>
msgpack 1.0.5 <pip>
multidict 6.0.4 <pip>
multiprocess 0.70.14 <pip>
ncurses 6.3 h01db608_0 conda-forge
nest-asyncio 1.5.4 <pip>
networkx 2.6.3 <pip>
numba 0.56.4 <pip>
numexpr 2.7.1 <pip>
numpy 1.21.6 <pip>
oauthlib 3.2.2 <pip>
omegaconf 2.3.0 <pip>
opencv-python 4.2.0.34 <pip>
openpyxl 3.0.3 <pip>
openssl 3.0.0 hf897c2e_1 conda-forge
opt-einsum 3.3.0 <pip>
packaging 21.3 <pip>
pandas 1.1.3 <pip>
parso 0.8.3 <pip>
pathlib2 2.3.7.post1 <pip>
pathos 0.3.0 <pip>
pexpect 4.8.0 <pip>
pickleshare 0.7.5 <pip>
Pillow 7.0.0 <pip>
pip 21.3.1 pyhd8ed1ab_0 conda-forge
pip 23.0.1 <pip>
pkgutil_resolve_name 1.3.10 <pip>
platformdirs 3.1.1 <pip>
pooch 1.7.0 <pip>
pox 0.3.2 <pip>
ppft 1.7.6.6 <pip>
prometheus-client 0.8.0 <pip>
prompt-toolkit 3.0.24 <pip>
protobuf 3.19.4 <pip>
psutil 5.7.0 <pip>
ptyprocess 0.7.0 <pip>
pyasn1 0.4.8 <pip>
pyasn1-modules 0.2.8 <pip>
pycocotools 2.0.0 <pip>
pycodestyle 2.8.0 pyhd3eb1b0_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
pycodestyle 2.10.0 <pip>
pycparser 2.21 <pip>
pydantic 1.10.6 <pip>
pyflakes 2.4.0 pyhd3eb1b0_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
Pygments 2.11.2 <pip>
pyparsing 3.0.7 <pip>
pyrsistent 0.19.3 <pip>
python 3.7.10 h47f6e27_104_cpython conda-forge
python-dateutil 2.8.2 <pip>
python_abi 3.7 2_cp37m conda-forge
pytorch-lightning 1.9.0 <pip>
pytz 2021.3 <pip>
PyWavelets 1.1.1 <pip>
PyYAML 6.0 <pip>
pyzmq 22.3.0 <pip>
ray 2.3.0 <pip>
readline 8.1 h1a49cc3_0 conda-forge
requests 2.23.0 <pip>
requests-oauthlib 1.3.1 <pip>
resampy 0.4.2 <pip>
rsa 4.9 <pip>
s3transfer 0.3.7 <pip>
scikit-image 0.17.2 <pip>
scikit-learn 0.24.0 <pip>
scipy 1.5.4 <pip>
setuptools 60.5.0 py37hd9ded2f_0 conda-forge
Shapely 1.8.1.post1 <pip>
six 1.16.0 <pip>
soundfile 0.12.1 <pip>
sqlite 3.37.0 hc164836_0 conda-forge
sqlparse 0.4.2 <pip>
sympy 1.4 <pip>
tables 3.6.1 <pip>
tensorboard 2.11.2 <pip>
tensorboard-data-server 0.6.1 <pip>
tensorboard-plugin-wit 1.8.1 <pip>
tensorboardX 2.6 <pip>
tensorflow 1.15.0 <pip>
tensorflow-estimator 1.15.1 <pip>
tensorflow-probability 0.10.1 <pip>
termcolor 1.1.0 <pip>
terminaltables 3.1.0 <pip>
threadpoolctl 3.1.0 <pip>
tifffile 2021.11.2 <pip>
tk 8.6.11 hd8af866_1 conda-forge
toml 0.10.1 <pip>
tomli 2.0.1 <pip>
torch 1.11.0 <pip>
torch-npu 1.11.0rc2 <pip>
torchattacks 3.3.0 <pip>
torchaudio 0.11.0 <pip>
torchmetrics 0.11.4 <pip>
torchvision 0.11.1 <pip>
tornado 6.1 <pip>
tqdm 4.65.0 <pip>
traitlets 5.1.1 <pip>
treelib 1.6.1 <pip>
typing_extensions 4.4.0 py37hd43f75c_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
typing_extensions 4.5.0 <pip>
umap-learn-modified 0.3.8 <pip>
Unidecode 1.3.6 <pip>
urllib3 1.25.11 <pip>
virtualenv 20.21.0 <pip>
wcwidth 0.2.5 <pip>
Werkzeug 2.0.3 <pip>
wheel 0.37.1 pyhd8ed1ab_0 conda-forge
wrapt 1.14.0 <pip>
XlsxWriter 3.0.3 <pip>
xmltodict 0.12.0 <pip>
xz 5.2.5 h6dd45c4_1 conda-forge
yapf 0.32.0 <pip>
yarl 1.8.2 <pip>
zipp 3.7.0 <pip>
zipp 3.11.0 py37hd43f75c_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
zlib 1.2.11 hb9de7d4_1013 conda-forge
其中accelerator_npu_transfer
是把pytorch lightning的一些函数变成兼容npu的,代码如下:
import torch
import logging
from typing import Any, Optional, Union
import logging
import os
from typing import Optional, Union
import torch
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from torch.nn import Module
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from lightning_fabric.accelerators.cuda import _check_cuda_matmul_precision
def setup_device(self, device: torch.device) -> None:
"""
Raises:
MisconfigurationException:
If the selected device is not GPU.
"""
if device.type != "npu":
raise MisconfigurationException(f"Device should be GPU, got {device} instead")
_check_cuda_matmul_precision(device)
torch.cuda.set_device(device)
CUDAAccelerator.setup_device = setup_device
def to(self, *args: Any, **kwargs: Any): # type: ignore[valid-type]
"""See :meth:`torch.nn.Module.to`."""
# this converts `str` device to `torch.device`
device = torch.device(*args, **kwargs)
__update_properties(self, device=device, dtype=None)
return super(_DeviceDtypeModuleMixin, self).to(*args, **kwargs)
def __update_properties(
self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
) -> None:
def apply_fn(module: Union[_DeviceDtypeModuleMixin, Module]) -> None:
if not isinstance(module, _DeviceDtypeModuleMixin):
return
if device is not None:
module._device = device
if dtype is not None:
module._dtype = dtype
self.apply(apply_fn)
_DeviceDtypeModuleMixin.to = to
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。
Please add labels , also you can visit https://gitee.com/ascend/community/blob/master/labels.md to find more.
为了让代码尽快被审核,请您为Issue打上标签,打上标签的Issue可以直接推送给责任人进行审核。
更多的标签可以查看https://gitee.com/ascend/community/blob/master/labels.md
以模型训练相关代码提交为例,如果你提交的是模型训练代码,你可以这样评论:
//train/model
另外你还可以给这个Issue标记类型,例如是bugfix或者是特性需求:
//kind/bug or //kind/feature
恭喜你,你已经学会了使用命令来打标签,接下来就在下面的评论里打上标签吧!
登录 后才可以发表评论