假设我有这个:
TensorShape([Dimension(None),Dimension(32)])
我在这个张量_X上使用tf.split,其尺寸如上:
_X = tf.split(_X, 128, 0)
这种新张量的形状是什么?输出是一个列表,因此很难知道这个新张量的形状。
tf.split()返回张量对象列表。你可以知道每个张量对象的形状如下
import tensorflow as tf
X = tf.random_uniform([256, 32]);
Y = tf.split(X,128,0)
Y_shape = tf.shape(Y[1])
sess = tf.Session()
X_v,Y_v,Y_shape_v = sess.run([X,Y,Y_shape])
# numpy style
print X_v.shape
print len(Y_v)
print Y_v[100].shape
# TF style
print len(Y)
print Y_shape_v
输出:
(256, 32)
128
(2, 32)
128
[ 2 32]
我希望这有帮助 !
tf.split(X, row = n, column = m)
用于将变量的数据集分成n
行数和m
列数。
例如,我们有大小为x
的data_set (10,10)
,然后tf.split(x, 2, 0)
将在2组大小x
中打破(5, 10)
的data_set
但如果我们采取tf.split(x, 2, 2)
,那么我们将得到4组大小为(5, 5)
的数据。