感谢您在以下方面的帮助。为方便起见,请考虑一个模拟问题,其中子类模型接受输入(None,3,3,1)并返回第四元标签0-3。但是,假设算法首先需要将批处理数据细分为两个互斥的集合(例如,具有标签0-1和标签2-3的那些),然后使用更精细的网络处理每个集合并将其组合回(无,4)形状。根据下面给出的代码,我使用了一个嵌套的交换层来处理批处理中仅包含i)0-1,ii)2-3和iii)0-1和2-3的标签的情况。 。
该模型确实给出了适当的摘要,也可以按预期进行预测,但是由于不存在渐变值,因此我无法为其拟合数据。如果您能提供任何帮助,我将不胜感激。
import tensorflow as tf
class CustomModel(tf.keras.Model):
def __init__(self, nClasses = 2):
super(CustomModel, self).__init__()
self.nClasses = nClasses
self.Net = tf.keras.Sequential()
self.Net.add(tf.keras.layers.Conv2D(filters=4, kernel_size=3, activation="relu", padding="same"))
self.Net.add(tf.keras.layers.Dense(5, activation='relu'))
self.Net.add(tf.keras.layers.Flatten())
self.Net.add(tf.keras.layers.Dense(self.nClasses, activation='softmax'))
def call(self, x):
y = self.Net(x)
return y
class CustomModelSwitch(tf.keras.Model):
def __init__(self):
super(CustomModelSwitch, self).__init__()
self.Net0 = CustomModel(2) # this governs intermediate binary lables (0-1)
self.Net1 = CustomModel(4) # this governs final quaternary lables (0-1-2-3)
self.Net1.build((None, 3, 3, 1)) # Gives shape errors without this
def call(self, x):
def MyGetSubInput(x,Index,Val) : # Returns a subset of x with the lables "Val"
y = tf.keras.backend.squeeze(tf.where(tf.keras.backend.equal(Index,Val)),axis=1)
y = tf.keras.backend.gather(x, y)
return y
def MyIsSubBatchInputEmpty(x,Argmax,Val):
SubZeroOne = tf.keras.layers.Lambda(MyGetSubInput,arguments={'Index': Argmax, 'Val': Val})(x)
return tf.keras.backend.equal(tf.keras.backend.shape(SubZeroOne)[0],0)
def MyLayer01(x):
Sub0 = self.Net1(x)
Sub0 = tf.keras.layers.Lambda(lambda t : 2*t)(Sub0) # This is just a mock lambda layer to be replaced with an intended functionality.
return Sub0 # shape(None,4)
def MyLayer23(x):
Sub1 = self.Net1(x)
Sub1 = tf.keras.layers.Lambda(lambda t : 5*t)(Sub1) # This is just a mock lambda layer ...
return Sub1 # shape(None,4)
def MyLayer0123(x,Argmax):
Sub0 = tf.keras.layers.Lambda(MyGetSubInput,arguments={'Index': Argmax, 'Val': 0})(x) # this extracts 0 lables
Sub1 = tf.keras.layers.Lambda(MyGetSubInput,arguments={'Index': Argmax, 'Val': 1})(x) # this extracts 1 lables
Sub0 = self.Net1(Sub0)
Sub0 = tf.keras.layers.Lambda(lambda t : 2*t)(Sub0) # This is just a mock lambda layer ...
Sub1 = self.Net1(Sub1)
Sub1 = tf.keras.layers.Lambda(lambda t : 5*t)(Sub1) # This is just a mock lambda layer ...
Sub01 = tf.keras.layers.Lambda(lambda t: tf.concat(t,axis=0))([Sub0, Sub1])
return Sub01 # shape(None,4)
y = self.Net0(x) # this governs intermediate binary lables (0-1)
Argmax = tf.keras.layers.Lambda(lambda t: tf.keras.backend.cast(tf.keras.backend.argmax(t),'float32'))(y)
# The following is a nested switch: i) If Sub0 is empty or not and ii) if Sub1 is empty or not.
SubFinal = tf.keras.layers.Lambda(lambda t: tf.keras.backend.switch(MyIsSubBatchInputEmpty(t,Argmax,0),
MyLayer23(t),
tf.keras.backend.switch(MyIsSubBatchInputEmpty(t,Argmax,1),MyLayer01(t), MyLayer0123(t,Argmax))))(x)
return SubFinal # shape(None,4)
#---------------------------------------------------------------------------------------
def MyDefineData(DataCount):
import numpy as np
a = np.empty((DataCount,3,3,1),dtype='uint8')
b = np.zeros((DataCount),dtype='uint8')
for i in range(DataCount) :
a[i,:,:,:] = np.random.randint(low = 0, high = 2, size = 9).reshape((3,3,1))
from keras.utils import to_categorical
a = a.astype('float32') / 2
b[0] = 0
b[1] = 1
b[2] = 2
b[3] = 3
b[4] = 3
b[5] = 2
b = to_categorical(b,dtype=np.float32)
return a,b
inputData,inputLables = MyDefineData(6)
#------------------------------------------------------------------------------------
model = CustomModelSwitch()
model.build((None, 3, 3, 1))
model.summary()
model.compile(optimizer='adamax', loss=['categorical_crossentropy'], loss_weights=[1], metrics=['accuracy'])
model.fit(inputData, inputLables, epochs=3, batch_size=3, verbose=1)
<< img src =“ https://image.soinside.com/eyJ1cmwiOiAiaHR0cHM6Ly9pLnN0YWNrLmltZ3VyLmNvbS9wZzFTRy5qcGcifQ==” alt =“在此处输入图像描述”>
我不确定这里是否百分之一百,但在我看来,问题可能出在以下情况:您首先使用一个似乎要训练的网络,以便从x(y = net0 (X))。然后,对结果使用不可微操作(argmax)决定进一步的计算。但是,这意味着net0参数的微小变化可能导致发生完全不同的计算,因此,给定输入x作为其参数的函数的模型结果似乎不是连续的,更不用说是可区分的了。