如何在张量流中的Estimator的input_fn中读取和处理我的图像文件?

问题描述 投票:1回答:1

我是TF新手。由于某些原因,我必须使用TF1.10,在此我发现.ppm不支持tf.image.decode_image

我的网络的最终目标是读取RGBD输入,并使用它生成更多特征(通过手工制作方法(例如采样法线),并最终利用这些附加特征和地面实况来计算损耗。

由于我的数据集很大,因此我使用tf.data.TextLineDataset获取了input_fn中文件路径列表的数据集,并尝试使用Dataset.map生成要素。当我尝试解码.ppm文件时,我发现了此问题。 (代码如下所示。)

或,还有其他方法可以避免在读取strings中的图像之前将路径Tensor转换为input_fn,然后可以使用cv2.imread吗?但是,如果我这样做,我想我必须使用所有Tensors来构建我的数据集和迭代器,这可能会占用大量内存。 (也许我错了。)

或者,如果您认为我完全误解了datasetEstimator的用法,请告诉我正确的方法。谢谢。

def input_fn(self, dataset, mode="train"):
    self.dict_dataset_lists = {}
    ds_rgb = os.path.expandvars(dataset["rgb"])
    ds_d = os.path.expandvars(dataset["d"])
    ds_gt = os.path.expandvars(dataset["gt"])

    self.dict_dataset_lists["rgb"] = tf.data.TextLineDataset(ds_rgb)
    self.dict_dataset_lists["d"] = tf.data.TextLineDataset(ds_d)
    self.dict_dataset_lists["gt"] = tf.data.TextLineDataset(ds_gt)

    with tf.name_scope("Dataset_API"):
        tf_dataset = tf.data.Dataset.zip(self.dict_dataset_lists)

        # load path to imgs(tensor)
        if mode == "train":
            tf_dataset = tf_dataset.repeat(self.parameters.max_epochs)
            if self.parameters.shuffle:
                tf_dataset = tf_dataset.shuffle(
                        buffer_size=self.parameters.steps_per_epoch * self.parameters.batch_size)
            tf_dataset = tf_dataset.map(load_img, num_parallel_calls=1) 
            tf_dataset = tf_dataset.batch(self.parameters.batch_size)
            tf_dataset = tf_dataset.prefetch(buffer_size=self.parameters.prefetch_buffer_size)

    # make iterator
    iterator = tf_dataset.make_one_shot_iterator()

    dict_tf_input = iterator.get_next()
python tensorflow deep-learning tensorflow-datasets tensorflow-estimator
1个回答
0
投票

这里是如何在tf.data.Dataset.map函数中从张量获取字符串部分的示例。

下面是我在代码中实现的步骤。

  1. 您必须用tf.py_function(get_path, [x], [tf.string])装饰地图功能。您可以找到有关tf.py_function here的更多信息。
  2. 您可以通过使用地图功能中的bytes.decode(file_path.numpy())来获得琴弦部分。

代码-

%tensorflow_version 2.x
import tensorflow as tf
import numpy as np

def get_path(file_path):
    print("file_path: ",bytes.decode(file_path.numpy()),type(bytes.decode(file_path.numpy())))
    return file_path

train_dataset = tf.data.Dataset.list_files('/content/bird.jpg')
train_dataset = train_dataset.map(lambda x: tf.py_function(get_path, [x], [tf.string]))

for one_element in train_dataset:
    print(one_element)

输出-

file_path:  /content/bird.jpg <class 'str'>
(<tf.Tensor: shape=(), dtype=string, numpy=b'/content/bird.jpg'>,)

希望这能回答您的问题。

© www.soinside.com 2019 - 2024. All rights reserved.