我需要同时训练两个模型。每个模型都有一个带有可训练参数的不同激活函数。我想训练模型1和模型2,使模型1的激活函数参数(例如alpha1)与模型2的参数(例如alpha2)相距2;即| alpha_1-alpha_2 | > 2.我想知道如何将其包含在损失函数中进行训练。
我将使用torch.nn.PReLU
作为您讨论的参数激活。为方便起见,创建了get_weight
。
import torch
class Module(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.input = torch.nn.Linear(in_features, 2 * in_features)
self.activation = torch.nn.PReLU()
self.output = torch.nn.Linear(2 * in_features, out_features)
def get_weight(self):
return self.activation.weight
def forward(self, inputs):
return self.output(self.activation(self.inputs(inputs)))
这里,我正在使用一个优化器来优化您所讨论的两个模块的参数。 criterion
可以是mean squared error
,cross entropy
或您需要的任何其他内容。
module1 = Module(20, 1)
module2 = Module(20, 1)
optimizer = torch.optim.Adam(
itertools.chain(module1.parameters(), module2.parameters())
)
critertion = ...
这里仅一步之遥,您应该像通常那样将其打包在数据中的for循环中,希望对您来说足以理解它:
inputs = ...
targets = ...
output1 = module1(inputs)
output2 = module2(inputs)
loss1 = criterion(output1, targets)
loss2 = criterion(output2, targets)
total_loss = loss1 + loss2
total_loss += torch.nn.functional.relu(
2 - torch.abs(module1.get_weight() - module2.get_weight()).sum()
)
total_loss.backward()
optimizer.step()
此行是您在此情况下要遵循的内容:
total_loss += torch.nn.functional.relu(
2 - torch.abs(module1.get_weight() - module2.get_weight()).sum()
)
relu
被使用,因此网络不会仅因创建不同的权重而获得无限收益。如果没有,则权重之间的差异越大,损失将变为负数。在这种情况下,差异越大越好,但是在差距大于或等于2
之后,差异就没有了。
您可能需要将2
增加到2.1
,或者如果您必须通过2
的阈值,则在接近2.0
时优化值的动机会很小。
没有显式指定阈值可能会很困难,但也许这样会起作用:
total_loss = (
(torch.abs(module1) + torch.abs(module2)).sum()
+ (1 / torch.abs(module1) + 1 / torch.abs(module2)).sum()
- torch.abs(module1 - module2).sum()
)
对于网络来说有点黑,但是可能值得一试(如果您应用其他L2
正则化)。
本质上,此损失将在相应位置的-inf, +inf
对权重处具有最佳值,并且永远不会小于零。
对于那些重量
weights_a = torch.tensor([-1000.0, 1000, -1000, 1000, -1000])
weights_b = torch.tensor([1000.0, -1000, 1000, -1000, 1000])
每个部分的损失将是:
(torch.abs(module1) + torch.abs(module2)).sum() # 10000
(1 / torch.abs(module1) + 1 / torch.abs(module2)).sum() # 0.0100
torch.abs(module1 - module2).sum() # 10000
在这种情况下,网络可以通过在两个模块中使用相反的符号来增加权重,而忽略您要优化的内容(两个模块的权重较大L2
可能会有所帮助,我认为最佳值为[C0 ] / 1
(如果-1
的L2
等于alpha
),并且我怀疑网络可能非常不稳定。
使用此丢失功能,如果网络错误地发现了较大的权重,将受到严厉的处罚。
在这种情况下,您将需要使用1
alpha参数进行调整以使其正常运行,这并不严格,但仍需要选择超参数。