TensorFlow的tf.data.Dataset
documentation on consuming numpy arrays指出,为了将numpy数组与Dataset
API结合使用,数组必须足够小(总共<2 GB)才能用作张量,或者它们可以通过占位符输入到数据集中。
但是,如果将Dataset
与估算器一起使用(其中占位符不可用),则文档不提供使用没有占位符的大型数组的解决方案。
是否有其他选项可以将占位符值传递给可以使用的估算器,或者是以tfrecord
或csv
格式提供数据的解决方案?
您可以在创建数据集对象之前使用np.split
和from_generator
。
chunks = list(np.split(array, 1000))
def gen():
for i in chunks:
yield i
dataset = tf.data.Dataset.from_generator(gen, tf.float32)
dataset = dataset.shuffle(shuffle_buffer_size)
...
您可以使用shuffle控制数据集的大小。它一次只加载指定的数量。