我在本地有一个非常庞大的图像数据库,像每个文件夹一样的数据分布包含了一个类的图像。
我想使用tensorflow数据集API来获取批量数据而不必将所有图像加载到内存中。
我尝试过这样的事情:
def _parse_function(filename, label):
image_string = tf.read_file(filename, "file_reader")
image_decoded = tf.image.decode_jpeg(image_string, channels=3)
image = tf.cast(image_decoded, tf.float32)
return image, label
image_list, label_list, label_map_dict = read_data()
dataset = tf.data.Dataset.from_tensor_slices((tf.constant(image_list), tf.constant(label_list)))
dataset = dataset.shuffle(len(image_list))
dataset = dataset.repeat(epochs).batch(batch_size)
dataset = dataset.map(_parse_function)
iterator = dataset.make_one_shot_iterator()
image_list是一个列表,其中附加了图像的路径(和名称),label_list是一个列表,其中每个图像的类以相同的顺序附加。
但_parse_function不起作用,我所接受的错误是:
ValueError:Shape必须为0级,但对于'file_reader'(op:'ReadFile'),其输入形状为[?]。
我用Google搜索了错误,但对我来说没有任何作用。
如果我不使用map函数,我只是重新思考图像的路径(存储在image_list中),所以我认为我需要map函数来读取图像,但我无法使其工作。
先感谢您。
编辑:
def read_data():
image_list = []
label_list = []
label_map_dict = {}
count_label = 0
for class_name in os.listdir(base_path):
class_path = os.path.join(base_path, class_name)
label_map_dict[class_name]=count_label
for image_name in os.listdir(class_path):
image_path = os.path.join(class_path, image_name)
label_list.append(count_label)
image_list.append(image_path)
count_label += 1
错误在此行中dataset = dataset.repeat(epochs).batch(batch_size)
您的管道将batchsize添加为输入维。
您需要在映射函数之后批量处理数据集
dataset = tf.data.Dataset.from_tensor_slices((tf.constant(image_list), tf.constant(label_list)))
dataset = dataset.shuffle(len(image_list))
dataset = dataset.repeat(epochs)
dataset = dataset.map(_parse_function).batch(batch_size)