如何在Tensorflow 2.0中创建填充批处理?

问题描述 投票:0回答:1

我正在尝试从tensorflow数据集创建填充批处理,这使我出错。

样本数据:

A = [
 ([[101,9385,13302,102], [1,1,1,1], [0,0,0,0]]),
 ([[101,9385,13302], [1,1,1], [0,0,0]]),

 ([[101,9385,13302,102], [1,1,1,1], [0,0,0,0]]),
 ([[101,9385,13302], [1,1,1], [0,0,0]])     
 ]

预期输出:

Output = [
 ([[101,9385,13302,102], [1,1,1,1], [0,0,0,0]]),
 ([[101,9385,13302,-100], [1,1,1,-100], [0,0,0,-100]]),

 ([[101,9385,13302,102], [1,1,1,1], [0,0,0,0]]),
 ([[101,9385,13302,-100], [1,1,1,-100], [0,0,0,-100]])     
 ]

下面是代码:

chk = tf.data.Dataset.from_generator(lambda: A,
                                         output_types = (tf.int32))
BATCH_SIZE = 2
chk_batched = chk.padded_batch(BATCH_SIZE ,
                                   padded_shapes=(4),
                                   padding_values=(0))

for elem in chk_batched.as_numpy_iterator():
   print(elem)

错误在下面给出:

InvalidArgumentError: All elements in a batch must have the same rank as the padded shape for component0: expected rank 1 but got element with rank 2
python tensorflow-datasets
1个回答
0
投票

因此它是二维的,papped_shapes =(4,3),无需编写padding_values =(0)。默认情况下,填充的值为0

© www.soinside.com 2019 - 2024. All rights reserved.