我正在尝试使用Dataset API来准备TFRecordDataset
的文本序列。处理完毕后,我为每条记录都有一张张量词典。每条记录包含两个序列。
我正在使用padded_batch
来应用填充
dataset = dataset.padded_batch(batch_size, padded_shapes= {
'seq1': tf.TensorShape([None]),
'seq2': tf.TensorShape([None])
})
这将每个序列填充到批次中的最大序列长度。但是,我想选择一个任意的序列长度,并在真正的序列长度较小时填充此长度,否则截断序列。
当我尝试用None
替换100
时,我遇到了DataLossError
DataLossError:尝试填充到比输入元素更小的大小。
有没有办法在序列上实现与tf.image.resize_image_with_crop_or_pad
类似的功能?
填充或截断没有简单的方法,但您可以使用map
函数来获取包含所需长度元素的数据集。这是一个简单的例子:
k = 4
def pad_or_trunc(t):
dim = tf.size(t)
return tf.cond(tf.equal(dim, k), lambda: t, lambda: tf.cond(tf.greater(dim, k), lambda: tf.slice(t, [0], [k]), lambda: tf.concat([t, tf.zeros(k-dim, dtype=tf.int32)], 0)))
vals = tf.constant([[1, 1, 1], [2, 2, 2], [3, 3, 3]])
dset1 = tf.data.Dataset.from_tensor_slices(vals)
dset2 = dset1.map(pad_or_trunc)
iter = dset2.make_one_shot_iterator()
with tf.Session() as sess:
while True:
try:
print(sess.run(iter.get_next()))
except tf.errors.OutOfRangeError:
break