了解如何使用tf.dataset.map()

问题描述 投票:1回答:1

我正在将一些原本使用JPEG作为输入的代码转换为使用Matlab MAT文件。 代码中包含了这几行。

train_dataset = tf.data.Dataset.list_files(PATH + 'train/*.mat')
train_dataset = train_dataset.shuffle(BUFFER_SIZE) 
train_dataset = train_dataset.map(load_image_train)

如果我在数据集中循环并在map()之前打印()每个元素,我就会得到一组带有可见文件路径的tensors。

然而,在load_image_train函数中,情况并非如此,print()的输出是如此。

Tensor("add:0", shape=(), dtype=string)

我想用scipy.io.loadmat()函数从我的mat文件中获取数据,但它失败了,因为路径是张量而不是字符串。 dataset.map()做了什么让字符串的字面值不再可见? 我如何提取字符串,以便将其作为 scipy.io.loadmat() 的输入?

如果这是一个愚蠢的问题,请原谅,对Tensorflow比较陌生,还在尝试理解。 我所能找到的很多相关问题的讨论都只适用于TF v1,谢谢你的帮助。

python tensorflow tensorflow-datasets
1个回答
0
投票

在下面的代码中,我使用的是 tf.data.Dataset.list_files 读取图像的文件路径。在 map 函数,我正在加载图像并进行 crop_central(基本上是按给定的百分比裁剪图像的中心部分,这里我是按以下方式指定百分比的。np.random.uniform(0.50, 1.00)).

正如你所提到的,由于文件路径为 tf.string 类型和 load_img 或其他任何读取图像文件的功能都需要简单的 string 型。

所以,你可以这样做 -

  1. 你需要用 tf.py_function(load_file_and_process, [x], [tf.float32]). 你可以找到更多关于它的信息 此处.
  2. 你可以检索到 string 来自 tf.string 使用 bytes.decode(path.numpy().

下面是完整的代码,供你参考。当你运行这段代码时,你可以用你的图像路径替换它。

%tensorflow_version 2.x
import tensorflow as tf
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array, array_to_img
from matplotlib import pyplot as plt
import numpy as np

def load_file_and_process(path):
    image = load_img(bytes.decode(path.numpy()), target_size=(224, 224))
    image = img_to_array(image)
    image = tf.image.central_crop(image, np.random.uniform(0.50, 1.00))
    return image

train_dataset = tf.data.Dataset.list_files('/content/bird.jpg')
train_dataset = train_dataset.map(lambda x: tf.py_function(load_file_and_process, [x], [tf.float32]))

for f in train_dataset:
  for l in f:
    image = np.array(array_to_img(l))
    plt.imshow(image)

输出 -

enter image description here

希望能回答你的问题。祝您学习愉快。

© www.soinside.com 2019 - 2024. All rights reserved.