想象一下以下设置:
我的数据加载器产生以下形状:(100,2,302,482,3) - 目标是将通道轴上的两个输入图像堆叠到(100,302,482,6)。
没有批量维度(因此x具有形状(2,302,482,3)),它非常容易:
# x.shape = (2, 302, 482, 3)
stacked = tf.concat(x, axis=-1)
# stacked.shape = (302, 482, 6)
但是,当添加批量维度时,我不知道要做同样的操作。
在我看来,最好的方法是在输入网络之前连接到2个图像(使用numpy),为网络提供尺寸(302,482,6),除非你想在网络中处理它更高。这取决于你的目标。编写图层时,批量大小无关紧要。无论批量大小,tf.concat
都将继续相同。