关于tf.repeat()。batch(batch_size)

问题描述 投票:0回答:1

我正在研究张量流。关于tensorflow.data.Dataset中的repeat函数,如果repeat函数repeat()中没有参数,则应无限期重复张量。但是,当不带参数的重复功能与循环语句下的批处理功能结合使用时,它会产生无休止重复的结果,如下所示。我无法理解该过程。您可以使用以下示例解释重复功能吗?谢谢!

for count_batch in ds_counter.repeat().batch(10).take(10):
  print(count_batch.numpy())

[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]
tensorflow-datasets
1个回答
0
投票

由于您使用.take(10)作为最后一种链接方法,因此,因此,结果数据集仅限于10个样本。这里的单个sample将是单个批次中的所有元素。有10个批次的无限重复,您可以使用.take(10)从其中初始提取10个批次。将您的代码更改为以下代码应该可以得到预期的结果。

ds_counter = tf.data.Dataset.range(25)
for count_batch in ds_counter.repeat().batch(10):
    print(count_batch.numpy())
© www.soinside.com 2019 - 2024. All rights reserved.