Gradient accumulation is an optimization technique that enables the use of a larger Batch Size to train a network when memory is limited. Typically, training large neural networks requires a large amount of memory because calculating the gradient on each Batch and updating the model parameters requires saving the gradient values. Larger Batch Size requires more memory, which may lead to out of memory problems. Gradient accumulation works by summing the gradient values of multiple MicroBatches, thus allowing the model to be trained with a larger Batch Size without increasing memory requirements. This article focuses on gradient accumulation in distributed scenarios.
The core idea of gradient accumulation is to add the gradients of multiple MicroBatches and then use the accumulated gradients to update the model parameters. Here are the steps of gradient accumulation:
Select MicroBatch Size: The data of MicroBatch Size is the basic batch for each forward and backward propagation, and also according to the Batch Size divided by Micro Batch Size to get the number of accumulation steps, you can determine after how many MicroBatches a parameter update is performed.
Forward and backward propagation: for each MicroBatch, perform the standard forward and backward propagation operations. Calculate the gradient of the MicroBatch.
Gradient Accumulation: add the gradient values of each MicroBatch until the number of accumulation steps is reached.
Gradient update: After the accumulation number of steps is reached, the accumulation gradient is used to update the model parameters via the optimizer.
Gradient Clear: After the gradient is updated, the gradient value is cleared to zero for the next accumulation cycle.
mindspore.parallel.GradAccumulation(network, micro_size)
: Wrap the network with a finer-grained MicroBatch. micro_size
is the size of the MicroBatch.
- Under grad accumulation situation, suggests to use lazy_inline decorator to reduce compile time, and only support to set the lazy_inline decorator to the outermost cell.
The following is an illustration of the gradient accumulation operation using Ascend or GPU stand-alone 8-card as an example:
Download the complete example code: distributed_gradient_accumulation.
The directory structure is as follows:
└─ sample_code
├─ distributed_gradient_accumulation
├── train.py
└── run.sh
...
train.py
is the script that defines the network structure and the training process. run.sh
is the execution script.
Initialize the communication with init.
import mindspore as ms
from mindspore.communication import init
ms.set_context(mode=ms.GRAPH_MODE)
init()
Here the dataset loading and network definition is consistent with the single card model, with the initialization of network parameters and optimizer parameters deferred through the no_init_parameters
interface. The code is as follows:
import os
import mindspore.dataset as ds
from mindspore import nn
from mindspore.parallel.auto_parallel import AutoParallel
from mindspore.nn.utils import no_init_parameters
def create_dataset(batch_size):
dataset_path = os.getenv("DATA_PATH")
dataset = ds.MnistDataset(dataset_path)
image_transforms = [
ds.vision.Rescale(1.0 / 255.0, 0),
ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
ds.vision.HWC2CHW()
]
label_transform = ds.transforms.TypeCast(ms.int32)
dataset = dataset.map(image_transforms, 'image')
dataset = dataset.map(label_transform, 'label')
dataset = dataset.batch(batch_size)
return dataset
data_set = create_dataset(32)
class Network(nn.Cell):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.dense_relu_sequential = nn.SequentialCell(
nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),
nn.ReLU(),
nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),
nn.ReLU(),
nn.Dense(512, 10, weight_init="normal", bias_init="zeros")
)
def construct(self, x):
x = self.flatten(x)
logits = self.dense_relu_sequential(x)
return logits
with no_init_parameters():
net = Network()
optimizer = nn.SGD(net.trainable_params(), 1e-2)
In this step, we need to define the loss function and the training process. Parallel mode is set to semi-automatic parallel mode and optimizer parallel via the top-level AutoParallel
interface, and both interfaces are called to configure gradient accumulation:
nn.WithLossCell
interface is called to wrap the network and loss functions.GradAccumulationCell
around the LossCell and specify a MicroBatch size of 4. Refer to the relevant interfaces in the overview of this chapter for more details.import mindspore as ms
from mindspore import nn, train
from mindspore.parallel import GradAccumulation
loss_fn = nn.CrossEntropyLoss()
loss_cb = train.LossMonitor(100)
net = GradAccumulation(nn.WithLossCell(net, loss_fn), 4)
# set paralllel mode and enable parallel optimizer
net = AutoParallel(net)
net.hsdp()
model = ms.Model(net, optimizer=optimizer)
model.train(10, data_set, callbacks=[loss_cb])
Gradient accumulation training is better suited to the
model.train
approach, due to the complexity of the TrainOneStep logic under gradient accumulation, whereasmodel.train
internally wraps the TrainOneStepCell for gradient accumulation, which is much easier to use.
Next, the corresponding script is called by the command. Take the mpirun
startup method, the 8-card distributed training script as an example, and perform the distributed training:
bash run.sh
After training, the part of results about the Loss are saved in log_output/worker_*.log
. The example is as follows:
epoch: 1 step: 100, loss is 7.793933868408203
epoch: 1 step: 200, loss is 2.6476094722747803
epoch: 1 step: 300, loss is 1.784448266029358
epoch: 1 step: 400, loss is 1.402374029159546
epoch: 1 step: 500, loss is 1.355136752128601
epoch: 1 step: 600, loss is 1.1950846910476685
...
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。