tf.split的输出是什么?

问题描述 投票:2回答:2

假设我有这个:

TensorShape([Dimension(None),Dimension(32)])

我在这个张量_X上使用tf.split,其尺寸如上:

_X = tf.split(_X, 128, 0) 

这种新张量的形状是什么?输出是一个列表,因此很难知道这个新张量的形状。

python arrays tensorflow tensor
2个回答
9
投票

tf.split()返回张量对象列表。你可以知道每个张量对象的形状如下

import tensorflow as tf

X = tf.random_uniform([256, 32]);
Y = tf.split(X,128,0)
Y_shape = tf.shape(Y[1])

sess = tf.Session()
X_v,Y_v,Y_shape_v = sess.run([X,Y,Y_shape]) 
# numpy style
print X_v.shape
print len(Y_v)
print Y_v[100].shape
# TF style
print len(Y)
print Y_shape_v

输出:

(256, 32)
128
(2, 32)
128
[ 2 32]

我希望这有帮助 !


5
投票

tf.split(X, row = n, column = m)用于将变量的数据集分成n行数和m列数。

例如,我们有大小为x的data_set (10,10),然后tf.split(x, 2, 0)将在2组大小x中打破(5, 10)的data_set

但如果我们采取tf.split(x, 2, 2),那么我们将得到4组大小为(5, 5)的数据。

© www.soinside.com 2019 - 2024. All rights reserved.