我想在 keras 中制作一个简单的模型,其中只有很少的密集层,其中第一层的权重和偏差可以由另一个模型创建。与本文类似,但更简单,其中 CNN 也被替换为密集层。
我试过这个:
aux_in = np.random.random((2000, 10)) # features to generate weights
data_in = np.random.random((2000, 40))# main features
target = np.random.random((2000, 4))# annotations to be predicted
def create_ws(dim1,dim2):
# define our MLP network
model = Sequential()
initializer = RandomNormal(mean=0., stddev=1.)
model.add(Dense(dim2, input_dim=dim1, activation="relu"))
model.add(Dense(128, activation="relu"))
return model
def create_FCN(dim,ws):
# define our MLP network
model = Sequential()
initializer = RandomNormal(mean=0., stddev=1.)
model.add(Dense(128,
input_dim=dim,
activation="sigmoid",
kernel_initializer=initializer
))
model.set_weights(ws)
return model
wgen = create_ws(aux_in.shape[1],data_in.shape[1])
mainf = create_FCN(data_in.shape[1], wgen.get_weights()[2:])
x = Dense(32, activation="sigmoid")(mainf.output)
x = Dense(4, activation="linear")(x)
main_model = Model(inputs=[mainf.input, wgen.input], outputs=x)
opt = Nadam(learning_rate=0.001)
model.compile(loss="mean_squared_error", optimizer=opt, metrics=['accuracy'])
history=model.fit(x=[data_in,aux_in], y=target, validation_split=0.1, epochs=10, batch_size=32)
问题是,“wgen”模型似乎与“main_model”分离,因此其权重在训练“main_model”之前和之后保持相同。 “mainf”和“main_model”之间的连接似乎是正确的,因为当我在训练前后检查“main_model”上的权重时,很明显它们正在更新。
我应该如何重新连接这个模型,以便将权重生成网络(“wgen”)包含到主模型管道中?
我想将权重生成网络(“wgen”)包含到主模型管道中并确保其权重在训练期间更新,您需要对代码进行一些修改。关键是要确保“wgen”生成的权重在主模型中是连接和可训练的。
model.add(密集(dim2,input_dim = dim1,激活=“relu”,trainable = True))
def create_FCN(dim, ws_input):
# Define the main model
x = Dense(128, activation="sigmoid", kernel_initializer=RandomNormal(mean=0., stddev=1.))(dim)
x = Dense(32, activation="sigmoid")(x)
x = Dense(4, activation="linear")(x)
main_model = Model(inputs=[dim, ws_input], outputs=x)
return main_model
然后,我们使用“wgen”的输出作为主模型(“main_model”)的输入。这样,“wgen”生成的权重就会合并到主模型的计算图中,并且它们将在训练过程中更新。
# Create the weight generating model and main model
wgen_input = Input(shape=(aux_in.shape[1],))
wgen = create_ws(aux_in.shape[1], data_in.shape[1])(wgen_input)
main_model = create_FCN(wgen, wgen_input)
# Compile and train the main model
opt = Nadam(learning_rate=0.001)
main_model.compile(loss="mean_squared_error", optimizer=opt, metrics=['accuracy'])
history = main_model.fit(x=[data_in, aux_in], y=target,
validation_split=0.1, epochs=10, batch_size=32)