The operator information describes the supported inputs and outputs data type, the supported inputs and outputs format, attributes, and target (platform information) of the operator implementation. It is used to select and map operators by the backend. The operator information can be defined by using the CustomRegOp API, then you can use the custom_info_register decorator or just pass it to the reg_info
parameter of Custom primitive to bind the information to the operator implementation. The operator information will be registered to the operator information library on the MindSpore C++ side at last. The reg_info
parameter takes higher priority than the custom_info_register
decorator.
The target value in operator information can be "Ascend", "GPU" or "CPU", which describes the operator information on a specific target. For the same operator implementation, it may have different supported data types on different targets, so you can use the target value in operator information to differ this. The operator information on a specific target will be registered only once.
- The numbers and sequences of the input and output information defined in the operator information must be the same as those in the parameters of the operator implementation.
- For the custom operator of akg type, if the operator has attributes, you need to register operator information. The attribute name in the operator information must be consistent with the attribute name used in the operator implementation. For the custom operator of tbe type, you need to register operator information. For the custom operator of aot type, since the operator implementation needs to be compiled into a dynamic library in advance, it is not possible to bind operator information by means of decorators, and the operator information can only be passed in through the
reg_info
parameter.- If the custom operator only supports a specific input and output data type or data format, the operator information needs to be registered so that the data type and data format can be checked when the operator is selected in the backend. For the case where the operator information is not provided, the information will be derived from the inputs of the current operator.
If an operator needs to support automatic differentiation, the backpropagation (bprop) function needs to be defined first and then passed to the bprop
parameter of Custom
primitive. In the bprop function, you need to describe the backward computation logic that uses the forward input, forward output, and output gradients to obtain the input gradients. The backward computation logic can be composed of built-in operators or custom operators.
Note the following points when defining the bprop function for operators:
Take test_grad.py as an example to show the usage of the backpropagation function:
import numpy as np
import mindspore as ms
from mindspore.nn import Cell
import mindspore.ops as ops
ms.set_context(mode=ms.GRAPH_MODE, device_target="GPU")
# Forward computation of custom operator
def square(x):
y = output_tensor(x.shape, x.dtype)
for i0 in range(x.shape[0]):
y[i0] = y[i0] * y[i0]
return y
# Backward computation of custom operator
def square_grad(x, dout):
dx = output_tensor(x.shape, x.dtype)
for i0 in range(x.shape[0]):
dx[i0] = 2.0 * x[i0]
for i0 in range(x.shape[0]):
dx[i0] = dx[i0] * dout[i0]
return dx
# Backpropagation function
def bprop():
op = ops.Custom(square_grad, lambda x, _: x, lambda x, _: x, func_type="akg")
def custom_bprop(x, out, dout):
dx = op(x, dout)
return (dx,)
return custom_bprop
class Net(Cell):
def __init__(self):
super(Net, self).__init__()
# Define a custom operator of akg type and provide a backpropagation function
self.op = ops.Custom(square, lambda x: x, lambda x: x, bprop=bprop(), func_type="akg")
def construct(self, x):
return self.op(x)
if __name__ == "__main__":
x = np.array([1.0, 4.0, 9.0]).astype(np.float32)
dx = ms.grad(Net())(ms.Tensor(x))
print(dx)
The following points need to be explained in this example:
custom_bprop
function and used inside the custom_bprop
function.Execute case:
python test_grad.py
The execution result is as follows:
[ 2. 8. 18.]
More examples can be found in the MindSpore source code tests/st/ops/graph_kernel/custom.
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。