334 Star 1.5K Fork 863

MindSpore / docs

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
ExponentialDecayLR.md 5.57 KB
一键复制 编辑 原始数据 按行查看 历史
luojianing 提交于 2023-07-21 15:16 . replace target=blank

Comparing the function difference with torch.optim.lr_scheduler.ExponentialLR

View Source On Gitee

torch.optim.lr_scheduler.ExponentialLR

torch.optim.lr_scheduler.ExponentialLR(
    optimizer,
    gamma,
    last_epoch=-1,
    verbose=False
)

For more information, see torch.optim.lr_scheduler.ExponentialLR.

mindspore.nn.exponential_decay_lr

mindspore.nn.exponential_decay_lr(
      learning_rate,
      decay_rate,
      total_step,
      step_per_epoch,
      decay_epoch,
      is_stair=False
)

For more information, see mindspore.nn.exponential_decay_lr.

mindspore.nn.ExponentialDecayLR

mindspore.nn.ExponentialDecayLR(
  learning_rate,
  decay_rate,
  decay_steps,
  is_stair=False
)

For more information, see mindspore.nn.ExponentialDecayLR.

Differences

PyTorch (torch.optim.lr_scheduler.ExponentialLR): The calculating method is $lr * gamma^{epoch}$ . When used, the optimizer is used as input and the learning rate is updated by calling the step method. When verbose is True, the relevant information is printed for each update.

MindSpore (mindspore.nn.exponential_decay_lr): The calculating method is $lr * decay_rate^{p}$ . exponential_decay_lr pre-generates the learning rate list and passes the list into the optimizer.

Categories Subcategories PyTorch MindSpore Differences
Parameter Parameter 1 optimizer Optimizer for PyTorch applications. MindSpore does not have this Parameter
Parameter 2 gamma decay_rate Parameter of decay learning rate, same function, different Parameter name
Parameter 3 last_epoch MindSpore does not have this Parameter
Parameter 4 verbose PyTorch verbose prints information about each update when it is True. MindSpore does not have this Parameter.
Parameter 5 learning_rate MindSpore sets the initial value of the learning rate.
Parameter 6 total_step Total number of steps in MindSpore
Parameter 7 step_per_epoch The number of steps per epoch in MindSpore
Parameter 8 decay_steps The number of decay steps performed by MindSpore
Parameter 9 is_stair When MindSpore is_stair is True, the learning rate decays once every decay_steps.

MindSpore (mindspore.nn.ExponentialDecayLR): The calculating method is $lr * decay_rate^{p}$ . ExponentialDecayLR is passed in the optimizer for training in the way of the computational graph.

Categories Subcategories PyTorch MindSpore Differences
Parameter Parameter 1 optimizer Optimizer for PyTorch applications. MindSpore does not have this Parameter
Parameter 2 gamma decay_rate Parameter of decay learning rate, same function, different Parameter name
Parameter 3 last_epoch MindSpore does not have this Parameter.
Parameter 4 verbose PyTorch verbose prints information about each update when it is True. MindSpore does not have this Parameter.
Parameter 5 learning_rate MindSpore sets the initial value of the learning rate.
Parameter 6 decay_steps The number of decay steps performed by MindSpore
Parameter 7 is_stair When MindSpore is_stair is True, the learning rate decays once every decay_steps.

Code Example

# In MindSpore:
import mindspore as ms
from mindspore import nn

# In MindSpore:exponential_decay_lr
learning_rate = 0.1
decay_rate = 0.9
total_step = 6
step_per_epoch = 2
decay_epoch = 1
output = nn.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch)
print(output)
# out
# [0.1, 0.1, 0.09000000000000001, 0.09000000000000001, 0.08100000000000002, 0.08100000000000002]

# In MindSpore:ExponentialDecayLR
learning_rate = 0.1
decay_rate = 0.9
decay_steps = 4
global_step = ms.Tensor(2, ms.int32)
exponential_decay_lr = nn.ExponentialDecayLR(learning_rate, decay_rate, decay_steps)
result = exponential_decay_lr(global_step)
print(result)
# out
# 0.094868325

# In torch:
import torch
import numpy as np
from torch import optim

model = torch.nn.Sequential(torch.nn.Linear(20, 1))
optimizer = optim.SGD(model.parameters(), 0.1)
exponential_decay_lr = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
myloss = torch.nn.MSELoss()
dataset = [(torch.tensor(np.random.rand(1, 20).astype(np.float32)), torch.tensor([1.]))]

for epoch in range(5):
    for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = myloss(output.view(-1), target)
        loss.backward()
        optimizer.step()
    exponential_decay_lr.step()
    print(exponential_decay_lr.get_last_lr())
#  out
# [0.09000000000000001]
# [0.08100000000000002]
# [0.07290000000000002]
# [0.06561000000000002]
# [0.05904900000000002]
1
https://gitee.com/mindspore/docs.git
git@gitee.com:mindspore/docs.git
mindspore
docs
docs
r2.0

搜索帮助