diff --git a/tutorials/source_en/custom_program/hook_program.md b/tutorials/source_en/custom_program/hook_program.md index b7dcb4fb61c5bb889f36d16a62014a79a4be9786..3211052ad7a256b3edbe4ac7218d6c2d2c454ec1 100644 --- a/tutorials/source_en/custom_program/hook_program.md +++ b/tutorials/source_en/custom_program/hook_program.md @@ -45,7 +45,7 @@ For more descriptions of the HookBackward operator, refer to the [API documentat ## register_forward_pre_hook Function in Cell Object -The user can use the `register_forward_pre_hook` function on the Cell object to register a custom Hook function to capture data that is passed to that Cell object. This function does not work in static graph mode and inside functions modified with `@jit`. The `register_forward_pre_hook` function takes the Hook function as an input and returns a `handle` object that corresponds to the Hook function. The user can remove the corresponding Hook function by calling the `remove()` function of the `handle` object. Each call to the `register_forward_pre_hook` function returns a different `handle` object. Hook functions should be defined in the following way. +The user can use the `register_forward_pre_hook` function on the Cell object to register a custom Hook function to capture data that is passed to that Cell object. The `register_forward_pre_hook` function takes the Hook function as an input and returns a `handle` object that corresponds to the Hook function. The user can remove the corresponding Hook function by calling the `remove()` function of the `handle` object. Each call to the `register_forward_pre_hook` function returns a different `handle` object. Hook functions should be defined in the following way. ```python def forward_pre_hook_fn(cell, inputs): @@ -149,7 +149,7 @@ For more information about the `register_forward_pre_hook` function of the Cell ## register_forward_hook Function of Cell Object -The user can use the `register_forward_hook` function on the Cell object to register a custom Hook function that captures the data passed forward to the Cell object and the output data of the Cell object. This function does not work in static graph mode and inside functions modified with `@jit`. The `register_forward_hook` function takes the Hook function as an input and returns a `handle` object that corresponds to the Hook function. The user can remove the corresponding Hook function by calling the `remove()` function of the `handle` object. Each call to the `register_forward_hook` function returns a different `handle` object. Hook functions should be defined in the following way. +The user can use the `register_forward_hook` function on the Cell object to register a custom Hook function that captures the data passed forward to the Cell object and the output data of the Cell object. The `register_forward_hook` function takes the Hook function as an input and returns a `handle` object that corresponds to the Hook function. The user can remove the corresponding Hook function by calling the `remove()` function of the `handle` object. Each call to the `register_forward_hook` function returns a different `handle` object. Hook functions should be defined in the following way. The sample code is as follows: @@ -214,7 +214,7 @@ For more information about the `register_forward_hook` function of the Cell obje ## register_backward_pre_hook Function of Cell Object -The user can use the `register_backward_pre_hook` function on the Cell object to register a custom Hook function that captures the gradient associated with the Cell object when the network is back propagated. This function does not work in graph mode or inside functions modified with `@jit`. The `register_backward_pre_hook` function takes the Hook function as an input and returns a `handle` object that corresponds to the Hook function. The user can remove the corresponding Hook function by calling the `remove()` function of the `handle` object. Each call to the `register_backward_pre_hook` function will return a different `handle` object. +The user can use the `register_backward_pre_hook` function on the Cell object to register a custom Hook function that captures the gradient associated with the Cell object when the network is back propagated. The `register_backward_pre_hook` function takes the Hook function as an input and returns a `handle` object that corresponds to the Hook function. The user can remove the corresponding Hook function by calling the `remove()` function of the `handle` object. Each call to the `register_backward_pre_hook` function will return a different `handle` object. Unlike the custom Hook function used by the HookBackward operator, the inputs of the Hook function used by `register_backward_pre_hook` contains `cell`, which represents the information of the Cell object, the gradient passed to the Cell object in reverse of the Cell object. @@ -279,7 +279,7 @@ For more information about the `register_backward_pre_hook` function of the Cell ## register_backward_hook Function of Cell Object -The user can use the `register_backward_hook` function on the Cell object to register a custom Hook function that captures the gradient associated with the Cell object when the network is back propagated. This function does not work in graph mode or inside functions modified with `@jit`. The `register_backward_hook` function takes the Hook function as an input and returns a `handle` object that corresponds to the Hook function. The user can remove the corresponding Hook function by calling the `remove()` function of the `handle` object. Each call to the `register_backward_hook` function will return a different `handle` object. +The user can use the `register_backward_hook` function on the Cell object to register a custom Hook function that captures the gradient associated with the Cell object when the network is back propagated. The `register_backward_hook` function takes the Hook function as an input and returns a `handle` object that corresponds to the Hook function. The user can remove the corresponding Hook function by calling the `remove()` function of the `handle` object. Each call to the `register_backward_hook` function will return a different `handle` object. Unlike the custom Hook function used by the HookBackward operator, the inputs of the Hook function used by `register_backward_hook` contains `cell`, which represents the information of the Cell object, the gradient passed to the Cell object in reverse, and the gradient of the reverse output of the Cell object. diff --git a/tutorials/source_zh_cn/custom_program/hook_program.ipynb b/tutorials/source_zh_cn/custom_program/hook_program.ipynb index c4bd88b29aa60cdfafc1f4f4ad8aeb2005368145..b2bf57061e1388468567ab9a2cc0cc0ca1034c48 100644 --- a/tutorials/source_zh_cn/custom_program/hook_program.ipynb +++ b/tutorials/source_zh_cn/custom_program/hook_program.ipynb @@ -26,12 +26,23 @@ }, { "cell_type": "code", + "execution_count": 15, "metadata": { "ExecuteTime": { "end_time": "2024-08-15T03:32:04.585336Z", "start_time": "2024-08-15T03:32:04.578481Z" } }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "hook_fn print grad_out: (Tensor(shape=[], dtype=Float32, value= 2),)\n", + "output: (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))\n" + ] + } + ], "source": [ "import mindspore as ms\n", "from mindspore import ops\n", @@ -54,18 +65,7 @@ "\n", "output = net(ms.Tensor(1, ms.float32), ms.Tensor(2, ms.float32))\n", "print(\"output:\", output)" - ], - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "hook_fn print grad_out: (Tensor(shape=[], dtype=Float32, value= 2),)\n", - "output: (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))\n" - ] - } - ], - "execution_count": 15 + ] }, { "cell_type": "markdown", @@ -75,7 +75,7 @@ "\n", "## Cell对象的register_forward_pre_hook功能\n", "\n", - "用户可以对Cell对象使用`register_forward_pre_hook`函数来注册一个自定义的Hook函数,用来捕获正向传入该Cell对象的数据。该功能在静态图模式下和在使用`@jit`修饰的函数内不起作用。`register_forward_pre_hook`函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的`handle`对象。用户可以通过调用`handle`对象的`remove()`函数来删除与之对应的Hook函数。每一次调用`register_forward_pre_hook`函数,都会返回一个不同的`handle`对象。Hook函数应该按照以下的方式进行定义。" + "用户可以对Cell对象使用`register_forward_pre_hook`函数来注册一个自定义的Hook函数,用来捕获正向传入该Cell对象的数据。`register_forward_pre_hook`函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的`handle`对象。用户可以通过调用`handle`对象的`remove()`函数来删除与之对应的Hook函数。每一次调用`register_forward_pre_hook`函数,都会返回一个不同的`handle`对象。Hook函数应该按照以下的方式进行定义。" ] }, { @@ -227,7 +227,7 @@ "\n", "## Cell对象的register_forward_hook功能\n", "\n", - "用户可以在Cell对象上使用`register_forward_hook`函数来注册一个自定义的Hook函数,用来捕获正向传入Cell对象的数据和Cell对象的输出数据。该功能在静态图模式下和在使用`@jit`修饰的函数内不起作用。`register_forward_hook`函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的`handle`对象。用户可以通过调用`handle`对象的`remove()`函数来删除与之对应的Hook函数。每一次调用`register_forward_hook`函数,都会返回一个不同的`handle`对象。Hook函数应该按照以下的方式进行定义。\n", + "用户可以在Cell对象上使用`register_forward_hook`函数来注册一个自定义的Hook函数,用来捕获正向传入Cell对象的数据和Cell对象的输出数据。`register_forward_hook`函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的`handle`对象。用户可以通过调用`handle`对象的`remove()`函数来删除与之对应的Hook函数。每一次调用`register_forward_hook`函数,都会返回一个不同的`handle`对象。Hook函数应该按照以下的方式进行定义。\n", "\n", "示例代码如下:" ] @@ -321,7 +321,7 @@ "\n", "## Cell对象的register_backward_pre_hook功能\n", "\n", - "用户可以在Cell对象上使用`register_backward_pre_hook`函数来注册一个自定义的Hook函数,用来捕获网络反向传播时与Cell对象相关联的梯度。该功能在图模式下或者在使用`@jit`修饰的函数内不起作用。`register_backward_pre_hook`函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的`handle`对象。用户可以通过调用`handle`对象的`remove()`函数来删除与之对应的Hook函数。每一次调用`register_backward_pre_hook`函数,都会返回一个不同的`handle`对象。\n", + "用户可以在Cell对象上使用`register_backward_pre_hook`函数来注册一个自定义的Hook函数,用来捕获网络反向传播时与Cell对象相关联的梯度。`register_backward_pre_hook`函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的`handle`对象。用户可以通过调用`handle`对象的`remove()`函数来删除与之对应的Hook函数。每一次调用`register_backward_pre_hook`函数,都会返回一个不同的`handle`对象。\n", "\n", "与HookBackward算子所使用的自定义Hook函数有所不同,`register_backward_pre_hook`使用的Hook函数的入参中,包含了表示Cell对象信息`cell`以及反向传入到Cell对象的梯度。\n", "示例代码如下:" @@ -414,7 +414,7 @@ "\n", "## Cell对象的register_backward_hook功能\n", "\n", - "用户可以在Cell对象上使用`register_backward_hook`函数来注册一个自定义的Hook函数,用来捕获网络反向传播时与Cell对象相关联的梯度。该功能在图模式下或者在使用`@jit`修饰的函数内不起作用。`register_backward_hook`函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的`handle`对象。用户可以通过调用`handle`对象的`remove()`函数来删除与之对应的Hook函数。每一次调用`register_backward_hook`函数,都会返回一个不同的`handle`对象。\n", + "用户可以在Cell对象上使用`register_backward_hook`函数来注册一个自定义的Hook函数,用来捕获网络反向传播时与Cell对象相关联的梯度。`register_backward_hook`函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的`handle`对象。用户可以通过调用`handle`对象的`remove()`函数来删除与之对应的Hook函数。每一次调用`register_backward_hook`函数,都会返回一个不同的`handle`对象。\n", "\n", "与HookBackward算子所使用的自定义Hook函数有所不同,`register_backward_hook`使用的Hook函数的入参中,包含了表示Cell对象信息`cell`、反向传入到Cell对象的梯度、以及Cell对象的反向输出的梯度。\n", "\n",