我想为具有主目标和辅助目标的数据集构建Keras模型。我有数据集中所有条目的辅助目标数据,但对于主目标,我只有所有数据点的子集数据。请考虑以下示例,该示例应该预测
max(min(x1, x2), x3)
但是对于某些值,它只给出了我的辅助目标min(x1, x2)
。
from keras.models import Model
from keras.optimizers import Adadelta
from keras.losses import mean_squared_error
from keras.layers import Input, Dense
import tensorflow as tf
import numpy
input = Input(shape=(3,))
hidden = Dense(2)(input)
min_pred = Dense(1)(hidden)
max_min_pred = Dense(1)(hidden)
model = Model(inputs=[input],
outputs=[min_pred, max_min_pred])
model.compile(
optimizer=Adadelta(),
loss=mean_squared_error,
loss_weights=[0.2, 1.0])
def random_values(n, missing=False):
for i in range(n):
x = numpy.random.random(size=(4, 3))
_min = numpy.minimum(x[..., 0], x[..., 1])
if missing:
_max_min = numpy.full((len(x), 1), numpy.nan)
else:
_max_min = numpy.maximum(_min, x[..., 2]).reshape((-1, 1))
yield x, [numpy.array(_min).reshape((-1, 1)), numpy.array(_max_min)]
model.fit_generator(random_values(50, False),
steps_per_epoch=50)
model.fit_generator(random_values(5, True),
steps_per_epoch=5)
model.fit_generator(random_values(50, False),
steps_per_epoch=50)
显然,上面的代码不起作用 - 具有NaN的目标意味着NaN的损失,这意味着NaN的权重适应,因此权重转向NaN并且模型变得无用。 (另外,实例化整个NaN阵列是浪费的,但原则上我的缺失数据可能是存在数据的任何批次的一部分,因此为了拥有同质数组,这似乎是合理的。)
我的代码不必与所有keras
后端一起使用,tensorflow
-only代码没问题。我试过改变损失功能,
def loss_0_where_nan(loss_function):
def filtered_loss_function(y_true, y_pred):
with_nans = loss_function(y_true, y_pred)
nans = tf.is_nan(with_nans)
return tf.where(nans, tf.zeros_like(with_nans), with_nans)
return filtered_loss_function
并使用loss_0_where_nan(mean_squared_error)
作为新的损失函数,但it still introduces NaNs。
我应该如何处理主要预测输出的缺失目标数据,其中我有辅助目标数据? masking会帮忙吗?
在您的问题中,您将展示数据集中缺少数据可预测块的情况。如果您可以将缺失数据和现有数据分开,则可以使用
truncated_model = Model(inputs=[input],
outputs=[min_pred])
truncated_model.compile(
optimizer=Adadelta(),
loss=[mean_squared_error])
定义与您的完整模型共享某些图层的模型,然后替换
model.fit_generator(random_values(5, True),
steps_per_epoch=5)
同
def partial_data(entry):
x, (y0, y1) = entry
return x, y0
truncated_model.fit_generator(map(partial_data, random_values(5, True)),
steps_per_epoch=5)
在非缺失数据上训练截断模型。
鉴于对输入数据提供者的这种控制水平,您显然可以调整您的random_values
方法,使其甚至不会生成partial_data
立即丢弃的数据,但我认为这将是更明确的方式来呈现必要的更改。