1 Star 0 Fork 0

Smurfs/Synchronized-BatchNorm-PyTorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
贡献代码
同步代码
取消
提示: 由于 Git 不支持空文件夾,创建文件夹后会生成空的 .keep 文件
Loading...
README
MIT

Synchronized-BatchNorm-PyTorch

Synchronized Batch Normalization implementation in PyTorch.

This module differs from the built-in PyTorch BatchNorm as the mean and standard-deviation are reduced across all devices during training.

For example, when one uses nn.DataParallel to wrap the network during training, PyTorch's implementation normalize the tensor on each device using the statistics only on that device, which accelerated the computation and is also easy to implement, but the statistics might be inaccurate. Instead, in this synchronized version, the statistics will be computed over all training samples distributed on multiple devices.

Note that, for one-GPU or CPU-only case, this module behaves exactly same as the built-in PyTorch implementation.

This module is currently only a prototype version for research usages. As mentioned below, it has its limitations and may even suffer from some design problems. If you have any questions or suggestions, please feel free to open an issue or submit a pull request.

Why Synchronized BatchNorm?

Although the typical implementation of BatchNorm working on multiple devices (GPUs) is fast (with no communication overhead), it inevitably reduces the size of batch size, which potentially degenerates the performance. This is not a significant issue in some standard vision tasks such as ImageNet classification (as the batch size per device is usually large enough to obtain good statistics). However, it will hurt the performance in some tasks that the batch size is usually very small (e.g., 1 per GPU).

For example, the importance of synchronized batch normalization in object detection has been recently proved with a an extensive analysis in the paper MegDet: A Large Mini-Batch Object Detector.

Usage

To use the Synchronized Batch Normalization, we add a data parallel replication callback. This introduces a slight difference with typical usage of the nn.DataParallel.

Use it with a provided, customized data parallel wrapper:

from sync_batchnorm import SynchronizedBatchNorm1d, DataParallelWithCallback

sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])

Or, if you are using a customized data parallel module, you can use this library as a monkey patching.

from torch.nn import DataParallel  # or your customized DataParallel module
from sync_batchnorm import SynchronizedBatchNorm1d, patch_replication_callback

sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
patch_replication_callback(sync_bn)  # monkey-patching

You can use convert_model to convert your model to use Synchronized BatchNorm easily.

import torch.nn as nn
from torchvision import models
from sync_batchnorm import convert_model
# m is a standard pytorch model
m = models.resnet18(True)
m = nn.DataParallel(m)
# after convert, m is using SyncBN
m = convert_model(m)

See also tests/test_sync_batchnorm.py for numeric result comparison.

Implementation details and highlights

If you are interested in how batch statistics are reduced and broadcasted among multiple devices, please take a look at the code with detailed comments. Here we only emphasize some highlights of the implementation:

  • This implementation is in pure-python. No C++ extra extension libs.
  • Easy to use as demonstrated above.
  • It is completely compatible with PyTorch's implementation. Specifically, it uses unbiased variance to update the moving average, and use sqrt(max(var, eps)) instead of sqrt(var + eps).
  • The implementation requires that each module on different devices should invoke the batchnorm for exactly SAME amount of times in each forward pass. For example, you can not only call batchnorm on GPU0 but not on GPU1. The #i (i = 1, 2, 3, ...) calls of the batchnorm on each device will be viewed as a whole and the statistics will be reduced. This is tricky but is a good way to handle PyTorch's dynamic computation graph. Although sounds complicated, this will usually not be the issue for most of the models.

Known issues

Runtime error on backward pass.

Due to a PyTorch Bug, using old PyTorch libraries will trigger an RuntimeError with messages like:

Assertion `pos >= 0 && pos < buffer.size()` failed.

This has already been solved in the newest PyTorch repo, which, unfortunately, has not been pushed to the official and anaconda binary release. Thus, you are required to build the PyTorch package from the source according to the instructions here.

Numeric error.

Because this library does not fuse the normalization and statistics operations in C++ (nor CUDA), it is less numerically stable compared to the original PyTorch implementation. Detailed analysis can be found in tests/test_sync_batchnorm.py.

Authors and License:

Copyright (c) 2018-, Jiayuan Mao.

Contributors: Tete Xiao, DTennant.

Distributed under MIT License (See LICENSE)

MIT License Copyright (c) 2018 Jiayuan MAO Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

简介

暂无描述 展开 收起
Python
MIT
取消

发行版

暂无发行版

贡献者

全部

近期动态

加载更多
不能加载更多了
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/smurfs3364/Synchronized-BatchNorm-PyTorch.git
git@gitee.com:smurfs3364/Synchronized-BatchNorm-PyTorch.git
smurfs3364
Synchronized-BatchNorm-PyTorch
Synchronized-BatchNorm-PyTorch
master

搜索帮助

A270a887 8829481 3d7a4017 8829481