如何将numpy数组保存/加载到TF2中的TFRecordDataset中?

问题描述 投票:0回答:1

我正在使用Tensorflow 2.0将numpy数组保存/加载到TFRecordDataset中。似乎完全缺少2.0的文档,而且api也不简单。

我在这里创建了一个可复制的最小示例作为笔记本:https://gist.github.com/vicpara/3b4ea00553a1990620a2df77d8b6aa1f

感谢您的任何建议。

tensorflow tensorflow-datasets tensorflow2.0
1个回答
0
投票

一种方法是遍历数据集并保存特征,示例中提供的标签对:

import numpy as np
import tensorflow as tf
import random
import os

# For Tensorflow <2.0:
#tf.enable_eager_execution()


SAVE_PATH = "."

def _make_tf_example(feature, labels):
    bytes_feature = tf.io.serialize_tensor(feature)
    bytes_labels = tf.io.serialize_tensor(labels)

    feature_mapping = {
            'feature': _bytes_feature(bytes_feature),
            'labels': _bytes_feature(bytes_labels),
        }

    return tf.train.Example(features=tf.train.Features(feature=feature_mapping))

def _bytes_feature(value):
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _save_single_entry(feature, label, path_to_save, idx):
    tfrecord_full_filename = os.path.join(path_to_save, f"my_tfrecord_{idx:05}.tfrecord")
    with tf.io.TFRecordWriter(tfrecord_full_filename) as writer:
        tf_example = _make_tf_example(feature, label)
        writer.write(tf_example.SerializeToString())

然后:(假设orig是您的tf.data.Dataset

for idx, (feature, label) in enumerate(orig):
  _save_single_entry(feature, label, SAVE_PATH, idx)
© www.soinside.com 2019 - 2024. All rights reserved.