ResNet50 CNN 模型不适合

问题描述 投票:0回答:1

下面是我使用 ResNet50 的 CNN 代码,我过去曾使用过它,但是,我最近更改了模型输入的结构,似乎出现了以下错误。我还打印了图像阵列的形状,您可以在下面看到

train_inputs.shape =  (4727, 224, 224, 3)
。还有
len(train_labels) = 4727

此外,错误似乎发生在以下代码块中:

history = model.fit(train_inputs_resized, train_labels, epochs=10, batch_size=32, validation_split=0.2, callbacks=[early_stopping, model_checkpoint])

错误信息:

Packages Loaded                                                                                                                             
main started                                                                                                                    
(4727, 6, 100, 100)                                                                                                                                         
4727                                                                                            
images collected                                                                                                                                
(4727, 6, 224, 224)                                                                                                                                         
(4727, 6, 224, 224)                                                                                                                                         
(4727, 6, 224, 224, 3)                                                                                                                                       
(4727, 224, 224, 3)                                                                                                                                         
model loaded                                                                                                                    
model built                                                                                                                 
(4727, 224, 224, 3)                                                                                                                                         
4727                                                                                            
Epoch 1/10                                                                                                              
Traceback (most recent call last):                                                                                                                           
  File "/oscar/scratch/aagudel1/SenID/Code/CNN_RN503D.py", line 388, in <module>                                                                             
    main()                                                                                                              
  File "/oscar/scratch/aagudel1/SenID/Code/CNN_RN503D.py", line 352, in main                                                                                 
    history = model.fit(train_inputs_resized, train_labels, epochs=10, batch_siz                                                                             e=32, validation_split=0.2, callbacks=[early_stopping, model_checkpoint])                                                                                    
  File "/users/aagudel1/.local/lib/python3.10/site-packages/keras/src/utils/trac                                                                             eback_utils.py", line 122, in error_handler                                                                                                                  
    raise e.with_traceback(filtered_tb) from None                                                                                                            
  File "/users/aagudel1/.local/lib/python3.10/site-packages/tensorflow/python/ea                                                                             ger/execute.py", line 53, in quick_execute                                                                                                                   
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,                                                                                   
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution er                                                                             ror:                                                                                            

Detected at node LogicalAnd defined at (most recent call last):
  File "/oscar/scratch/aagudel1/SenID/Code/CNN_RN503D.py", line 388, in <module>

  File "/oscar/scratch/aagudel1/SenID/Code/CNN_RN503D.py", line 352, in main

  File "/users/aagudel1/.local/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/users/aagudel1/.local/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 325, in fit

  File "/users/aagudel1/.local/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 118, in one_step_on_iterator

  File "/users/aagudel1/.local/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 106, in one_step_on_data

  File "/users/aagudel1/.local/lib/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 77, in train_step

  File "/users/aagudel1/.local/lib/python3.10/site-packages/keras/src/trainers/trainer.py", line 376, in compute_metrics

  File "/users/aagudel1/.local/lib/python3.10/site-packages/keras/src/traineompile_utils.py", line 330, in update_state

  File "/users/aagudel1/.local/lib/python3.10/site-packages/keras/src/trompile_utils.py", line 17, in update_state

  File "/users/aagudel1/.local/lib/python3.10/site-packages/keras/src/bansorflow/numpy.py", line 1194, in logical_and

Incompatible shapes: [1,64] vs. [1,32]
         [[{{node LogicalAnd}}]] [Op:__inference_one_step_on_iterator_12

主要功能完整代码:

def main():
"""
This script performs the following tasks:
1. Collects and processes images of cycling and senescent cells.
2. Splits the data into training and testing sets.
3. Loads a pre-trained ResNet50 model and adds custom layers.
4. Trains the model using the processed images.
5. Evaluates the model's performance on the test set.
6. Generates a confusion matrix and classification report.
7. Visualizes the results, including correctly and incorrectly classified examples.
"""
print("main started")
script_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = os.path.join(script_dir, "../Data/")
out1_dir = os.path.join(output_dir, "/users/aagudel1/scratch/SenID/Images/SegmentedImages/Prolif/")
out2_dir = os.path.join(output_dir, "/users/aagudel1/scratch/SenID/Images/SegmentedImages/Senescent/")
train_inputs, test_inputs, train_labels, test_labels = collect_images(out1_dir, out2_dir, train_ratio=0.8, random_state=42)
print(train_inputs.shape)
print(len(train_labels))
print("images collected")
# Assuming images are already 3D
train_inputs = ImageCropper.crop_images(train_inputs)
print(train_inputs.shape)
test_inputs = ImageCropper.crop_images(test_inputs)
train_inputs = preprocess_input(train_inputs)
test_inputs = preprocess_input(test_inputs)
print(train_inputs.shape)
train_inputs_resized = np.repeat(train_inputs[..., np.newaxis], 3, axis=-1)
print(train_inputs_resized.shape)
test_inputs_resized = np.repeat(test_inputs[..., np.newaxis], 3, axis=-1)
# Combine slices using Maximum Intensity Projection (MIP)
train_inputs_resized = np.max(train_inputs_resized, axis=1)
print(train_inputs_resized.shape)
test_inputs_resized = np.max(test_inputs_resized, axis=1)
# Load ResNet50 as a base model and add custom layers
base_model = ResNet50(weights='imagenet', include_top=False)



print("model loaded")
# Unfreeze the last few layers
for layer in base_model.layers[:-5]:
    layer.trainable = False
for layer in base_model.layers[-5:]:
    layer.trainable = True

# Add custom layers
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)
predictions = Dense(2, activation='softmax', kernel_regularizer=regularizers.l2(0.01))(x)

# Build the final model
model = Model(inputs=base_model.input, outputs=predictions)
print("model built")
# Compile the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy', Precision(), Recall()])

# Early stopping and model checkpoints
early_stopping = EarlyStopping(monitor='val_loss', patience=5)
model_checkpoint = ModelCheckpoint('best_model.keras', monitor='val_loss', save_best_only=True)

# Train the model
print(train_inputs_resized.shape)
print(len(train_labels))
history = model.fit(train_inputs_resized, train_labels, epochs=10, batch_size=32, validation_split=0.2, callbacks=[early_stopping, model_checkpoint])
plot_history(history, "model_loss.png")
print("model trained")
#tf.print("Value of output tensor:", output)

# Evaluate the model
test_loss, test_acc = model.evaluate(test_inputs_resized, test_labels)

print(f"Test accuracy: {test_acc}")
print(f"Test loss: {test_loss}")

y_pred = model.predict(test_inputs_resized)
y_pred_binary = np.argmax(y_pred, axis=1)
class_labels = ["Cycling", "Senescene"]
create_confusion_matrix(test_labels, y_pred_binary, class_labels, "confusion_matrix.png")

print(classification_report(test_labels, y_pred_binary))
accuracy = accuracy_score(test_labels, y_pred_binary)
print("Test Accuracy", accuracy.round(4) * 100)
precision = precision_score(test_labels, y_pred_binary)
print("Test Precision", precision.round(4) * 100)
recall = recall_score(test_labels, y_pred_binary)
print("Test Recall", recall.round(4) * 100)
f1 = f1_score(test_labels, y_pred_binary)
print("Test F1", f1.round(4) * 100)

test_inputs = (test_inputs_resized - test_inputs_resized.min()) / (test_inputs_resized.max() - test_inputs_resized.min())

visualize_results(
    image_inputs=test_inputs_resized,
    probabilities=y_pred,
    image_labels=test_labels,
    first_label="Cycling",
    second_label="Senescene")

关于如何解决此问题或排除故障有什么想法吗?错误消息显示形状 [1,64] 与 [1,32] 不兼容,但我不知道它可能指的是哪些形状

python tensorflow keras conv-neural-network resnet
1个回答
0
投票

如果你只有 2 门课,最好使用

loss = 'binary_crossentropy'
编译中,并且
Dense(1, active='sigmoid', ...)
作为输出层。

© www.soinside.com 2019 - 2024. All rights reserved.