用于分类补丁的Tensorflow数据集流水线

问题描述 投票:1回答:1

我试图用以下方法写一个数据集流水线。tensorflow 用于标记图像的补丁。现在我从一堆 tfrecord 文件中读取,每个文件有多个补丁,但只有一个标签。这个标签有四个类。

Tensorflow似乎不喜欢我通过管道传递一个单子。我得到以下错误。

ValueError: Value Tensor("args_1:0", shape=(), dtype=int32) has insufficient rank for batching.

我正在想办法让这个用例正常工作。这基本上是我想做的事情。我想请教一下我应该怎么做 y 以便我在流水线的最后得到每个补丁的一个标签。如果我需要改变 tfrecord 文件的结构,使之成为 y 是一个一热编码的向量,我只是不知道这是否有必要。

def parse_func(proto):
    features = tf.io.parse_single_example(
        serialized=proto,
        features={'X': tf.io.FixedLenFeature([], tf.string),
                  'length': tf.io.FixedLenFeature([], tf.int64),
                  'y': tf.io.FixedLenFeature([], tf.int64)})

    y = tf.cast(features['y'], tf.int32)  # this is just an integer, but maybe it should be a one-hot encoded vector

    X = tf.io.decode_raw(features['X'], tf.float32)
    length = tf.cast(features['length'], tf.int32)
    shape = tf.stack([length, 60, 1])
    return tf.reshape(X, shape), y


def get_patches(X, y):
    X = X[tf.newaxis, ...]

    patches = tf.image.extract_patches(X,
                                       sizes=[1,  128, 60, 1],
                                       strides=[1, 4, 1, 1],
                                       rates=[1, 1, 1, 1],
                                       padding='VALID')
    patches = tf.reshape(patches, [-1, 128, 60, 1])
    y = repeat_so_that_there_is_one_label_per_patch(y)
    return patches, y


dataset = (tf.data.Dataset.from_tensor_slices('tf_record_file_paths')
           .shuffle(100)
           .interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=4)
           .map(parse_func)
           .map(get_patches)
           .unbatch()
           .shuffle(100)
           .repeat()
           .batch(64, drop_remainder=True)
           .prefetch(1))
python tensorflow tensorflow-datasets tensorflow2.0
1个回答
0
投票

我解决了这个问题,具体如下。

def repeat_so_that_there_is_one_label_per_patch(y, patches):
    num_patches = tf.shape(patches)[0]
    tiled_y = tf.tile(y, multiples=[num_patches])
    return tf.reshape(tiled_y, tf.shape(y) * num_patches)
© www.soinside.com 2019 - 2024. All rights reserved.