如何收集未知的第一(批)维度的张量?

问题描述 投票:2回答:2

我有一个形状(?, 3, 2, 5)的张量。我想提供成对的索引来从张量的第一维和第二维中进行选择,其形状为(3, 2)

如果我提供4对这样的对,我希望得到的形状是(?, 4, 5)。我认为这就是batch_gather的用途:在第一个(批量)维度上“广播”收集指数。但这不是它正在做的事情:

import tensorflow as tf
data = tf.placeholder(tf.float32, (None, 3, 2, 5))


indices = tf.constant([
    [2, 1],
    [2, 0],
    [1, 1],
    [0, 1]
], tf.int32)

tf.batch_gather(data, indices)

这导致<tf.Tensor 'Reshape_3:0' shape=(4, 2, 2, 5) dtype=float32>而不是我期待的形状。

如果没有明确索引批次(具有未知大小),我该怎么做我想要的?

python tensorflow
2个回答
0
投票

我想避免使用transpose和Python循环,我认为这很有效。这是设置:

import numpy as np
import tensorflow as tf

shape = None, 3, 2, 5
data = tf.placeholder(tf.int32, shape)
idxs_list = [
    [2, 1],
    [2, 0],
    [1, 1],
    [0, 1]
]
idxs = tf.constant(idxs_list, tf.int32)

这允许我们收集结果:

batch_size, num_idxs, num_channels = tf.shape(data)[0], tf.shape(idxs)[0], shape[-1]

batch_idxs = tf.math.floordiv(tf.range(0, batch_size * num_idxs), num_idxs)[:, None]
nd_idxs = tf.concat([batch_idxs, tf.tile(idxs, (batch_size, 1))], axis=1)

gathered = tf.reshape(tf.gather_nd(data, nd_idxs), (batch_size, num_idxs, num_channels))

当我们运行批量大小的4时,我们得到一个形状为(4, 4, 5)的结果,即(batch_size, num_idxs, num_channels)

vals_shape = 4, *shape[1:]
vals = np.arange(int(np.prod(vals_shape))).reshape(vals_shape)

with tf.Session() as sess:
    result = gathered.eval(feed_dict={data: vals})

numpy索引的关系:

x, y = zip(*idxs_list)
assert np.array_equal(result, vals[:, x, y])

基本上,gather_nd想要第一维中的批量索引,并且对于每个索引对必须重复一次(即,如果有4个索引对,则为[0, 0, 0, 0, 1, 1, 1, 1, 2, ...])。

由于似乎没有tf.repeat,我使用了rangefloordiv,然后qazxs使用所需的(x,y)索引(它们本身平铺concat次)的批次索引。


0
投票

使用batch_sizetf.batch_gather形状的主要尺寸应与tensor张量形状的前导尺寸相匹配。

indice

你最想要的是使用import tensorflow as tf data = tf.placeholder(tf.float32, (2, 3, 2, 5)) print(data.shape) // (2, 3, 2, 5) # shape of indices, [2, 3] indices = tf.constant([ [1, 1, 1], [0, 0, 1] ]) print(tf.batch_gather(data, indices).shape) # (2, 3, 2, 5) # if shape of indice was (2, 3, 1) the output would be 2, 3, 1, 5 如下

tf.gather_nd
© www.soinside.com 2019 - 2024. All rights reserved.