访问Tensor的元素

问题描述 投票:2回答:1

我有以下TensorFlow张量。

tensor1 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]
tensor2 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]

tensor3 = tf.keras.backend.flatten(tensor1)
tensor4 = tf.keras.backend.flatten(tensor2)

tensor5 = tf.constant(np.random.randint(0,255, (255,255)), dtype='int32') #All elements in range [0,255]

我希望使用存储在张量3和张量4中的值来制作一个元组,并查询张量5中元组给出的位置的元素。例如,我们说张量3中的第0个元素,即张量3[0]=5,张量4[0]=99.所以元组变成了(5,99)。我希望查找张量5中元素(5,99)的值。我希望以批处理的方式对Tensor3和Tensor4中的所有元素进行查询。也就是说,我不想循环处理(len(Tensor3))范围内的所有值。我做了下面的工作来实现这个目的。

tensor6 = tensor5[tensor3[0],tensor4[0]]

但是tensor6的形状是(255,255),我希望得到一个形状为(len(tensor3),len(tensor3))的张量。我希望在len(tensor3)中所有可能的位置评估tensor5。这是在 (0,0),...(1000,1000),....(2000,2000),.... 我使用的是TensorFlow 1.12.0版本。我如何实现这个目标?

python tensorflow keras tensor array-broadcasting
1个回答
2
投票

我已经设法在Tensorflow v 1.12中得到了一些工作,但请告诉我这是否是预期的代码。

import tensorflow as tf
print(tf.__version__)
import numpy as np

tensor1 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]
tensor2 = tf.constant(np.random.randint(0,255, (2,512,512,1)), dtype='int32') #All elements in range [0,255]

tensor3 = tf.keras.backend.flatten(tensor1)
tensor4 = tf.keras.backend.flatten(tensor2)

tensor5 = tf.constant(np.random.randint(0,255, (255,255)), dtype='int32') #All elements in range [0,255]

elems = (tensor3, tensor4)
a = tf.map_fn(lambda x: tensor5[x[0], x[1]], elems, dtype=tf.int32)

print(tf.Session().run(a))

基于下面的评论,我想补充一个解释 map_fn 在代码中使用。由于 for 如果没有eager_execution,则不支持循环。map_fn 等于 for 循环。

A map_fn 有以下参数。operation_performed, input_arguments, optional_dtype. 引擎盖下发生的事情是 for 中的值的长度运行循环。input_arguments (其中必须包含一个可迭代对象),然后对每一个获得的值进行 operation_performed 是执行。如需进一步说明,请参考 文件.

函数参数的名称是我对它们的解释,因为我想理解它们,而不是在官方文档中给出的。)

© www.soinside.com 2019 - 2024. All rights reserved.