pytorch中不同nn.Moule中的共享参数

问题描述 投票:3回答:1

我有下面可以看到的模型,但是我需要创建两个共享x2h和h2h的实例。有人知道怎么做吗?

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size
        self.x2h = nn.Linear(input_size, hidden_size)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        self.h2o = nn.Linear(hidden_size, output_size)

        #self.softmax = nn.LogSoftmax(dim=1)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, input, hidden):

        hidden1 = self.x2h(input)
        hidden2 = self.h2h(hidden)
        hidden = hidden1 + hidden2
        output = self.h2o(hidden)
        output = self.softmax(output)

        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)
python pytorch static-variables
1个回答
1
投票

这是我认为的Python问题。

在类内部而不是方法内部声明的变量是类或静态变量。

参考:https://radek.io/2011/07/21/static-variables-and-methods-in-python/

© www.soinside.com 2019 - 2024. All rights reserved.