2.3K Star 8K Fork 4.2K

GVPMindSpore / mindspore

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
grad_comm_ops.py 19.91 KB
一键复制 编辑 原始数据 按行查看 历史
yao_yf 提交于 2022-02-16 16:35 . auto parallel adasum python part
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate bprop for comm ops"""
from mindspore import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import functional as F
from mindspore.communication import get_rank, get_group_size
from mindspore.parallel._utils import _get_enable_parallel_optimizer, _get_grad_accumulation_shard
from .. import operations as P
from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, NeighborExchange, AlltoAll, NeighborExchangeV2,
Broadcast, _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, _AllSwap,
_VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator, _MicroStepAllGather)
from .grad_base import bprop_getters
from ..operations._inner_ops import Send, Receive
from ..operations import _grad_ops as G
@bprop_getters.register(AllReduce)
def get_bprop_all_reduce(self):
"""Generate bprop for AllReduce, do allreduce or allgather, allgather for sparse feature."""
all_reduce_grad = AllReduce(ReduceOp.SUM, self.group)
all_gather = AllGather(group=self.group)
if self.instance_name:
instance_name = "grad" + self.instance_name
all_reduce_grad.set_prim_instance_name(instance_name)
equal = P.Equal()
cast = P.Cast()
mul = P.Mul()
div = P.RealDiv()
dtype = P.DType()
if self.op == ReduceOp.PROD:
def bprop(x, out, dout):
dy1 = mul(dout, out)
dy2 = all_reduce_grad(dy1)
dx = div(dy2, x)
return (dx,)
elif self.op == ReduceOp.SUM:
def bprop(x, out, dout):
if F.issubclass_(F.typeof(dout), mstype.tensor):
dx = all_reduce_grad(dout)
else:
indices = all_gather(dout.indices)
grad = all_gather(dout.values)
dx = RowTensor(indices, grad, dout.dense_shape)
return (dx,)
else:
def bprop(x, out, dout):
if F.issubclass_(F.typeof(dout), mstype.tensor):
dx = all_reduce_grad(dout)
z = equal(x, out)
z = cast(z, dtype(dx))
dx = mul(dx, z)
else:
indices = all_gather(dout.indices)
grad = all_gather(dout.values)
z = equal(x, out)
z = cast(z, dtype(grad))
grad = mul(grad, z)
dx = RowTensor(indices, grad, dout.dense_shape)
return (dx,)
return bprop
@bprop_getters.register(Send)
def get_bprop_send(self):
"""Generate bprop for Send."""
shape = self.get_attr_dict()["shape"]
dtype = self.get_attr_dict()["dtype"]
send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group_back)
virtual_input = Tensor(0.0, dtype)
def bprop(x, out, dout):
dx = send_grad(virtual_input)
return (dx,)
return bprop
@bprop_getters.register(Receive)
def get_bprop_receive(self):
"""Generate bprop for Receive."""
receive_grad = Send(self.tag, self.rank, self.group_back)
depend = P.Depend()
cast = P.Cast()
out_tensor = Tensor(0.0, mstype.float16)
is_opt_shard = _get_enable_parallel_optimizer()
def bprop(x, out, dout):
send_out = receive_grad(dout)
if is_opt_shard:
dx = depend(F.zeros_like(x), send_out)
else:
dx = depend(cast(out_tensor, F.dtype(x)), send_out)
return (dx,)
return bprop
@bprop_getters.register(_VirtualAdd)
def get_bprop_virtual_add(self):
"""Generate bprop for _VirtualAdd"""
def bprop(x, grad_accu, out, dout):
return (dout + grad_accu, zeros_like(grad_accu))
return bprop
@bprop_getters.register(_VirtualAssignAdd)
def get_bprop_virtual_assign_add(self):
"""Generate bprop for VirtualAssignAdd."""
assign_add = P.AssignAdd()
cast = P.Cast()
dtype = P.DType()
out_tensor = Tensor(0.0, mstype.float16)
reduce_scatter = None
group = self.get_attr_dict().get("group", None)
fusion = self.get_attr_dict().get("fusion", 0)
if group:
reduce_scatter = ReduceScatter(ReduceOp.SUM, group).add_prim_attr("fusion", fusion)
if self.instance_name:
instance_name = "_grad_accumulation_shard_grad" + self.instance_name
reduce_scatter.set_prim_instance_name(instance_name)
# For pipeline training, as the fused communication will be visited later
# this may make memory increase, so we need to add a tag to let the
# fused communication not be effective.
reduce_scatter.add_prim_attr("not_delay_fusion", True)
def bprop(x, y, out, dout):
if reduce_scatter:
dout = reduce_scatter(dout)
temp = assign_add(y, dout)
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(y))), temp)
return bprop
@bprop_getters.register(_VirtualAccuGrad)
def get_bprop_virtual_accu_grad(self):
"""Generate bprop for VirtualAccuGrad."""
cast = P.Cast()
dtype = P.DType()
out_tensor = Tensor(0.0, mstype.float16)
def bprop(x, y, out, dout):
return (F.depend(y, dout), cast(out_tensor, dtype(y)))
return bprop
@bprop_getters.register(_MirrorMicroStepOperator)
def get_bprop_mirror_micro_step_operator(self):
"""
Backpropagator for _MirrorMicroStepOperator, do allreduce or allgather for the devices in the group,
allgather for sparse feature.
"""
group = self.group
dev_num = self.dev_num
mean_flag = self.mean_flag
scale = 1 / dev_num
all_reduce = AllReduce(group=group)
fusion = self.get_attr_dict()["fusion"]
all_reduce.add_prim_attr("fusion", fusion)
if hasattr(self, 'parameter'):
parameter = self.parameter
all_reduce.add_prim_attr("parameter", parameter)
if self.instance_name:
instance_name = "grad_mirror" + self.instance_name
all_reduce.set_prim_instance_name(instance_name)
cast = P.Cast()
dtype = P.DType()
assign = P.Assign()
if "parameter_micro" in self.get_attr_dict():
assign.add_prim_attr("parameter_micro", 0)
out_tensor = Tensor(1.0, mstype.float16)
opt_shard = _get_enable_parallel_optimizer()
def bprop(x, z, out, dout):
real_grad = z
assign_out = dout
if mean_flag:
if F.issubclass_(F.typeof(dout), mstype.tensor):
z = F.depend(z, dout)
real_grad = all_reduce(z)
real_grad = F.tensor_mul(real_grad, scale)
assign_out = assign(z, real_grad)
else:
if F.issubclass_(F.typeof(dout), mstype.tensor):
z = F.depend(z, dout)
real_grad = all_reduce(z)
assign_out = assign(z, real_grad)
if opt_shard:
return (real_grad, cast(out_tensor, dtype(z)))
return F.depend((cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))), assign_out)
return bprop
@bprop_getters.register(Broadcast)
def get_bprop_broad_cast(self):
"""Generate bprop for Broadcast."""
def bprop(x, out, dout):
return (dout,)
return bprop
@bprop_getters.register(AllGather)
def get_bprop_all_gather(self):
"""Generate bprop for AllGather"""
fusion = self.get_attr_dict()["fusion"]
reduce_scatter = ReduceScatter(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
if self.instance_name:
instance_name = "grad_" + self.instance_name
reduce_scatter.set_prim_instance_name(instance_name)
mean_flag = self.get_attr_dict()["mean_flag"]
scale = 1 / self.rank_size
def bprop(x, out, dout):
dx = reduce_scatter(dout)
if mean_flag:
dx = F.tensor_mul(dx, scale)
return (dx,)
return bprop
@bprop_getters.register(_MiniStepAllGather)
def get_bprop_mini_step_all_gather(self):
"""Generate bprop for _MiniStepAllGather"""
fusion = self.get_attr_dict()["fusion"]
mean_flag = self.get_attr_dict()["mean_flag"]
do_mirror = self.get_attr_dict()["do_mirror"]
add_accu = self.get_attr_dict().get("add_accu", False)
gradient_shard = _get_grad_accumulation_shard()
scale = 1 / self.rank_size
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
assign_add = P.AssignAdd()
if self.instance_name:
instance_name = "grad_" + self.instance_name
all_reduce.set_prim_instance_name(instance_name)
rank = get_rank(self.group)
dev_num = get_group_size(self.group)
split = P.Split(output_num=dev_num)
def bprop(x, z, out, dout):
if do_mirror:
if not gradient_shard:
z = F.depend(z, F.assign_add(z, dout))
grad = all_reduce(z)
dx = split(grad)[rank]
if mean_flag:
dx = F.tensor_mul(dx, scale)
else:
dout = F.depend(dout, z)
grad = all_reduce(dout)
dx = split(grad)[rank]
if mean_flag:
dx = F.tensor_mul(dx, scale)
if add_accu:
z = assign_add(z, dx)
dx = F.depend(dx, z)
else:
dx = dout
return (dx, zeros_like(z))
return bprop
@bprop_getters.register(_MicroStepAllGather)
def get_bprop_micro_step_all_gather(self):
"""Generate bprop for _MicroStepAllGather"""
fusion = self.get_attr_dict()["fusion"]
mean_flag = self.get_attr_dict()["mean_flag"]
do_mirror = self.get_attr_dict()["do_mirror"]
scale = 1 / self.rank_size
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
rank = get_rank(self.group)
dev_num = get_group_size(self.group)
split = P.Split(output_num=dev_num)
if self.instance_name:
instance_name = "grad_" + self.instance_name
all_reduce.set_prim_instance_name(instance_name)
cast = P.Cast()
dtype = P.DType()
out_tensor = Tensor(1.0, mstype.float16)
# z: accu_grad
def bprop(x, z, out, dout):
z = F.depend(z, dout)
if not do_mirror:
return (z, cast(out_tensor, dtype(z)))
real_grad = all_reduce(z)
real_grad = split(real_grad)[rank]
if mean_flag:
real_grad = F.tensor_mul(real_grad, scale)
return (real_grad, cast(out_tensor, dtype(z)))
return bprop
@bprop_getters.register(_HostAllGather)
def get_bprop_host_all_gather(self):
"""Generate bprop for _HostAllGather"""
host_all_gather_grad = _HostReduceScatter(ReduceOp.SUM, self.group)
if self.instance_name:
instance_name = "grad" + self.instance_name
host_all_gather_grad.set_prim_instance_name(instance_name)
def bprop(x, out, dout):
dx = host_all_gather_grad(dout)
return (dx,)
return bprop
@bprop_getters.register(ReduceScatter)
def get_bprop_reduce_scatter(self):
"""Generate bprop for ReduceScatter"""
reduce_scatter_grad = AllGather(self.group)
if self.instance_name:
instance_name = "grad" + self.instance_name
reduce_scatter_grad.set_prim_instance_name(instance_name)
if self.op != ReduceOp.SUM:
raise RuntimeError("The reducescatter bprop only support ReduceOp.SUM until now.")
def bprop(x, out, dout):
dx = reduce_scatter_grad(dout)
return (dx,)
return bprop
@bprop_getters.register(_AllSwap)
def get_bprop_allswap(self):
"""Generate bprop for _AllSwap."""
all_swap_grad = _AllSwap(self.group)
if self.instance_name:
instance_name = "grad" + self.instance_name
all_swap_grad.set_prim_instance_name(instance_name)
def bprop(x, send_size, recv_size, out, dout):
dx = all_swap_grad(dout, recv_size, send_size)
return (dx, zeros_like(send_size), zeros_like(recv_size))
return bprop
@bprop_getters.register(_HostReduceScatter)
def get_bprop_host_reduce_scatter(self):
"""Generate bprop for _HostReduceScatter"""
host_reduce_scatter_grad = _HostAllGather(self.group)
if self.instance_name:
instance_name = "grad" + self.instance_name
host_reduce_scatter_grad.set_prim_instance_name(instance_name)
if self.op != ReduceOp.SUM:
raise RuntimeError("The hostreducescatter bprop only support ReduceOp.SUM until now.")
def bprop(x, out, dout):
dx = host_reduce_scatter_grad(dout)
return (dx,)
return bprop
@bprop_getters.register(NeighborExchange)
def get_bprop_neighborexchange(self):
"""Generate bprop for NeighborExchange."""
group = self.group
send_rank_ids = self.recv_rank_ids
recv_rank_ids = self.send_rank_ids
recv_shapes = self.send_shapes
send_shapes = self.recv_shapes
recv_type = self.recv_type
neighborexchange_grad = NeighborExchange(send_rank_ids, recv_rank_ids, recv_shapes, send_shapes, recv_type, group)
def bprop(x, out, dout):
return (neighborexchange_grad(dout),)
return bprop
@bprop_getters.register(AlltoAll)
def get_bprop_all_to_all(self):
"""Generate bprop for AlltoAll."""
all_to_all_grad = AlltoAll(self.split_count, self.concat_dim, self.split_dim, self.group)
if self.instance_name:
instance_name = "grad" + self.instance_name
all_to_all_grad.set_prim_instance_name(instance_name)
def bprop(x, out, dout):
dx = all_to_all_grad(dout)
return (dx,)
return bprop
@bprop_getters.register(NeighborExchangeV2)
def get_bprop_neighborexchangev2(self):
"""Generate bprop for NeighborExchangeV2."""
group = self.group
send_rank_ids = self.recv_rank_ids
recv_rank_ids = self.send_rank_ids
send_lens = self.recv_lens
recv_lens = self.send_lens
data_format = self.data_format
neighborexchangev2_grad = G.NeighborExchangeV2Grad(send_rank_ids, send_lens, recv_rank_ids,
recv_lens, data_format, group)
def bprop(x, out, dout):
return (neighborexchangev2_grad(dout),)
return bprop
@bprop_getters.register(_MirrorOperator)
def get_bprop_mirror_operator(self):
"""
Backpropagator for _MirrorOperator, do allreduce or allgather for the devices in group(only for one group),
allgather for sparse feature.
"""
group = self.get_attr_dict()['group']
dev_num = self.get_attr_dict()['dev_num']
mean_flag = self.get_attr_dict()['mean_flag']
if dev_num > 1:
all_reduce = AllReduce(group=group)
all_gather = AllGather(group=group)
mul = P.Mul()
cast = P.Cast()
fusion = self.get_attr_dict()["fusion"]
all_reduce.add_prim_attr("fusion", fusion)
if hasattr(self, 'parameter'):
parameter = self.parameter
all_reduce.add_prim_attr("parameter", parameter)
if self.instance_name:
instance_name = "grad_mirror" + self.instance_name
all_reduce.set_prim_instance_name(instance_name)
def bprop(x, out, dout):
if dev_num == 1:
return (dout,)
if mean_flag:
if F.issubclass_(F.typeof(dout), mstype.tensor):
dx = all_reduce(dout)
float_one = F.scalar_cast(1.0, F.dtype(dx))
num = F.scalar_cast(dev_num, F.dtype(dx))
dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
else:
indices = all_gather(dout.indices)
grad = all_gather(dout.values)
float_one = F.scalar_cast(1.0, F.dtype(grad))
num = F.scalar_cast(dev_num, F.dtype(grad))
grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad)))
dx = RowTensor(indices, grad, dout.dense_shape)
else:
if F.issubclass_(F.typeof(dout), mstype.tensor):
dx = all_reduce(dout)
else:
indices = all_gather(dout.indices)
grad = all_gather(dout.values)
dx = RowTensor(indices, grad, dout.dense_shape)
return (dx,)
return bprop
@bprop_getters.register(_MirrorMiniStepOperator)
def get_bprop_mirror_mini_step_operator(self):
"""
Backpropagator for _MirrorMiniStepOperator, do allreduce or allgather for the devices in the group,
allgather for sparse feature.
"""
group = self.group
dev_num = self.dev_num
mean_flag = self.mean_flag
all_reduce = AllReduce(group=group)
mul = P.Mul()
cast = P.Cast()
fusion = self.get_attr_dict()["fusion"]
all_reduce.add_prim_attr("fusion", fusion)
if hasattr(self, 'parameter'):
parameter = self.parameter
all_reduce.add_prim_attr("parameter", parameter)
if self.instance_name:
instance_name = "grad_mirror" + self.instance_name
all_reduce.set_prim_instance_name(instance_name)
do_mirror = self.get_attr_dict()["do_mirror"]
def bprop(x, z, out, dout):
if mean_flag:
if F.issubclass_(F.typeof(dout), mstype.tensor):
if do_mirror:
z = F.depend(z, F.assign_add(z, dout))
real_grad = all_reduce(z)
dx = real_grad
else:
dx = dout
float_one = F.scalar_cast(1.0, F.dtype(dx))
num = F.scalar_cast(dev_num, F.dtype(dx))
dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
else:
dx = zeros_like(x) # The grad accumulation do not support row tensor now
else:
if F.issubclass_(F.typeof(dout), mstype.tensor):
if do_mirror:
z = F.depend(z, F.assign_add(z, dout))
real_grad = all_reduce(z)
dx = real_grad
else:
dx = dout
else:
dx = zeros_like(x) # The grad accumulation do not support row tensor now
return (dx, zeros_like(z))
return bprop
@bprop_getters.register(_VirtualDiv)
def get_bprop_virtual_div_operator(self):
"""Backpropagator for _VirtualDiv, do Div for the divisor."""
divisor = self.divisor
op = P.RealDiv()
cast = P.Cast()
dtype = P.DType()
def bprop(x, out, dout):
if F.issubclass_(F.typeof(dout), mstype.tensor):
if F.issubclass_(F.dtype(dout), mstype.bool_) or F.issubclass_(F.dtype(dout), mstype.int32) \
or F.issubclass_(F.dtype(dout), mstype.int16):
return (dout,)
dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout)))
return (dx,)
if F.issubclass_(F.typeof(dout), mstype.tuple_):
dx = ()
input_nums = F.tuple_len(dout)
for i in range(input_nums):
ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i])))
dx = dx + (ele_grad,)
return (dx,)
dx = []
input_nums = F.list_len(dout)
for i in range(input_nums):
ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i])))
dx.append(ele_grad)
return (dx,)
return bprop
@bprop_getters.register(_GetTensorSlice)
def get_bprop_get_tensor_slice_operator(self):
"""Backpropagator for _GetTensorSlice"""
def bprop(x, dev_mat, tensor_map, out, dout):
return (zeros_like(x),)
return bprop
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore
r1.8

搜索帮助