如何使用 tf.tensor_scatter_nd_update 更新 3d 张量?

问题描述 投票:0回答:0

我迷路了,想弄清楚如何将

tf.tensor_scatter_nd_update
用于 3d 张量。

使用示例代码,输入数据的形状为

(64,8,768)
第一维是批量大小,第二维是通道/补丁,第三维是特征大小。

我想在给定形状

(64,4)
的索引的情况下更新输入张量。注意批处理索引中的每个条目如何只更新 8 个通道/补丁中的 4 个。

用于更新的特征由

updateTensor
定义,其形状为
(64,4,768)

期望的输出与输入大小相同

(64,8,768)

虽然在尝试

tf.tensor_scatter_nd_update
之后,我得到以下错误。

InvalidArgumentError: {{function_node __wrapped__TensorScatterUpdate_device_/job:localhost/replica:0/task:0/device:GPU:0}} Inner dimensions of output shape must match inner dimensions of updates shape. Output: [64,8,768] updates: [64,4,768] [Op:TensorScatterUpdate]

最低限度的工作代码

import tensorflow as tf

featureSize = 768
batchSize = 64
patchCount = 8
toUpdatePatchCount = 4

inputTensor = tf.random.normal([batchSize,patchCount,featureSize])
# TensorShape([64, 8, 768])

indices = tf.argsort(
    tf.random.uniform(shape=(batchSize, toUpdatePatchCount)), axis=-1
)
# TensorShape([64, 4])

updateTensor = tf.random.normal([batchSize,toUpdatePatchCount,featureSize])
# TensorShape([64, 4, 768]).shape

tf.tensor_scatter_nd_update(inputTensor,indices,updateTensor)

我要实现的目标是

tf.tensor_scatter_nd_update
吗?

tensorflow2.0 tensor
© www.soinside.com 2019 - 2024. All rights reserved.