使用Tuporflow数据集进行分组,其中每个tupled元素具有不同的形状

问题描述 投票:1回答:1

我正在尝试修改现有的tensorflow代码。首先,将一个单词矩阵从datasetgeneartor函数转换为map_strings_to_ints并转换为词汇索引。然后调用以下函数。

dataset = dataset.apply(tf.contrib.data.bucket_by_sequence_length(element_length_func=lambda d: tf.shape(d)[0],
                                                                     bucket_boundaries=bucket_boundaries,
                                                                     bucket_batch_sizes=bucket_batch_sizes,
                                                                     padded_shapes=dataset.output_shapes,
                                                                     padding_values=constants.PAD_VALUE))

其中每个dataset元素是一个大小的数组[无,无](即2d mat)。

现在,对于每个元素,我想添加另一个文本序列。因此每个元素都是前一个2d mat的元组,每个新数据集元素的相应句子/序列是([None,None],[None])的元组,那么如何修改上述函数呢?

我试过了

dataset = dataset.apply(tf.contrib.data.bucket_by_sequence_length(element_length_func=lambda d,t: tf.shape(d)[0],
                                                                     bucket_boundaries=bucket_boundaries,
                                                                     bucket_batch_sizes=bucket_batch_sizes,
                                                                     padded_shapes=dataset.output_shapes,
                                                                     padding_values=constants.PAD_VALUE))

而且很少有其他技巧

TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: <class ‘int’>

注意,dataset元素是映射到词汇索引的单词(即int)

tensorflow deep-learning batch-processing tensorflow-datasets bucket
1个回答
0
投票

这应该对您有所帮助:

X = np.array([[[1,2,3],[4,5,6]],[[7,8,9], [1,2,3], [4,5,6], [7,8,9]], [[1,2,3], [4,5,6]]])
Y = np.array([0,1,0])

def elements_gen():
    for x,y in zip(X,Y):
        yield (x,y)

dataset = tf.data.Dataset.from_generator(generator=elements_gen, output_shapes=([None, None], []), output_types=(tf.int32, tf.int32))

dataset = dataset.apply(tf.contrib.data.bucket_by_sequence_length(element_length_fun =lambda x,y: tf.shape(x)[0], bucket_boundaries=[4,7], bucket_batch_sizes=[2,2,2], padding_values=(0,0)))
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

问题正是错误所说的,因为你填充的结构是一个序列,用于填充结构的值也必须是一个序列。

© www.soinside.com 2019 - 2024. All rights reserved.