在Pytorch中进行交互式训练模型

问题描述 投票:2回答:1

我需要同时训练两个模型。每个模型都有一个带有可训练参数的不同激活函数。我想训练模型1和模型2,使模型1的激活函数参数(例如alpha1)与模型2的参数(例如alpha2)相距2;即| alpha_1-alpha_2 | > 2.我想知道如何将其包含在损失函数中进行训练。

parameters pytorch backpropagation custom-training
1个回答
1
投票

示例模块定义

我将使用torch.nn.PReLU作为您讨论的参数激活。为方便起见,创建了get_weight

import torch


class Module(torch.nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.input = torch.nn.Linear(in_features, 2 * in_features)
        self.activation = torch.nn.PReLU()
        self.output = torch.nn.Linear(2 * in_features, out_features)

    def get_weight(self):
        return self.activation.weight

    def forward(self, inputs):
        return self.output(self.activation(self.inputs(inputs)))

模块和设置

这里,我正在使用一个优化器来优化您所讨论的两个模块的参数。 criterion可以是mean squared errorcross entropy或您需要的任何其他内容。

module1 = Module(20, 1)
module2 = Module(20, 1)

optimizer = torch.optim.Adam(
    itertools.chain(module1.parameters(), module2.parameters())
)
critertion = ...

培训

这里仅一步之遥,您应该像通常那样将其打包在数据中的for循环中,希望对您来说足以理解它:

inputs = ...
targets = ...

output1 = module1(inputs)
output2 = module2(inputs)

loss1 = criterion(output1, targets)
loss2 = criterion(output2, targets)

total_loss = loss1 + loss2
total_loss += torch.nn.functional.relu(
    2 - torch.abs(module1.get_weight() - module2.get_weight()).sum()
)
total_loss.backward()

optimizer.step()

此行是您在此情况下要遵循的内容:

total_loss += torch.nn.functional.relu(
    2 - torch.abs(module1.get_weight() - module2.get_weight()).sum()
)

relu被使用,因此网络不会仅因创建不同的权重而获得无限收益。如果没有,则权重之间的差异越大,损失将变为负数。在这种情况下,差异越大越好,但是在差距大于或等于2之后,差异就没有了。

您可能需要将2增加到2.1,或者如果您必须通过2的阈值,则在接近2.0时优化值的动机会很小。

编辑

没有显式指定阈值可能会很困难,但也许这样会起作用:

total_loss = (
    (torch.abs(module1) + torch.abs(module2)).sum()
    + (1 / torch.abs(module1) + 1 / torch.abs(module2)).sum()
    - torch.abs(module1 - module2).sum()
)

对于网络来说有点黑,但是可能值得一试(如果您应用其他L2正则化)。

本质上,此损失将在相应位置的-inf, +inf对权重处具有最佳值,并且永远不会小于零。

对于那些重量

weights_a = torch.tensor([-1000.0, 1000, -1000, 1000, -1000])
weights_b = torch.tensor([1000.0, -1000, 1000, -1000, 1000])

每个部分的损失将是:

(torch.abs(module1) + torch.abs(module2)).sum() # 10000
(1 / torch.abs(module1) + 1 / torch.abs(module2)).sum() # 0.0100
torch.abs(module1 - module2).sum() # 10000

在这种情况下,网络可以通过在两个模块中使用相反的符号来增加权重,而忽略您要优化的内容(两个模块的权重较大L2可能会有所帮助,我认为最佳值为[C0 ] / 1(如果-1L2等于alpha),并且我怀疑网络可能非常不稳定。

使用此丢失功能,如果网络错误地发现了较大的权重,将受到严厉的处罚。

在这种情况下,您将需要使用1 alpha参数进行调整以使其正常运行,这并不严格,但仍需要选择超参数。

© www.soinside.com 2019 - 2024. All rights reserved.