tensorflow.data.Dataset.from_generator() 将字符串参数转换为字节

问题描述 投票:0回答:1

我有一个自定义生成器,它采用字符串格式的两个日期作为参数,并从日期范围内生成(特征、标签)。我想用它创建一个数据集,但是

tf.data.Dataset.from_generator()
顽固地将日期字符串转换为字节,导致生成器函数失败。让我演示一下这种行为:

def some_generator(date1, date2):
    print(date1, date2)
    yield [1, 2], [3]

feats, label = next(some_generator('2010-01-01', '2017-12-31'))
signature = (tf.type_spec_from_value(tf.convert_to_tensor(feats)),
             tf.type_spec_from_value(tf.convert_to_tensor(label)))
ds = Dataset.from_generator(some_generator, args=('2010-01-01', '2017-12-31'), output_signature=signature)

for feats, label in ds.take(1):
    print(feats, label)

此代码的输出是:

2010-01-01 2017-12-31
b'2010-01-01' b'2017-12-31'
tf.Tensor([1 2], shape=(2,), dtype=int32) tf.Tensor([3], shape=(1,), dtype=int32)

第一行是第一次调用

some_generator
的结果,其中为签名生成功能和标签,这里日期以字符串形式打印出来。

第二行是 for 循环中迭代数据集的结果,其中日期被打印为字节字符串。整数不会出现此问题,但我需要字符串。如果有人知道如何解决此问题,请分享。

python tensorflow dataset
1个回答
0
投票

我想出的一个可能的解决方案是解码参数,如果它们不是字符串,即:

def some_generator(date1, date2):
    date1 = date1 if type(date1) == str else date1.decode('ASCII')
    date2 = date2 if type(date2) == str else date2.decode('ASCII')
    print(date1, date2)
    yield [1, 2], [3]

与上面示例中的其余代码一起生成:

2010-01-01 2017-12-31
2010-01-01 2017-12-31
tf.Tensor([1 2], shape=(2,), dtype=int32) tf.Tensor([3], shape=(1,), dtype=int32)

所以它确实解决了问题,但增加了额外的计算,因为条件检查结束类型转换是针对每个生成的数据样本完成的,这并不理想。

© www.soinside.com 2019 - 2024. All rights reserved.