如何将组件名称添加/更改为现有Tensorflow数据集对象?

问题描述 投票:1回答:1

从Tensorflow数据集指南中可以看出

为元素的每个组件指定名称通常很方便,例如,如果它们代表训练示例的不同特征。除了元组之外,您还可以使用collections.namedtuple或将字符串映射到张量的字典来表示数据集的单个元素。

dataset = tf.data.Dataset.from_tensor_slices(
   {"a": tf.random_uniform([4]),
    "b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types)  # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes)  # ==> "{'a': (), 'b': (100,)}"

https://www.tensorflow.org/guide/datasets

这在Keras非常有用。如果将数据集对象传递给model.fit,则可以使用组件的名称来匹配Keras模型的输入。例:

image_input = keras.Input(shape=(32, 32, 3), name='img_input')
timeseries_input = keras.Input(shape=(None, 10), name='ts_input')

x1 = layers.Conv2D(3, 3)(image_input)
x1 = layers.GlobalMaxPooling2D()(x1)

x2 = layers.Conv1D(3, 3)(timeseries_input)
x2 = layers.GlobalMaxPooling1D()(x2)

x = layers.concatenate([x1, x2])

score_output = layers.Dense(1, name='score_output')(x)
class_output = layers.Dense(5, activation='softmax', name='class_output')(x)

model = keras.Model(inputs=[image_input, timeseries_input],
                    outputs=[score_output, class_output])

train_dataset = tf.data.Dataset.from_tensor_slices(
    ({'img_input': img_data, 'ts_input': ts_data},
     {'score_output': score_targets, 'class_output': class_targets}))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

model.fit(train_dataset, epochs=3)

因此,查找,添加和更改tf数据集对象中组件的名称会很有用。做这些任务的最佳方法是什么?

python tensorflow tensorflow-datasets
1个回答
3
投票

如果您正在寻找,可以使用map对数据集进行修改。例如,要将普通的tuple输出转换为具有有意义名称的dict

import tensorflow as tf

# dummy example
ds_ori = tf.data.Dataset.zip((tf.data.Dataset.range(0, 10), tf.data.Dataset.range(10, 20)))
ds_renamed = ds_ori.map(lambda x, y: {'input': x, 'output': y})

batch_ori = ds_ori.make_one_shot_iterator().get_next()
batch_renamed = ds_renamed.make_one_shot_iterator().get_next()

with tf.Session() as sess:
  print(sess.run(batch_ori))
  print(sess.run(batch_renamed))
  # (0, 10)
  # {'input': 0, 'output': 10}
© www.soinside.com 2019 - 2024. All rights reserved.