This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible.
apex.amp
is a tool to enable mixed precision training by changing only 3 lines of your script.
Users can easily experiment with different pure and mixed precision training modes by supplying
different flags to amp.initialize
.
Webinar introducing Amp
(The flag cast_batchnorm
has been renamed to keep_batchnorm_fp32
).
Comprehensive Imagenet example
Moving to the new Amp API (for users of the deprecated "Amp" and "FP16_Optimizer" APIs)
apex.parallel.DistributedDataParallel
is a module wrapper, similar to
torch.nn.parallel.DistributedDataParallel
. It enables convenient multiprocess distributed training,
optimized for NVIDIA's NCCL communication library.
The Imagenet example
shows use of apex.parallel.DistributedDataParallel
along with apex.amp
.
apex.parallel.SyncBatchNorm
extends torch.nn.modules.batchnorm._BatchNorm
to
support synchronized BN.
It allreduces stats across processes during multiprocess (DistributedDataParallel) training.
Synchronous BN has been used in cases where only a small
local minibatch can fit on each GPU.
Allreduced stats increase the effective batch size for the BN layer to the
global batch size across all processes (which, technically, is the correct
formulation).
Synchronous BN has been observed to improve converged accuracy in some of our research models.
To properly save and load your amp
training, we introduce the amp.state_dict()
, which contains all loss_scalers
and their corresponding unskipped steps,
as well as amp.load_state_dict()
to restore these attributes.
In order to get bitwise accuracy, we recommend the following workflow:
# Initialization
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
# Train your model
...
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
...
# Save checkpoint
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...
# Restore
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])
# Continue training
...
Note that we recommend restoring the model using the same opt_level
. Also note that we recommend calling the load_state_dict
methods after amp.initialize
.
Python 3
CUDA 9 or newer
PyTorch 0.4 or newer. The CUDA and C++ extensions require pytorch 1.0 or newer.
We recommend the latest stable release, obtainable from https://pytorch.org/. We also test against the latest master branch, obtainable from https://github.com/pytorch/pytorch.
It's often convenient to use Apex in Docker containers. Compatible options include:
pip uninstall apex
then reinstall Apex using the Quick Start commands below.docker pull pytorch/pytorch:nightly-devel-cuda10.0-cudnn7
, in which you can install Apex using the Quick Start commands.See the Docker example folder for details.
For performance and full functionality, we recommend installing Apex with CUDA and C++ extensions via
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
Apex also supports a Python-only build (required with Pytorch 0.4) via
pip install -v --disable-pip-version-check --no-cache-dir ./
A Python-only build omits:
apex.optimizers.FusedAdam
.apex.normalization.FusedLayerNorm
.apex.parallel.SyncBatchNorm
.apex.parallel.DistributedDataParallel
and apex.amp
.
DistributedDataParallel
, amp
, and SyncBatchNorm
will still be usable, but they may be slower.Pyprof support has been moved to its own dedicated repository. The codebase is deprecated in Apex and will be removed soon.
Windows support is experimental, and Linux is recommended. pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
may work if you were able to build Pytorch from source
on your system. pip install -v --no-cache-dir .
(without CUDA/C++ extensions) is more likely to work. If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。