Tensorflow数据集,papped_batch,为什么允许使用不同的output_shapes,还有更好的方法吗?

问题描述 投票:-1回答:1

我正在尝试编写足以与他人共享的Tensorflow 2.0代码。我遇到了tf.data.Dataset问题。我已经解决了,但是我不喜欢我的解决方案。

这里是有效的Python代码,可以通过两种不同的方式从不规则数据生成填充批处理。在一种情况下,我重新使用了全局变量来提供形状信息。我不喜欢全局变量,特别是因为我知道数据集知道其自己的输出形状,并且将来我可能会使用具有几种不同输出形状的数据集对象。

在另一种情况下,我从数据集对象本身提取形状信息。但是我必须跳过障碍才能做到这一点。

import numpy as np
import tensorflow as tf

print("""
Create a data set with the desired shape: 1 input per sub-element,
3 targets per sub-element, 8 elements of varying lengths.
""")

def gen():
    lengths = np.tile(np.arange(4,8), 2)
    np.random.shuffle(lengths)
    for length in lengths:
        inp = np.random.randint(1, 51, length)
        tgt = np.random.random((length, 3))
        yield inp, tgt

output_types = (tf.int64, tf.float64)
output_shapes = ([None], [None, 3])
dataset = tf.data.Dataset.from_generator(gen, output_types, output_shapes)

print("""
Using the global variable, output_shapes, allows the retrieval
of padded batches.
""")

for inp, tgt in dataset.padded_batch(3, output_shapes):
    print(inp)
    print(tgt)
    print()

print("""
Obtaining the shapes supplied to Dataset.from_generator()
is possible, but hard.
""")

default_shapes = tuple([[y.value for y in x.shape.dims] for x in dataset.element_spec]) # Crazy!
for inp, tgt in dataset.padded_batch(3, default_shapes):
    print(inp)
    print(tgt)

我不太明白为什么人们可能希望将一批大小不一的元素中的数据填充到最初用于定义Dataset元素的输出形状以外的任何形状。有人知道用例吗?

而且,pshaped_shapes参数没有默认值。我展示了如何检索我认为是padded_shapes的明智默认值的方法。单线工作...但是为什么这么难?

我目前正在尝试对数据集进行子类化,以提供数据集默认形状作为Python属性。 Tensorflow在与我抗争,可能是因为在我使用Python时底层的数据集是C ++对象。

所有这些麻烦使我想知道是否有比我尝试过的方法更干净的方法。

感谢您的建议。

python tensorflow tensorflow-datasets
1个回答
0
投票

回答我自己的问题。 I asked this same question on Reddit。 Tensorflow贡献者回答了TF 2.2 will provide a default value for the padded_shapes argument。我很高兴看到开发团队已经认识到我确定的相同需求。

© www.soinside.com 2019 - 2024. All rights reserved.