我使用下面的代码使用张量流制作了一个批处理数据集。
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
数据集形状如下。
<_BatchDataset element_spec=(TensorSpec(shape=(20, 256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(20, 256, 256, 1), dtype=tf.float32, name=None))>
我想用下面的代码打印第一个 x 数组,但它不起作用,其中 x 和 y 是输入和目标数据。
dataset[0][0]
您可以使用以下方法之一来获取数据集的第一部分:
x = tf.random.uniform((100,2))
y = tf.random.uniform((100, 1))
ds = tf.data.Dataset.from_tensor_slices((x, y)) # dataset creation
ds = ds.batch(3, drop_remainder=True)
""" THIS """
for x,y in ds:
print(x,y)
break
""" OR THIS """
print(next(iter(ds)))
我认为 for 循环在幕后或多或少地执行了第二个解决方案。它使用
iter
创建一个可迭代对象,并且 next
给出顺序中的下一个元素。