张量流指标

问题描述 投票:0回答:0

我对 TensorFlow 指标有一些疑问。

  1. 如果我在模型上调用 compile() 方法时将基于函数的自定义指标传递给 metrics 参数,例如以下 R 平方函数:
def r_squared(y_true, y_pred):
    ss_residual = tf.math.reduce_sum(tf.math.square(y_true - y_pred))
    ss_total = tf.math.reduce_sum(tf.math.square(y_true - tf.math.reduce_mean(y_true)))
    return 1 - ss_residual / (ss_total + tf.keras.backend.epsilon())

我知道每个批次都会调用这个函数,但我不确定 TensorFlow 是否会在每个 epoch 之后自动计算每批次的 R 平方平均值。

  1. 如果我使用基于类的自定义指标,我知道 result() 方法将在每个批次后调用,但我不确定 TensorFlow 是否会在每个时期后自动计算每批次指标值的平均值。

  2. 当我运行以下代码时:

model.compile(
    optimizer=Adam(),
    loss=tf.keras.losses.MeanSquaredError(),
    metrics=[
        tf.keras.metrics.MeanAbsoluteError(),
        tfa.metrics.RSquare(),
        tf.keras.metrics.MeanRelativeError(
            normalizer="y_true",
            name="residual_per_unit_observation"
        )
    ]
)

我收到以下错误:

File "C:\Users\user\Documents\GMM_AI\input_selection.py", line 94, in train_pipeline
    tf.keras.metrics.MeanRelativeError(
  File "C:\Users\user\anaconda3\envs\ai\lib\site-packages\keras\dtensor\utils.py", line 144, in _wrap_function
    init_method(instance, *args, **kwargs)
  File "C:\Users\user\anaconda3\envs\ai\lib\site-packages\keras\metrics\metrics.py", line 95, in __init__
    normalizer = tf.cast(normalizer, self._dtype)
  File "C:\Users\user\anaconda3\envs\ai\lib\site-packages\tensorflow\python\util\traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "C:\Users\user\anaconda3\envs\ai\lib\site-packages\tensorflow\python\framework\ops.py", line 7209, in raise_from_not_ok_status
    raise core._status_to_exception(e) from None  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.UnimplementedError: {{function_node __wrapped__Cast_device_/job:localhost/replica:0/task:0/device:CPU:0}} Cast string to float is not supported [Op:Cast]

问题似乎是 MeanRelativeError 中的 normalizer 参数不能使用字符串“y_true”作为参数。我如何使用 y_true 来传递规范器参数?

提前感谢您的帮助。

python tensorflow metrics
© www.soinside.com 2019 - 2024. All rights reserved.