Dropout和BatchNormalization层抛出TypeError:不兼容的类型: vs. int32。值是1,没有它们的模型也可以工作

问题描述 投票:0回答:1

[在Tensorflow 2中使用自定义估算器时,如果模型包含BatchNorm或Dropout层,则在构建图形时tf会失败,并出现以下错误。当我注释掉Dropout和BatchNorm层时,它工作得很好。

[我使用的模型是一个简单的CNN模型,其末尾有两个conv块和密集层:

def build_conv_block(x: Model, filter_map_count: int, name: str):
    x = Conv2D(filter_map_count, (3, 3), name=f'{name}_conv_2d')(x)
    x = BatchNormalization(name=f'{name}_bn')(x)               <------- Error when not commented out
    x = ReLU(name=f'{name}_relu')(x)
    x = MaxPool2D((2, 2), name=f'{name}_max_pool_2d')(x)
    x = Dropout(0.25, name=f'{name}_dropout')(x)               <------- Error when not commented out
    return x


def get_model(params):
    input_image = Input(shape=params.input_shape)
    x = build_conv_block(input_image, filter_map_count=64, name='layer_1')
    x = build_conv_block(x, filter_map_count=128, name='layer_2')
    x = Flatten(name='flatten_conv')(x)
    output_pred = Dense(10, activation='softmax', name='output')(x)

    model = Model(inputs=input_image, outputs=output_pred)
    model.optimizer = Adam(learning_rate=params.learning_rate)
    return model

我在train_op中有一个标准model_fn,该图像将mnist图像和标签作为输入,而将class作为输出:

# Calculate gradients
with tf.GradientTape() as tape:
    y_pred = model(features, training=training)
    loss = tf.losses.categorical_crossentropy(labels, y_pred)

if mode == tf.estimator.ModeKeys.TRAIN:
    gradients = tape.gradient(loss, model.trainable_variables)
    train_op = model.optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

这是我得到的错误的回溯:

Traceback (most recent call last):
  File "F:/Projects/python/my_project/train.py", line 38, in <module>
    tf.estimator.train_and_evaluate(estimator, train_spec=train_spec, eval_spec=eval_spec)
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_estimator\python\estimator\training.py", line 473, in train_and_evaluate
    return executor.run()
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_estimator\python\estimator\training.py", line 613, in run
    return self.run_local()
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_estimator\python\estimator\training.py", line 714, in run_local
    saving_listeners=saving_listeners)
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 370, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1160, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1190, in _train_model_default
    features, labels, ModeKeys.TRAIN, self.config)
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1148, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "F:\Projects\python\my_project\model.py", line 62, in model_fn
    gradients = tape.gradient(loss, model.trainable_variables)
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_core\python\eager\backprop.py", line 1014, in gradient
    unconnected_gradients=unconnected_gradients)
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_core\python\eager\imperative_grad.py", line 76, in imperative_grad
    compat.as_str(unconnected_gradients.value))
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_core\python\eager\backprop.py", line 138, in _gradient_function
    return grad_fn(mock_op, *out_grads)
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_core\python\ops\cond_v2.py", line 120, in _IfGrad
    true_graph, grads, util.unique_grad_fn_name(true_graph.name))
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_core\python\ops\cond_v2.py", line 395, in _create_grad_func
    func_graph=_CondGradFuncGraph(name, func_graph))
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_core\python\framework\func_graph.py", line 915, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_core\python\ops\cond_v2.py", line 394, in <lambda>
    lambda: _grad_fn(func_graph, grads), [], {},
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_core\python\ops\cond_v2.py", line 373, in _grad_fn
    src_graph=func_graph)
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_core\python\ops\gradients_util.py", line 550, in _GradientsHelper
    gradient_uid)
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_core\python\ops\gradients_util.py", line 175, in _DefaultGradYs
    constant_op.constant(1, dtype=y.dtype, name="grad_ys_%d" % i)))
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_core\python\framework\constant_op.py", line 227, in constant
    allow_broadcast=True)
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_core\python\framework\constant_op.py", line 265, in _constant_impl
    allow_broadcast=allow_broadcast))
  File "F:\Python\envs\tf2\lib\site-packages\tensorflow_core\python\framework\tensor_util.py", line 484, in make_tensor_proto
    (dtype, nparray.dtype, values))
TypeError: Incompatible types: <dtype: 'variant'> vs. int32. Value is 1

它看起来类似于TF Issue #31894中提到的错误,但似乎无法解决此问题。 TypeError并没有告诉您错误发生的位置和原因,直接进行谷歌搜索也无济于事。

python tensorflow tensorflow2.0 tensorflow-estimator
1个回答
0
投票

尽管从TypeError变体与int32可能不太明显,但如果我们仔细检查日志,我们会发现在找到梯度时会发生错误:

  File "F:\Projects\python\my_project\model.py", line 62, in model_fn
    gradients = tape.gradient(loss, model.trainable_variables)

此外,应该注意的是,即使存在其中之一,我们也会收到相同的错误。因此,如果我们尝试分析BatchNormalizationDropout层中的通用属性,似乎两者都不在核心层之下,但是当我们仔细观察时,只有模型中的这两个层具有不同的训练/测试阶段,即在测试阶段退出doesn't zero out the values,而批处理规范在测试阶段使用moving mean and variance

现在,问题被缩小到使用具有不同训练/测试阶段的任何层。发生这种情况是因为tensorflow使用传递给模型的training参数来识别训练模式是否打开。

此问题可以通过使用解决

y_pred = model(features, training=True)

[找到梯度,即用于训练阶段并通过使用]

y_pred = model(features, training=False)

否则,即用于预测和评估阶段。

已链接:Errors where moving mean is not updating也已报告,可以通过添加相同属性来解决。

© www.soinside.com 2019 - 2024. All rights reserved.