如何查看批量数据集的x,y数组?

问题描述 投票:0回答:1

我使用下面的代码使用张量流制作了一个批处理数据集。

dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)

数据集形状如下。

<_BatchDataset element_spec=(TensorSpec(shape=(20, 256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(20, 256, 256, 1), dtype=tf.float32, name=None))>

我想用下面的代码打印第一个 x 数组,但它不起作用,其中 x 和 y 是输入和目标数据。

dataset[0][0] 
tensorflow keras
1个回答
0
投票

您可以使用以下方法之一来获取数据集的第一部分:

x = tf.random.uniform((100,2))
y = tf.random.uniform((100, 1))
ds = tf.data.Dataset.from_tensor_slices((x, y))  # dataset creation
ds = ds.batch(3, drop_remainder=True)

""" THIS """
for x,y in ds:
  print(x,y)
  break
""" OR THIS """
print(next(iter(ds)))

我认为 for 循环在幕后或多或少地执行了第二个解决方案。它使用

iter
创建一个可迭代对象,并且
next
给出顺序中的下一个元素。

© www.soinside.com 2019 - 2024. All rights reserved.