我迷路了,想弄清楚如何将
tf.tensor_scatter_nd_update
用于 3d 张量。
使用示例代码,输入数据的形状为
(64,8,768)
第一维是批量大小,第二维是通道/补丁,第三维是特征大小。
我想在给定形状
(64,4)
的索引的情况下更新输入张量。注意批处理索引中的每个条目如何只更新 8 个通道/补丁中的 4 个。
用于更新的特征由
updateTensor
定义,其形状为(64,4,768)
期望的输出与输入大小相同
(64,8,768)
虽然在尝试
tf.tensor_scatter_nd_update
之后,我得到以下错误。
InvalidArgumentError: {{function_node __wrapped__TensorScatterUpdate_device_/job:localhost/replica:0/task:0/device:GPU:0}} Inner dimensions of output shape must match inner dimensions of updates shape. Output: [64,8,768] updates: [64,4,768] [Op:TensorScatterUpdate]
最低限度的工作代码
import tensorflow as tf
featureSize = 768
batchSize = 64
patchCount = 8
toUpdatePatchCount = 4
inputTensor = tf.random.normal([batchSize,patchCount,featureSize])
# TensorShape([64, 8, 768])
indices = tf.argsort(
tf.random.uniform(shape=(batchSize, toUpdatePatchCount)), axis=-1
)
# TensorShape([64, 4])
updateTensor = tf.random.normal([batchSize,toUpdatePatchCount,featureSize])
# TensorShape([64, 4, 768]).shape
tf.tensor_scatter_nd_update(inputTensor,indices,updateTensor)
我要实现的目标是
tf.tensor_scatter_nd_update
吗?