我已经从生成器创建了一个tensorflow数据集,但无法弄清楚如何通过batch_size对其进行迭代
def ds_gen():
x = np.random.random((10, 10, 3))
y = np.random.random((2))
yield x, y
def create_tf_dataset():
dataset = tf.data.Dataset.from_generator(ds_gen, output_types=(tf.float32, tf.float32), output_shapes=((10, 10, 3), (2,)))
return dataset
ds = create_tf_dataset()
ds = ds.batch(10)
for x_batch, y_batch in ds:
print(x_batch.shape, y_batch.shape)
此代码将循环遍历1而不是10的批量大小
请参阅下面的代码以按批处理大小进行迭代
def ds():
for i in range(1000):
x = np.random.rand(10,10,3)
y = np.random.rand(2)
yield x,y
ds = tf.data.Dataset.from_generator(ds, output_types=(tf.float32, tf.float32), output_shapes=((10, 10, 3), (2,)))
ds = ds.batch(10)
for batch, (x,y) in enumerate(ds):
pass
print("Data shape: ", x.shape, y.shape)
输出:
Data shape: (10, 10, 10, 3) (10, 2)
如果更改ds = ds.batch(1)
,则输出将为Data shape: (1, 10, 10, 3) (1, 2)