代码拉取完成,页面将自动刷新
tf.compat.v1.scatter_mul(
ref,
indices,
updates,
use_locking=False,
name=None
) -> Tensor
For more information, see tf.compat.v1.scatter_mul.
mindspore.ops.scatter_mul(
input_x,
indices,
updates
) -> Tensor
For more information, see mindspore.ops.scatter_mul.
TensorFlow: In-place scatter update for Tensor.
MindSpore: Implement the same function as TensorFlow. TensorFlow can use the use_locking parameter to control whether locking is used when updating the tensor. Locking ensures that the Tensor can be updated correctly in a multi-threaded environment, and the default is False. MindSpore implements unlocked function by default.
Categories | Subcategories | TensorFlow | MindSpore | Differences |
---|---|---|---|---|
Parameters | Parameter1 | ref | input_x | Same function, different parameter names |
Parameter2 | indices | indices | - | |
Parameter3 | updates | updates | - | |
Parameter4 | use_locking | - | MindSpore does not have this parameter and implements unlocked functionality by default. | |
Parameter5 | name | - | Not involved |
When use_locking is False in TensorFlow, the two APIs implement the same function.
# TensorFlow
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
ref = tf.Variable(np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), dtype=tf.float32)
indices = tf.constant(np.array([0, 1]), dtype=tf.int32)
updates = tf.constant(np.array([[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]]), dtype=tf.float32)
op = tf.compat.v1.scatter_mul(ref, indices, updates, use_locking=False)
init = tf.compat.v1.global_variables_initializer()
with tf.compat.v1.Session() as sess:
sess.run(init)
out = sess.run(op)
print(out)
# [[ 1. 6. 15.]
# [ 2. 8. 18.]]
# MindSpore
import numpy as np
import mindspore
from mindspore import Tensor, Parameter
import mindspore.ops as ops
input_x = Parameter(Tensor(np.array([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]), mindspore.float32), name="x")
indices = Tensor(np.array([0, 1]), mindspore.int32)
updates = Tensor(np.array([[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]]), mindspore.float32)
output = ops.scatter_mul(input_x, indices, updates)
print(output)
# [[ 1. 6. 15.]
# [ 2. 8. 18.]]
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。