如何从生成器开发一个只输出数组列表的tf.data对象?

问题描述 投票:1回答:1

我试图开发一个产生数组列表的tf.data对象,但我得到一个数据不匹配的错误。以下是我的尝试

def labelGen():
    yield tf.constant([1, 0], dtype=tf.int64), tf.constant([1, 0], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64)

Labeldataset = tf.data.Dataset.from_generator(
     labelGen, (tf.int64, tf.int64, tf.int64, tf.int64, tf.int64), ([], [], [], [], []) )

list(Labeldataset.take(1))

而这是我得到的错误

InvalidArgumentError。TypeError: generator 产生了一个不符合预期结构的元素。预期的结构是(tf.int64, tf.int64, tf.int64, tf.int64, tf.int64, tf.int64),但产生的元素是(, , , )。 回溯(最近一次调用)。

tensorflow tensorflow-datasets
1个回答
1
投票

首先,.from_generator代码中的项数不匹配。第二,生成器的调用不应该是().下面是在TF 2.1中测试的工作代码。

def labelGen():
    yield tf.constant([1, 0], dtype=tf.int64), tf.constant([1, 0], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64), tf.constant([0, 1], dtype=tf.int64)

Labeldataset = tf.data.Dataset.from_generator(
    labelGen, # without ()
    (tf.int64, tf.int64, tf.int64, tf.int64), # should match number of items
    (tf.TensorShape([2]), tf.TensorShape([2]), tf.TensorShape([2]), tf.TensorShape([2]))) # use tf.TensorShape

list(Labeldataset.take(1))

结果是

[(<tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 0])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>)]
© www.soinside.com 2019 - 2024. All rights reserved.