Tensorflow MapDataset迭代器失败

问题描述 投票:0回答:1

我正在尝试实现此处的tensorflow文档(https://www.tensorflow.org/tutorials/load_data/images)建议的方法,以从本地目录加载图像作为tensorflow数据集。特别是我对使用tf.data作为tf.data.Dataset对象进行加载感兴趣,因为这样可以提高性能。我几乎从文档页面获取了确切的代码,并确保tensorflow版本与文档中的版本匹配

当我尝试使用take()遍历MapDataset对象时发生问题。

import os
import sys
import pathlib

import IPython.display as display
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

AUTOTUNE = tf.data.experimental.AUTOTUNE

BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224
STEPS_PER_EPOCH = np.ceil(3670/BATCH_SIZE)
CLASS_NAMES = None

#https://www.tensorflow.org/tutorials/load_data/images

def get_label(file_path):
    # convert the path to a list of path components
    #parts = tf.strings.split(file_path, result_type = 'RaggedTensor')
    parts = tf.strings.split(file_path)

    # The second to last is the class-directory
    return parts[-2] == CLASS_NAMES

def decode_img(img):
    # convert the compressed string to a 3D uint8 tensor
    img = tf.image.decode_jpeg(img, channels=3)

    # Use `convert_image_dtype` to convert to floats in the [0,1] range.
    img = tf.image.convert_image_dtype(img, tf.float32)

    # resize the image to the desired size.
    return tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])

def process_path(file_path):
    label = get_label(file_path)

    # load the raw data from the file as a string
    img = tf.io.read_file(file_path)
    img = decode_img(img)

    return img, label

def test():

    data_dir = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
                                         fname='flower_photos', untar=True)

    data_dir = pathlib.Path(data_dir)

    global CLASS_NAMES
    CLASS_NAMES = np.array([item.name for item in data_dir.glob('*') if item.name != "LICENSE.txt"])

    list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'))

    labeled_ds = list_ds.map(process_path)
    print('type: ', type(labeled_ds))

    for image, label in labeled_ds.take(1):
        print("Image shape: ", image.numpy().shape)
        print("Label: ", label.numpy())

def main():
    test()  

if __name__ == '__main__':
    main()

我收到以下错误,不知道如何解决此问题

2020-04-17 09:47:53.816123: W tensorflow/core/framework/op_kernel.cc:1655] OP_REQUIRES failed at strided_slice_op.cc:108 : Invalid argument: slice index -1 of dimension 0 out of bounds.
2020-04-17 09:47:53.820082: W tensorflow/core/framework/op_kernel.cc:1655] OP_REQUIRES failed at iterator_ops.cc:941 : Invalid argument: slice index -1 of dimension 0 out of bounds.
         [[{{node strided_slice}}]]
Traceback (most recent call last):
  File "C:\Users\VVJ3281\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow_core\python\eager\context.py", line 1897, in execution_mode
    yield
  File "C:\Users\VVJ3281\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow_core\python\data\ops\iterator_ops.py", line 659, in _next_internal
    output_shapes=self._flat_output_shapes)
  File "C:\Users\VVJ3281\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow_core\python\ops\gen_dataset_ops.py", line 2478, in iterator_get_next_sync
    _ops.raise_from_not_ok_status(e, name)
  File "C:\Users\VVJ3281\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow_core\python\framework\ops.py", line 6606, in raise_from_not_ok_status
    six.raise_from(core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: slice index -1 of dimension 0 out of bounds.
         [[{{node strided_slice}}]] [Op:IteratorGetNextSync]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File ".\img_sub_model.py", line 150, in <module>
    main()
  File ".\img_sub_model.py", line 145, in main
    test()
  File ".\img_sub_model.py", line 136, in test
    for image, label in labeled_ds.take(1):
  File "C:\Users\VVJ3281\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow_core\python\data\ops\iterator_ops.py", line 630, in __next__
    return self.next()
  File "C:\Users\VVJ3281\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow_core\python\data\ops\iterator_ops.py", line 674, in next
    return self._next_internal()
  File "C:\Users\VVJ3281\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow_core\python\data\ops\iterator_ops.py", line 665, in _next_internal
    return structure.from_compatible_tensor_list(self._element_spec, ret)
  File "C:\Users\VVJ3281\AppData\Local\Continuum\anaconda3\lib\contextlib.py", line 130, in __exit__
    self.gen.throw(type, value, traceback)
  File "C:\Users\VVJ3281\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow_core\python\eager\context.py", line 1900, in execution_mode
    executor_new.wait()
  File "C:\Users\VVJ3281\AppData\Local\Continuum\anaconda3\lib\site-packages\tensorflow_core\python\eager\executor.py", line 67, in wait
    pywrap_tensorflow.TFE_ExecutorWaitForAllPendingNodes(self._handle)
tensorflow.python.framework.errors_impl.InvalidArgumentError: slice index -1 of dimension 0 out of bounds.
         [[{{node strided_slice}}]]

通过一些偶然的巧合,我发现当CLASS_NAMES设置为None时,代码将运行,并且label_ds的lebel对象的值为'False'

请参见下面的输出

type:  <class 'tensorflow.python.data.ops.dataset_ops.MapDataset'>
Image shape:  (224, 224, 3)
Label:  False
python tensorflow tensorflow-datasets
1个回答
0
投票
发生错误是因为get_label执行越界列表访问

def get_label(file_path): # convert the path to a list of path components parts = tf.strings.split(file_path) # The second to last is the class-directory return parts[-2] == CLASS_NAMES

parts的大小为1。这是因为除非指定要使用的定界符,否则tf.strings.split将按空格分隔。要拆分为路径组件,应为parts = tf.strings.split(file_path, "/")

要调试此类问题,您可以在函数中添加tf.print语句,例如

def get_label(file_path): # convert the path to a list of path components parts = tf.strings.split(file_path) # The second to last is the class-directory tf.print(file_path) tf.print(len(parts)) return parts[-2] == CLASS_NAMES

© www.soinside.com 2019 - 2024. All rights reserved.