使用from_generator创建的Tensorflow数据集未按batch_size进行迭代

问题描述 投票:0回答:1

我已经从生成器创建了一个tensorflow数据集,但无法弄清楚如何通过batch_size对其进行迭代

def ds_gen():
    x = np.random.random((10, 10, 3))
    y = np.random.random((2))
    yield x, y

def create_tf_dataset():
    dataset = tf.data.Dataset.from_generator(ds_gen, output_types=(tf.float32, tf.float32), output_shapes=((10, 10, 3), (2,)))
    return dataset

ds = create_tf_dataset()
ds = ds.batch(10)
for x_batch, y_batch in ds:
  print(x_batch.shape, y_batch.shape)

此代码将循环遍历1而不是10的批量大小

tensorflow tensorflow-datasets
1个回答
0
投票

请参阅下面的代码以按批处理大小进行迭代

def ds():
    for i in range(1000):
      x = np.random.rand(10,10,3)
      y = np.random.rand(2)
      yield x,y

ds = tf.data.Dataset.from_generator(ds, output_types=(tf.float32, tf.float32), output_shapes=((10, 10, 3), (2,)))

ds = ds.batch(10)

for batch, (x,y) in enumerate(ds):
  pass
print("Data shape: ", x.shape, y.shape)

输出:

Data shape:  (10, 10, 10, 3) (10, 2)

如果更改ds = ds.batch(1),则输出将为Data shape: (1, 10, 10, 3) (1, 2)

© www.soinside.com 2019 - 2024. All rights reserved.