我有一个包含多个较小模型的模型。我想确保所有内容都正确连接,并查看了由[[keras.utils.plot_model生成的图形。在那里,我偶然发现了一些看起来不正确的零件。缺少某些连接,并且在输入处绘制了一些其他连接(可能是由于模型的堆叠)。
为什么缺少连接?模型正确吗?图表正确吗?检查信息流是否如预期的最佳方法是什么?from tensorflow.keras.utils import plot_model
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, concatenate, Add
n_input_a0 = 10
# A -> n_input_b0 = n_input_c0
n_input_b0 = 20
n_input_b1 = 21
# B -> n_input_c1, n_input_c1
n_input_c0 = 20
n_input_c1 = 31
# C -> n_output
n_output = 3
# A
ip_a0 = Input(shape=(n_input_a0,), name='ip_a0')
x_a = Dense(units=10)(ip_a0)
x_a = Dense(units=n_input_b0)(x_a)
model_a = Model(inputs=ip_a0, outputs=x_a, name='model_a')
x_a2 = model_a(ip_a0)
# B
ip_b0 = Input(shape=(n_input_b0,), name='ip_b0')
ip_b1 = Input(shape=(n_input_b1,), name='ip_b1')
ip_b0b1 = concatenate([ip_b0, ip_b1])
x_b = Dense(units=10)(ip_b0b1)
x_b_temp = Dense(units=n_output)(x_b)
x_b_left = Dense(units=n_input_c1)(x_b_temp)
x_b_right = Dense(units=n_input_c1)(x_b_temp)
model_b = Model(inputs=[ip_b0, ip_b1], outputs=[x_b_left, x_b_right], name='model_b')
x_b_left2, x_b_right2 = model_b([x_a2, ip_b1])
x_b2 = Add()([x_b_left2, x_b_right2])
# C
ip_c0 = Input(shape=(n_input_c0,), name='ip_c0')
ip_c1 = Input(shape=(n_input_c1,), name='ip_c1')
ip_c0c1 = concatenate([ip_c0, ip_c1])
x_c = Dense(units=10)(ip_c0c1)
x_c = Dense(units=n_output)(x_c)
model_c = Model(inputs=[ip_c0, ip_c1], outputs=[x_c], name='model_c')
x_c2 = model_c([x_a2, x_b2])
# Combined Model
model_total = Model(inputs=[ip_a0, ip_b1], outputs=[x_c2], name='model_total', )
plot_model(model_total, expand_nested=True, show_shapes=True, to_file='model.png', dpi=64)
的输出。结果要好一点,但也缺少x_b_left + x_b_right的加法。我还查看了
TensorBoard
import numpy as np
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
with tf.compat.v1.Session() as sess:
writer = tf.compat.v1.summary.FileWriter('logs', sess.graph)
model_total([np.ones((1, n_input_a0)), np.ones((1, n_input_b1))])
writer.close()
x_a2 = model_a(ip_a0)
x_b_left2, x_b_right2 = model_b([x_a2, ip_b1])
ip_a0
和ip_b1
输入已经是模型的一部分,不应再次调用。 我会这样做(不检查图表是否得到纠正)
from tensorflow.keras.utils import plot_model
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, concatenate, Add
n_input_a0 = 10
# A -> n_input_b0 = n_input_c0
n_input_b0 = 20
n_input_b1 = 21
# B -> n_input_c1, n_input_c1
n_input_c0 = 20
n_input_c1 = 31
# C -> n_output
n_output = 3
# A
ip_a0 = Input(shape=(n_input_a0,), name='ip_a0')
x_a = Dense(units=10)(ip_a0)
x_a = Dense(units=n_input_b0)(x_a)
model_a = Model(inputs=ip_a0, outputs=x_a, name='model_a')
#don't call the model on its own inputs!!!!
x_a2 = model_a.output
#alternatively, following the same pattern as followed for model B
#ip_a0_external = Input(shape=(n_input_a0,), name='ip_a0_external')
#x_a2 = model_a(ip_a0_external)
# B
ip_b0 = Input(shape=(n_input_b0,), name='ip_b0')
ip_b1 = Input(shape=(n_input_b1,), name='ip_b1')
ip_b0b1 = concatenate([ip_b0, ip_b1])
x_b = Dense(units=10)(ip_b0b1)
x_b_temp = Dense(units=n_output)(x_b)
x_b_left = Dense(units=n_input_c1)(x_b_temp)
x_b_right = Dense(units=n_input_c1)(x_b_temp)
model_b = Model(inputs=[ip_b0, ip_b1], outputs=[x_b_left, x_b_right], name='model_b')
#don't call the model on its own inputs!!
#create a new input for this
ip_b1_external = Input(shape=(n_input_b1,), name='ip_b1_external')
x_b_left2, x_b_right2 = model_b([x_a2, ip_b1_external])
x_b2 = Add()([x_b_left2, x_b_right2])
# C
ip_c0 = Input(shape=(n_input_c0,), name='ip_c0')
ip_c1 = Input(shape=(n_input_c1,), name='ip_c1')
ip_c0c1 = concatenate([ip_c0, ip_c1])
x_c = Dense(units=10)(ip_c0c1)
x_c = Dense(units=n_output)(x_c)
model_c = Model(inputs=[ip_c0, ip_c1], outputs=[x_c], name='model_c')
#ok, both inputs are from outside
x_c2 = model_c([x_a2, x_b2])
# Combined Model
# keep track of what inputs were actual inputs for this model
# and what inputs were only helpers for the submodels:
model_total = Model(inputs=[model_a.input, ip_b1_external], outputs=[x_c2], name='model_total', )
#alternatively
#model_total = Model(inputs=[ip_a0_external, ip_b1_external], outputs=[x_c2])
plot_model(model_total, expand_nested=True, show_shapes=True, to_file='model.png', dpi=64)