通过tf.data.Dataset将大型numpy数组输入TensorFlow估算器

问题描述 投票:1回答:1

TensorFlow的tf.data.Dataset documentation on consuming numpy arrays指出,为了将numpy数组与Dataset API结合使用,数组必须足够小(总共<2 GB)才能用作张量,或者它们可以通过占位符输入到数据集中。

但是,如果将Dataset与估算器一起使用(其中占位符不可用),则文档不提供使用没有占位符的大型数组的解决方案。

是否有其他选项可以将占位符值传递给可以使用的估算器,或者是以tfrecordcsv格式提供数据的解决方案?

python arrays numpy tensorflow tensorflow-estimator
1个回答
1
投票

您可以在创建数据集对象之前使用np.splitfrom_generator

chunks = list(np.split(array, 1000))

def gen():
    for i in chunks:
        yield i

dataset = tf.data.Dataset.from_generator(gen, tf.float32)
dataset = dataset.shuffle(shuffle_buffer_size)
...

您可以使用shuffle控制数据集的大小。它一次只加载指定的数量。

© www.soinside.com 2019 - 2024. All rights reserved.