standardize_input_data()的ValueError

问题描述 投票:1回答:1

我正在尝试在Tensorflow 2.0后端上使用Keras实现一个简单的Unet网络。

我使用没有附加数据的自定义图像生成器。我的模板和蒙版是1536x1536 RGB图像(蒙版是黑白的)。

def data_gen(templates_folder, masks_folder, image_width, batch_size): # Custom image generator
    counter = 0
    images_list = os.listdir(templates_folder)
    random.shuffle(images_list) 
    while True:
        templates_pack = np.zeros((batch_size, image_width, image_width, 3)).astype('float') 
        masks_pack = np.zeros((batch_size, image_width, image_width, 1)).astype('float')
        for i in range(counter, counter + batch_size): 
            template = cv2.imread(templates_folder + '/' + images_list[i]) / 255.
            templates_pack[i - counter] = template  

            mask = cv2.imread(masks_folder + '/' + images_list[i], cv2.IMREAD_GRAYSCALE) / 255.
            mask = mask.reshape(image_width, image_width, 1) # Add extra dimension for parity with template size [1536 * 1536 * 3]
            masks_pack[i - counter] = mask 

        counter += batch_size
        if counter + batch_size >= len(images_list): 
            counter = 0
            random.shuffle(images_list) 
        yield templates_pack, masks_pack


callbacks = [
EarlyStopping(patience=10, verbose=1),
ReduceLROnPlateau(factor=0.1, patience=3, min_lr=0.00001, verbose=1),
ModelCheckpoint("model-prototype.h5", verbose=1, save_best_only=True, 
save_weights_only=True)
]
train_templates_path = "E:/train/templates"
train_masks_path = "E:/train/masks"
valid_templates_path = "E:/valid/templates"
valid_masks_path = "E:/valid/masks"
TRAIN_SET_SIZE = len(os.listdir(train_templates_path))
VALID_SET_SIZE = len(os.listdir(valid_templates_path))
BATCH_SIZE = 1
EPOCHS = 100
STEPS_PER_EPOCH = TRAIN_SET_SIZE / BATCH_SIZE 
VALIDATION_STEPS = VALID_SET_SIZE / BATCH_SIZE
IMAGE_WIDTH = 1536

train_generator = data_gen(train_templates_path, train_masks_path, IMAGE_WIDTH, batch_size = BATCH_SIZE)
val_generator = data_gen(valid_templates_path, valid_masks_path, IMAGE_WIDTH, batch_size = BATCH_SIZE)

model = get_unet(input_img, n_filters=16, dropout=0.05, batchnorm=True) 

results = model.fit_generator(train_generator, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, validation_data=val_generator, validation_steps=VALIDATION_STEPS, callbacks=callbacks)

由于某种原因,我收到以下错误:

Epoch 1/100
Traceback (most recent call last):
  File "E:/Explorium/python/unet_trainer.py", line 83, in <module>
    results = model.fit_generator(train_generator, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, validation_data=val_generator, validation_steps=VALIDATION_STEPS, callbacks=callbacks)
  File "C:\Users\E-soft\Anaconda3\envs\Explorium\lib\site-packages\tensorflow_core\python\keras\engine\training.py", line 1297, in fit_generator
    steps_name='steps_per_epoch')
  File "C:\Users\E-soft\Anaconda3\envs\Explorium\lib\site-packages\tensorflow_core\python\keras\engine\training_generator.py", line 265, in model_iteration
    batch_outs = batch_function(*batch_data)
  File "C:\Users\E-soft\Anaconda3\envs\Explorium\lib\site-packages\tensorflow_core\python\keras\engine\training.py", line 973, in train_on_batch
    class_weight=class_weight, reset_metrics=reset_metrics)
  File "C:\Users\E-soft\Anaconda3\envs\Explorium\lib\site-packages\tensorflow_core\python\keras\engine\training_v2_utils.py", line 253, in train_on_batch
    extract_tensors_from_dataset=True)
  File "C:\Users\E-soft\Anaconda3\envs\Explorium\lib\site-packages\tensorflow_core\python\keras\engine\training.py", line 2472, in _standardize_user_data
    exception_prefix='input')
  File "C:\Users\E-soft\Anaconda3\envs\Explorium\lib\site-packages\tensorflow_core\python\keras\engine\training_utils.py", line 574, in standardize_input_data
    str(data_shape))
ValueError: Error when checking input: expected img to have shape (1536, 1536, 1) but got array with shape (1536, 1536, 3)

似乎Keras无法使用standardize_input_data()标准化数据,但我不知道为什么会这样。

python machine-learning keras image-segmentation
1个回答
0
投票

错误是因为您正在使用cv2.IMREAD_GRAYSCALE读取图像如下所示替换行mask = cv2.imread(masks_folder +'/'+ images_list [i],cv2.IMREAD_COLOR)/ 255。

© www.soinside.com 2019 - 2024. All rights reserved.