我试图用以下方法写一个数据集流水线。tensorflow
用于标记图像的补丁。现在我从一堆 tfrecord 文件中读取,每个文件有多个补丁,但只有一个标签。这个标签有四个类。
Tensorflow似乎不喜欢我通过管道传递一个单子。我得到以下错误。
ValueError: Value Tensor("args_1:0", shape=(), dtype=int32) has insufficient rank for batching.
我正在想办法让这个用例正常工作。这基本上是我想做的事情。我想请教一下我应该怎么做 y
以便我在流水线的最后得到每个补丁的一个标签。如果我需要改变 tfrecord 文件的结构,使之成为 y
是一个一热编码的向量,我只是不知道这是否有必要。
def parse_func(proto):
features = tf.io.parse_single_example(
serialized=proto,
features={'X': tf.io.FixedLenFeature([], tf.string),
'length': tf.io.FixedLenFeature([], tf.int64),
'y': tf.io.FixedLenFeature([], tf.int64)})
y = tf.cast(features['y'], tf.int32) # this is just an integer, but maybe it should be a one-hot encoded vector
X = tf.io.decode_raw(features['X'], tf.float32)
length = tf.cast(features['length'], tf.int32)
shape = tf.stack([length, 60, 1])
return tf.reshape(X, shape), y
def get_patches(X, y):
X = X[tf.newaxis, ...]
patches = tf.image.extract_patches(X,
sizes=[1, 128, 60, 1],
strides=[1, 4, 1, 1],
rates=[1, 1, 1, 1],
padding='VALID')
patches = tf.reshape(patches, [-1, 128, 60, 1])
y = repeat_so_that_there_is_one_label_per_patch(y)
return patches, y
dataset = (tf.data.Dataset.from_tensor_slices('tf_record_file_paths')
.shuffle(100)
.interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=4)
.map(parse_func)
.map(get_patches)
.unbatch()
.shuffle(100)
.repeat()
.batch(64, drop_remainder=True)
.prefetch(1))
我解决了这个问题,具体如下。
def repeat_so_that_there_is_one_label_per_patch(y, patches):
num_patches = tf.shape(patches)[0]
tiled_y = tf.tile(y, multiples=[num_patches])
return tf.reshape(tiled_y, tf.shape(y) * num_patches)