75 Star 582 Fork 1.1K

Ascend/pytorch

pytorch运行结果跟cuda不对

WIP
缺陷
创建于  
2023-03-22 15:33

使用torchattacks去攻击resnet-18模型的时候,在cuda的服务器上,可以到0精度,但是在ascend上面还有60多精度。

代码如下:

import torch
import torch_npu
import os
import inspect
import fire
import __main__
import shutil
import pytorch_lightning as pl
from torchmetrics.classification import MulticlassAccuracy
from torch import optim as torch_optim
from pytorch_lightning.callbacks import ModelCheckpoint
import torch.nn.functional as F
from utils import EarlyStopping
# from accelerator_npu import NPUAccelerator
from data_loader import get_data_module
from torchattacks import FGSM, PGD
from model import ResNet18
import accelerator_npu_transfer
from torch_npu.contrib import transfer_to_npu


class TrainingStep(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.model = self.get_model(config['model'])
        self.lr = config['lr']
        self.optimizer = config['optimizer']
        self.automatic_optimization = False
        self.best_val_pgd_acc = 0
        self.test_eps = 0.031
        self.optim = config['optim']
        self._init_metrics()

    def _init_metrics(self):
        self.train_acc = MulticlassAccuracy(num_classes=10)
        self.train_adv_acc = MulticlassAccuracy(num_classes=10)
        self.val_acc = MulticlassAccuracy(num_classes=10)
        self.val_fgsm_acc = MulticlassAccuracy(num_classes=10)
        self.val_pgd_acc = MulticlassAccuracy(num_classes=10)

    def on_fit_start(self):
        self.fgsm_attack = FGSM(self.model)
        self.pgd_attack = PGD(self.model)

    def training_step(self, batch, batch_idx):
        optimizer = self.optimizers()

        x, y = batch

        y_pred = self.model(x)

        loss = F.cross_entropy(y_pred, y)

        optimizer.zero_grad()
        self.manual_backward(loss)
        optimizer.step()

        self.train_acc.to('cpu')
        self.train_acc(y_pred.cpu(), y.cpu())

        self.log('train_acc', self.train_acc, on_epoch=True)

    def on_validation_model_eval(self, *args, **kwargs):
        super().on_validation_model_eval(*args, **kwargs)
        torch.set_grad_enabled(True)

    def training_epoch_end(self, outputs):
        self.lr_schedulers().step()

    def configure_optimizers(self):
        if self.optim == 'adam':
            optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        elif self.optim == 'sgd':
            optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9, weight_decay=2e-4)

        lr_scheduler = torch_optim.lr_scheduler.MultiStepLR(
            optimizer=optimizer, milestones=[100, 150], gamma=0.1)
        return [optimizer], [lr_scheduler]

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.model(x)
        val_loss_classification = F.cross_entropy(y_pred, y)

        self.val_acc.to('cpu')
        self.val_acc(y_pred.cpu(), y.cpu())
        self.log('val_loss_classification', val_loss_classification, on_epoch=True)
        self.log('val_acc', self.val_acc, on_epoch=True, prog_bar=True)

        self.attack_validation(batch, 'fgsm')
        self.attack_validation(batch, 'pgd')

    def attack_validation(self, batch, attack):
        x, y = batch

        attack_method = getattr(self, f'{attack}_attack')
        y_pred_adv = self.model(attack_method(*batch))

        acc = getattr(self, f'val_{attack}_acc').to('cpu')
        acc(y_pred_adv.to('cpu'), y.to('cpu'))
        self.log(f'val_{attack}_acc', acc, on_epoch=True)

    def get_model(self, model_name):
        if model_name == 'resnet18':
            return ResNet18()


def run(config):
    datamodule, config['input_channels'], config['num_classes'] = get_data_module(config)

    if not config['tune']:
        checkpoint_callback = ModelCheckpoint('checkpoint',  "checkpoint-{epoch:03d}-{train_acc:.3f}", save_last=True,
                                              monitor="train_acc", save_top_k=3, save_on_train_epoch_end=False, mode="max")
    else:
        checkpoint_callback = TuneReportCheckpointCallback(
            metrics={
                "train_acc": "train_acc",
            })
    trainer = pl.Trainer(
        accelerator='gpu',
        devices=config['devices'],
        enable_progress_bar=not config['tune'],
        callbacks=[
            EarlyStopping(monitor="val_loss_classification", patience=config['max_epochs'], strict=False),
            #    EarlyStopping(monitor="best_val_pgd_acc", patience=config['max_epochs'], threshold_patience=15, stopping_threshold=0.2, strict=False),
            checkpoint_callback
        ],
        max_epochs=config['max_epochs'],
        fast_dev_run=config['fast_dev_run'],
        # limit_train_batches=2,
        # limit_val_batches=2,
        num_sanity_val_steps=0,
        sync_batchnorm=True,
        precision=32,
    )
    ckpt_path = os.path.join(config['checkpoint_dir'], 'checkpoint') if config['checkpoint_dir'] else None

    trainer.fit(TrainingStep(config), datamodule=datamodule, ckpt_path=ckpt_path)



def main(
    model                 : str   = 'resnet18',
    dataset               : str   = 'cifar10',
    batch_size            : int   = 128,
    lr                    : float = 0.0001,
    optimizer             : str   = 'sgd',
    seed                  : int   = 0,
    precision             : int   = 32,
    optim                 : str   = 'adam',
    eps                   : float = 0.031,
    beam_size             : int   = 3,
    max_epochs                    = 200,
    table_name                    = None,
    project_path                  = None,
    refresh                       = False,
    fast_dev_run                  = False,
    checkpoint_dir                = None,
):
    devices = [1]
    config = locals()
    if table_name is None:
        config['table_name'] = os.path.splitext(os.path.basename(__main__.__file__))[0]
    if project_path is None:
        config['project_path'] = os.path.dirname(os.path.realpath(__main__.__file__))
    if refresh and os.path.exists(config['output_path']):
        shutil.rmtree(config['output_path'])
    config['function_signature'] = inspect.signature(main)
    pl.seed_everything(config['seed'])
    run(config)


if __name__ == '__main__':
    fire.Fire(main)

机器为:
Ascend 910,鲲鹏处理器。

pytorch 1.8和1.11版本都有这个问题。

环境:

# Name                    Version                   Build  Channel
_openmp_mutex             4.5                       1_gnu    conda-forge
absl-py                   0.13.0                    <pip>
addict                    2.4.0                     <pip>
aiohttp                   3.8.4                     <pip>
aiosignal                 1.3.1                     <pip>
albumentations            0.4.5                     <pip>
antlr4-python3-runtime    4.9.3                     <pip>
asgiref                   3.5.0                     <pip>
astor                     0.8.1                     <pip>
asttokens                 2.0.5                     <pip>
async-timeout             4.0.2                     <pip>
asynctest                 0.13.0                    <pip>
attrs                     19.3.0                    <pip>
audioread                 3.0.0                     <pip>
autopep8                  2.0.2                     <pip>
backcall                  0.2.0                     <pip>
boto3                     1.12.22                   <pip>
botocore                  1.15.49                   <pip>
ca-certificates           2021.10.8            h4fd8a4c_0    conda-forge
cachetools                5.3.0                     <pip>
certifi                   2021.10.8                 <pip>
cffi                      1.14.0                    <pip>
chardet                   3.0.4                     <pip>
click                     8.0.4                     <pip>
cloudpickle               1.3.0                     <pip>
cycler                    0.11.0                    <pip>
Cython                    0.29.14                   <pip>
dask                      2.18.1                    <pip>
decorator                 4.4.1                     <pip>
dill                      0.3.6                     <pip>
distlib                   0.3.6                     <pip>
Django                    3.2.12                    <pip>
docutils                  0.15.2                    <pip>
easydict                  1.9                       <pip>
einops                    0.6.0                     <pip>
entrypoints               0.3                       <pip>
esdk-obs-python           3.20.1                    <pip>
et-xmlfile                1.1.0                     <pip>
filelock                  3.10.0                    <pip>
fire                      0.5.0                     <pip>
flake8                    4.0.1              pyhd3eb1b0_1    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
Flask                     2.0.1                     <pip>
flit-core                 3.6.0              pyhd3eb1b0_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
frozenlist                1.3.3                     <pip>
fsspec                    2023.1.0                  <pip>
future                    0.18.2.post20200723173923           <pip>
gast                      0.2.2                     <pip>
google-auth               2.16.2                    <pip>
google-auth-oauthlib      0.4.6                     <pip>
google-pasta              0.2.0                     <pip>
grpcio                    1.44.0                    <pip>
grpcio-tools              1.26.0                    <pip>
gunicorn                  20.0.4                    <pip>
h5py                      2.10.0                    <pip>
huaweicloud-sdk-python-modelarts-dataset 0.1.5                     <pip>
idna                      2.10                      <pip>
image                     1.5.28                    <pip>
imageio                   2.9.0                     <pip>
imgaug                    0.2.6                     <pip>
importlib-metadata        6.0.0                     <pip>
importlib-metadata        4.11.3           py37hd43f75c_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
importlib-resources       5.12.0                    <pip>
inflect                   6.0.2                     <pip>
ipykernel                 5.3.4                     <pip>
ipython                   7.31.1                    <pip>
ipython-genutils          0.2.0                     <pip>
itsdangerous              2.1.1                     <pip>
jdcal                     1.4.1                     <pip>
jedi                      0.18.1                    <pip>
Jinja2                    3.0.3                     <pip>
jmespath                  0.10.0                    <pip>
joblib                    1.1.0                     <pip>
jsonschema                4.17.3                    <pip>
jupyter-client            7.1.2                     <pip>
jupyter-core              4.9.1                     <pip>
Keras                     2.3.1                     <pip>
Keras-Applications        1.0.8                     <pip>
Keras-Preprocessing       1.1.2                     <pip>
kfac                      0.2.0                     <pip>
kiwisolver                1.1.0                     <pip>
lazy-import               0.2.2                     <pip>
ld_impl_linux-aarch64     2.36.1               h02ad14f_2    conda-forge
libffi                    3.4.2                h3557bc0_5    conda-forge
libgcc-ng                 11.2.0              hf1cc4e7_12    conda-forge
libgomp                   11.2.0              hf1cc4e7_12    conda-forge
libnsl                    2.0.0                hf897c2e_0    conda-forge
librosa                   0.8.0                     <pip>
libstdcxx-ng              11.2.0              h0d0a5bb_12    conda-forge
libzlib                   1.2.11            hb9de7d4_1013    conda-forge
lightning-utilities       0.8.0                     <pip>
llvmlite                  0.39.1                    <pip>
lxml                      4.4.2                     <pip>
Markdown                  3.3.6                     <pip>
MarkupSafe                2.1.1                     <pip>
marshmallow               3.15.0                    <pip>
matplotlib                3.2.1                     <pip>
matplotlib-inline         0.1.3                     <pip>
mccabe                    0.7.0              pyhd3eb1b0_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
mccabe                    0.6.1                     <pip>
mindinsight               1.6.1                     <pip>
mindspore-ascend          1.7.0                     <pip>
mindx-elastic             0.0.1                     <pip>
mmcv                      0.2.14                    <pip>
moxing-framework          2.0.1.rc0.ffd1c0c8           <pip>
mpmath                    1.2.1                     <pip>
msgpack                   1.0.5                     <pip>
multidict                 6.0.4                     <pip>
multiprocess              0.70.14                   <pip>
ncurses                   6.3                  h01db608_0    conda-forge
nest-asyncio              1.5.4                     <pip>
networkx                  2.6.3                     <pip>
numba                     0.56.4                    <pip>
numexpr                   2.7.1                     <pip>
numpy                     1.21.6                    <pip>
oauthlib                  3.2.2                     <pip>
omegaconf                 2.3.0                     <pip>
opencv-python             4.2.0.34                  <pip>
openpyxl                  3.0.3                     <pip>
openssl                   3.0.0                hf897c2e_1    conda-forge
opt-einsum                3.3.0                     <pip>
packaging                 21.3                      <pip>
pandas                    1.1.3                     <pip>
parso                     0.8.3                     <pip>
pathlib2                  2.3.7.post1               <pip>
pathos                    0.3.0                     <pip>
pexpect                   4.8.0                     <pip>
pickleshare               0.7.5                     <pip>
Pillow                    7.0.0                     <pip>
pip                       21.3.1             pyhd8ed1ab_0    conda-forge
pip                       23.0.1                    <pip>
pkgutil_resolve_name      1.3.10                    <pip>
platformdirs              3.1.1                     <pip>
pooch                     1.7.0                     <pip>
pox                       0.3.2                     <pip>
ppft                      1.7.6.6                   <pip>
prometheus-client         0.8.0                     <pip>
prompt-toolkit            3.0.24                    <pip>
protobuf                  3.19.4                    <pip>
psutil                    5.7.0                     <pip>
ptyprocess                0.7.0                     <pip>
pyasn1                    0.4.8                     <pip>
pyasn1-modules            0.2.8                     <pip>
pycocotools               2.0.0                     <pip>
pycodestyle               2.8.0              pyhd3eb1b0_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
pycodestyle               2.10.0                    <pip>
pycparser                 2.21                      <pip>
pydantic                  1.10.6                    <pip>
pyflakes                  2.4.0              pyhd3eb1b0_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
Pygments                  2.11.2                    <pip>
pyparsing                 3.0.7                     <pip>
pyrsistent                0.19.3                    <pip>
python                    3.7.10          h47f6e27_104_cpython    conda-forge
python-dateutil           2.8.2                     <pip>
python_abi                3.7                     2_cp37m    conda-forge
pytorch-lightning         1.9.0                     <pip>
pytz                      2021.3                    <pip>
PyWavelets                1.1.1                     <pip>
PyYAML                    6.0                       <pip>
pyzmq                     22.3.0                    <pip>
ray                       2.3.0                     <pip>
readline                  8.1                  h1a49cc3_0    conda-forge
requests                  2.23.0                    <pip>
requests-oauthlib         1.3.1                     <pip>
resampy                   0.4.2                     <pip>
rsa                       4.9                       <pip>
s3transfer                0.3.7                     <pip>
scikit-image              0.17.2                    <pip>
scikit-learn              0.24.0                    <pip>
scipy                     1.5.4                     <pip>
setuptools                60.5.0           py37hd9ded2f_0    conda-forge
Shapely                   1.8.1.post1               <pip>
six                       1.16.0                    <pip>
soundfile                 0.12.1                    <pip>
sqlite                    3.37.0               hc164836_0    conda-forge
sqlparse                  0.4.2                     <pip>
sympy                     1.4                       <pip>
tables                    3.6.1                     <pip>
tensorboard               2.11.2                    <pip>
tensorboard-data-server   0.6.1                     <pip>
tensorboard-plugin-wit    1.8.1                     <pip>
tensorboardX              2.6                       <pip>
tensorflow                1.15.0                    <pip>
tensorflow-estimator      1.15.1                    <pip>
tensorflow-probability    0.10.1                    <pip>
termcolor                 1.1.0                     <pip>
terminaltables            3.1.0                     <pip>
threadpoolctl             3.1.0                     <pip>
tifffile                  2021.11.2                 <pip>
tk                        8.6.11               hd8af866_1    conda-forge
toml                      0.10.1                    <pip>
tomli                     2.0.1                     <pip>
torch                     1.11.0                    <pip>
torch-npu                 1.11.0rc2                 <pip>
torchattacks              3.3.0                     <pip>
torchaudio                0.11.0                    <pip>
torchmetrics              0.11.4                    <pip>
torchvision               0.11.1                    <pip>
tornado                   6.1                       <pip>
tqdm                      4.65.0                    <pip>
traitlets                 5.1.1                     <pip>
treelib                   1.6.1                     <pip>
typing_extensions         4.4.0            py37hd43f75c_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
typing_extensions         4.5.0                     <pip>
umap-learn-modified       0.3.8                     <pip>
Unidecode                 1.3.6                     <pip>
urllib3                   1.25.11                   <pip>
virtualenv                20.21.0                   <pip>
wcwidth                   0.2.5                     <pip>
Werkzeug                  2.0.3                     <pip>
wheel                     0.37.1             pyhd8ed1ab_0    conda-forge
wrapt                     1.14.0                    <pip>
XlsxWriter                3.0.3                     <pip>
xmltodict                 0.12.0                    <pip>
xz                        5.2.5                h6dd45c4_1    conda-forge
yapf                      0.32.0                    <pip>
yarl                      1.8.2                     <pip>
zipp                      3.7.0                     <pip>
zipp                      3.11.0           py37hd43f75c_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
zlib                      1.2.11            hb9de7d4_1013    conda-forge

其中accelerator_npu_transfer是把pytorch lightning的一些函数变成兼容npu的,代码如下:

import torch
import logging
from typing import Any, Optional, Union

import logging
import os
from typing import Optional, Union

import torch

from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from torch.nn import Module
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from lightning_fabric.accelerators.cuda import _check_cuda_matmul_precision


def setup_device(self, device: torch.device) -> None:
    """
    Raises:
        MisconfigurationException:
            If the selected device is not GPU.
    """
    if device.type != "npu":
        raise MisconfigurationException(f"Device should be GPU, got {device} instead")
    _check_cuda_matmul_precision(device)
    torch.cuda.set_device(device)


CUDAAccelerator.setup_device = setup_device


def to(self, *args: Any, **kwargs: Any):  # type: ignore[valid-type]
    """See :meth:`torch.nn.Module.to`."""
    # this converts `str` device to `torch.device`
    device = torch.device(*args, **kwargs)
    __update_properties(self, device=device, dtype=None)
    return super(_DeviceDtypeModuleMixin, self).to(*args, **kwargs)


def __update_properties(
    self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
) -> None:
    def apply_fn(module: Union[_DeviceDtypeModuleMixin, Module]) -> None:
        if not isinstance(module, _DeviceDtypeModuleMixin):
            return
        if device is not None:
            module._device = device
        if dtype is not None:
            module._dtype = dtype

    self.apply(apply_fn)


_DeviceDtypeModuleMixin.to = to

评论 (2)

孔飞 创建了缺陷 2年前

Please add labels , also you can visit https://gitee.com/ascend/community/blob/master/labels.md to find more.
为了让代码尽快被审核,请您为Issue打上标签,打上标签的Issue可以直接推送给责任人进行审核。
更多的标签可以查看https://gitee.com/ascend/community/blob/master/labels.md
以模型训练相关代码提交为例,如果你提交的是模型训练代码,你可以这样评论:
//train/model
另外你还可以给这个Issue标记类型,例如是bugfix或者是特性需求:
//kind/bug or //kind/feature
恭喜你,你已经学会了使用命令来打标签,接下来就在下面的评论里打上标签吧!

Destiny 任务状态TODO 修改为WIP 2年前

登录 后才可以发表评论

状态
负责人
项目
里程碑
Pull Requests
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
开始日期   -   截止日期
-
置顶选项
优先级
预计工期 (小时)
参与者(2)
ascend-robot-ascend-robot 孔飞-kong13661
Python
1
https://gitee.com/ascend/pytorch.git
git@gitee.com:ascend/pytorch.git
ascend
pytorch
pytorch

搜索帮助