从tf.io.parse_single_sequence_example获取字符串

问题描述 投票:0回答:1

我有tfrecord文件,其中包含图像和边框。每个图像都具有可变数量的边界框。我正在创建我的SequenceExamples,例如>

def image_example(image_string, obj_vectors):

    box_feature_list = []
    for vec in obj_vectors:
        box_features = tf.train.Feature(float_list=tf.train.FloatList(value=vec))
        box_feature_list.append(box_features)

    all_box_features = tf.train.FeatureList(feature=box_feature_list)
    box_dict = {
        'Box Vectors': all_box_features
    }
    boxes = tf.train.FeatureLists(feature_list=box_dict)
    image = tf.train.Features(feature={
        'image': _bytes_feature(image_string),
    })
    example = tf.train.SequenceExample(
        context=image,
        feature_lists=boxes
    )
    return example

然后我用[]阅读它们>

def _parse_image_function(example):
    # Create a dictionary describing the features.
    context_feature = {
        'image': tf.io.FixedLenFeature([], dtype=tf.string)
    }
    sequence_features = {
        'Box Vectors': tf.io.VarLenFeature(dtype=tf.float32)
    }
    context_data, sequence_data = tf.io.parse_single_sequence_example(serialized=example, 
                                    context_features=context_feature, sequence_features=sequence_features)
    print(context_data)
    image = context_data['image']
    print(image)
    #tf.io.decode_raw(image, tf.string)
    #print(image.numpy())
    image = tf.image.decode_jpeg(context_data['image'])    
    print(image)
    print(sequence_data['Box Vectors'])
    return context_data, sequence_data

[当我打印context_data时,它打印{'image': <tf.Tensor 'ParseSingleSequenceExample/ParseSingleSequenceExample:0' shape=() dtype=string>},而当我打印context_data['image']时,打印Tensor("ParseSingleSequenceExample/ParseSingleSequenceExample:0", shape=(), dtype=string)。我希望在执行context_data['image']时会得到原始字符串,但我没有。

我使用_parse_image_function作为dataset.map的输入,如

dataset = tf.data.TFRecordDataset(FILENAME)
dataset = dataset.map(_parse_image_function)

然后我可以通过执行操作来获取原始图像字符串

for x, y in dataset:
    vecs = y['Box Vectors']
    image = x['image']
    image = tf.reshape(image, [])
    #print(image)
    image = tf.image.decode_jpeg(image)
    vecs = tf.sparse.to_dense(vecs)

但是我想在地图函数_parse_image_function中将数据转换为张量。我想这样做,以便可以批处理数据,我打算在映射的数据集上使用dataset.padded_batch。我要用这种错误的方式?

我有tfrecord文件,其中包含图像和边框。每个图像都具有可变数量的边界框。我正在创建我的SequenceExamples,例如def image_example(...

python tensorflow tensorflow-datasets
1个回答
0
投票

我与之合作

    def _parse_image_function(example):
        # Create a dictionary describing the features.
        context_feature = {
            'image': tf.io.FixedLenFeature([], dtype=tf.string)
        }
        sequence_features = {
            'Box Vectors': tf.io.VarLenFeature(dtype=tf.float32)
        }
        context_data, sequence_data = tf.io.parse_single_sequence_example(serialized=example, 
                                        context_features=context_feature, sequence_features=sequence_features)

        return context_data['image'], sequence_data['Box Vectors']

    def format_data(image, labels):
        vecs = tf.sparse.to_dense(labels)
        #print(vecs)
        image = tf.image.decode_jpeg(image)
        image = tf.image.convert_image_dtype(image, dtype=tf.float32) # this should also normalize pixels
        #print(image)
        return image, vecs

    def train():
        dataset = dataset.map(_parse_image_function)
        dataset = dataset.map(format_data)
        dataset = dataset.padded_batch(params.batch_size, padded_shapes=([None, None, 3], [None, None]))
        for epoch in range(params.epochs):
            for x, y in dataset:
                loss_value, grads = grad(model, x, y)
                optimizer.apply_gradients(zip(grads, model.trainable_variables))
© www.soinside.com 2019 - 2024. All rights reserved.