Tensorflow - 使用数据集API填充或截断序列

问题描述 投票:2回答:1

我正在尝试使用Dataset API来准备TFRecordDataset的文本序列。处理完毕后,我为每条记录都有一张张量词典。每条记录包含两个序列。

我正在使用padded_batch来应用填充

dataset = dataset.padded_batch(batch_size, padded_shapes= {
    'seq1': tf.TensorShape([None]),
    'seq2': tf.TensorShape([None])
})

这将每个序列填充到批次中的最大序列长度。但是,我想选择一个任意的序列长度,并在真正的序列长度较小时填充此长度,否则截断序列。

当我尝试用None替换100时,我遇到了DataLossError

DataLossError:尝试填充到比输入元素更小的大小。

有没有办法在序列上实现与tf.image.resize_image_with_crop_or_pad类似的功能?

tensorflow sequence tensorflow-datasets
1个回答
0
投票

填充或截断没有简单的方法,但您可以使用map函数来获取包含所需长度元素的数据集。这是一个简单的例子:

k = 4
def pad_or_trunc(t):
    dim = tf.size(t)
    return tf.cond(tf.equal(dim, k), lambda: t, lambda: tf.cond(tf.greater(dim, k), lambda: tf.slice(t, [0], [k]), lambda: tf.concat([t, tf.zeros(k-dim, dtype=tf.int32)], 0)))

vals = tf.constant([[1, 1, 1], [2, 2, 2], [3, 3, 3]])
dset1 = tf.data.Dataset.from_tensor_slices(vals)
dset2 = dset1.map(pad_or_trunc)
iter = dset2.make_one_shot_iterator()

with tf.Session() as sess:
    while True:
        try:
            print(sess.run(iter.get_next()))
        except tf.errors.OutOfRangeError:
            break
© www.soinside.com 2019 - 2024. All rights reserved.