tensorflow.keras.layers.Flatten() 抛出 INVALID_ARGUMENT 错误

问题描述 投票:0回答:1

我已经使用带有分组卷积的tensorflow.keras创建了这个计算机视觉模型:

    n_filters = [8, 16, 32]
    img_dims = (128,572,1)
    layer_dims = calculate_dims(img_dims,conv_filters=n_filters)
    print('Layer dimensions')
    for l in layer_dims:
        print(l)
    last_neurons = layer_dims[-1][0]*layer_dims[-1][1]*layer_dims[-1][2]
    
    model = Sequential()
    
    model.add(tf.keras.Input(shape=(img_dims)))
    
    model.add(layers.Conv2D(n_filters[0], kernel_size=3, padding = "same", activation = "relu"))

    for n in n_filters[1:]:
        model.add(layers.Conv2D(n, kernel_size=3, padding = "same", activation = "relu",groups=2))
        model.add(layers.MaxPool2D(2))    

    model.add(layers.Flatten())
    
    model.add(layers.Dense(64, activation="relu"))
    
    model.add(layers.Dense(8, activation="relu"))

    model.add(layers.Dense(1, activation='relu'))

    return model

这是 model.summary() 输出

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape           ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ conv2d (Conv2D)                 │ (None, 128, 572, 8)    │            80 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_1 (Conv2D)               │ (None, 128, 572, 16)   │           592 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d (MaxPooling2D)    │ (None, 64, 286, 16)    │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_2 (Conv2D)               │ (None, 64, 286, 32)    │         2,336 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_1 (MaxPooling2D)  │ (None, 32, 143, 32)    │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ flatten (Flatten)               │ (None, 146432)         │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense (Dense)                   │ (None, 64)             │     9,371,712 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_1 (Dense)                 │ (None, 8)              │           520 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_2 (Dense)                 │ (None, 1)              │             9 │
└─────────────────────────────────┴────────────────────────┴───────────────┘

总参数:9,375,249 (35.76 MB) 可训练参数:9,375,249 (35.76 MB) 不可训练参数:0 (0.00 B)

这样训练。

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                  loss="mean_squared_error",
                  metrics=['root_mean_squared_error', 'mean_absolute_error'], 
             )
model_history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    batch_size=4,
    epochs=2,
    verbose=2
)

我在展平层上遇到错误:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
Cell In[16], line 1
----> 1 model_history = model.fit(
      2     train_dataset,
      3     validation_data=val_dataset,
      4     batch_size=4,
      5     epochs=2,
      6     # callbacks=callbacks,
      7     verbose=2
      8 )
      9 completed = dt.now()

File ~/torch_cuda12_env/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File ~/torch_cuda12_env/lib/python3.10/site-packages/tensorflow/python/eager/execute.py:53, in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     51 try:
     52   ctx.ensure_initialized()
---> 53   tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     54                                       inputs, attrs, num_outputs)
     55 except core._NotOkStatusException as e:
     56   if name is not None:

InvalidArgumentError: Graph execution error:

Detected at node sequential_1/flatten_1/Reshape defined at (most recent call last):
<stack traces unavailable>
only one input size may be -1, not both 0 and 1

Stack trace for op definition: 
File "usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
File "usr/lib/python3.10/runpy.py", line 86, in _run_code
File "/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
File "/python3.10/site-packages/traitlets/config/application.py", line 1075, in launch_instance
File "/python3.10/site-packages/ipykernel/kernelapp.py", line 739, in start
File "/python3.10/site-packages/tornado/platform/asyncio.py", line 205, in start
File "usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
File "usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
File "usr/lib/python3.10/asyncio/events.py", line 80, in _run
File "/python3.10/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
File "/python3.10/site-packages/ipykernel/kernelbase.py", line 534, in process_one
File "/python3.10/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
File "/python3.10/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
File "/python3.10/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
File "/python3.10/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
File "/python3.10/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
File "/python3.10/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell
File "/python3.10/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell
File "/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
File "/python3.10/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
File "/python3.10/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
File "/python3.10/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
File "tmp/ipykernel_97254/668273277.py", line 1, in <module>
File "/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 329, in fit
File "/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 122, in one_step_on_iterator
File "/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 110, in one_step_on_data
File "/python3.10/site-packages/keras/src/backend/tensorflow/trainer.py", line 57, in train_step
File "/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/python3.10/site-packages/keras/src/layers/layer.py", line 826, in __call__
File "/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/python3.10/site-packages/keras/src/ops/operation.py", line 48, in __call__
File "/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler
File "/python3.10/site-packages/keras/src/models/sequential.py", line 206, in call
File "/python3.10/site-packages/keras/src/models/functional.py", line 199, in call
File "/python3.10/site-packages/keras/src/ops/function.py", line 151, in _run_through_graph
File "/python3.10/site-packages/keras/src/models/functional.py", line 583, in call
File "/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/python3.10/site-packages/keras/src/layers/layer.py", line 826, in __call__
File "/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/python3.10/site-packages/keras/src/ops/operation.py", line 48, in __call__
File "/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler
File "/python3.10/site-packages/keras/src/layers/reshaping/flatten.py", line 54, in call
File "/python3.10/site-packages/keras/src/ops/numpy.py", line 4527, in reshape
File "/python3.10/site-packages/keras/src/backend/tensorflow/numpy.py", line 1618, in reshape

     [[{{node sequential_1/flatten_1/Reshape}}]]
    tf2xla conversion failed while converting __inference_one_step_on_data_2376[]. Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and --vmodule=xla_compiler=2 to obtain a dump of the compiled functions.
     [[StatefulPartitionedCall]] [Op:__inference_one_step_on_iterator_2471]

我尝试阅读有关 keras 层的更多信息,但没有任何迹象表明 Flatten 层中存在错误。我还尝试将 input_shape 赋予 Flatten 层为 (32, 143, 32),但即使这样也没有解决问题。

不确定库版本是否与我第一次遇到的这个错误有关。我正在使用这些库:

  • keras==3.2.1
  • numpy==1.26.4
  • nvidia-cudnn-cu12==8.9.2.26
  • 张量流==2.16.1
  • 火炬==2.2.2
  • 火炬音频==2.2.2
  • 火炬视觉==0.17.2 我并没有真正使用这个模型的 torch 库,而是将其放在这里以防与其他库版本交互。
tensorflow keras deep-learning computer-vision
1个回答
0
投票

这是tensorflow v2.16的问题。降级到 v2.14 后,模型可以正常编译和训练。

© www.soinside.com 2019 - 2024. All rights reserved.