使用TensorFlow数据集的意外尺寸

问题描述 投票:0回答:1

我正在尝试使用MNIST数据集上的InceptionV3进行迁移学习。

计划是读取MNIST数据集,调整图像大小,然后使用它们进行训练,如下所示:

import numpy as np
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import tensorflow.compat.v2 as tf
import tensorflow.compat.v1 as tfv1
from tensorflow.python.keras.applications import InceptionV3

tfv1.enable_v2_behavior()

print(tf.version.VERSION)

img_size = 299

def preprocess_tf_image(image, label):
  image = tf.image.grayscale_to_rgb(image)
  image = tf.image.resize(image, [img_size, img_size])
  return image, label

#Acquire MNIST data
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
#Convert data to [0,1] range
x_train, x_test = x_train / 255.0, x_test / 255.0

#Add extra dimension to images so that they can be converted to RGB
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test  = x_test.reshape (x_test.shape[0],  28, 28, 1)

x_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
x_test  = tf.data.Dataset.from_tensor_slices((x_test, y_test))

#Convert images to RGB space and resize
x_train = x_train.map(preprocess_tf_image)
x_test  = x_test.map(preprocess_tf_image)

img_shape = (img_size, img_size, 3)

#Get trained model, but leave off the head
base_model = InceptionV3(input_shape = img_shape, weights='imagenet', include_top=False)
base_model.trainable = False

#Make a model with a new head
model = tf.keras.Sequential([
  base_model,
  tf.keras.layers.GlobalAveragePooling2D(),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

#Compile model
model.compile(
    optimizer='adam', #tf.keras.optimizers.RMSprop(lr=BASE_LEARNING_RATE),
    loss='binary_crossentropy',
    metrics=['accuracy']
)

model.fit(x_train, epochs=5)

model.evaluate(x_test)

但是,当我运行此命令时,事情在model.fit()处停止并出现错误:

[ValueError:检查输入时出错:预期inception_v3_input具有4个维,但数组的形状为(299,299,3)]

发生了什么事?

python tensorflow machine-learning tensorflow-datasets mnist
1个回答
0
投票

您的预处理功能必须将图像尺寸调整为[批处理,高度,宽度,通道],如下所示:

def preprocess_tf_image(image, label):
    image = tf.reshape(image, [-1, image.shape[0], image.shape[1], image.shape[2]])
    image = tf.image.grayscale_to_rgb(image)
    image = tf.image.resize(image, [img_size, img_size])
    return image, label
© www.soinside.com 2019 - 2024. All rights reserved.