Pytorch LSTM - 在训练期间初始化隐藏状态

问题描述 投票:0回答:1

我有一个包含 LSTM 模型的类,并且对一些数据(=钟摆轨迹)有一个训练循环。 当我训练模型时,我必须初始化每个时间步的隐藏状态。这让我很困惑,因为我认为 LSTM (RNN) 的强大之处在于,我使用之前的隐藏状态进行下一次计算……但我每次都将其设置为零。该模型对于预测钟摆也非常有效。 (这段代码“深受启发”来自一篇文章,该文章在一个非常相似的问题上使用了它)

这是模型类:

class LSTMmodel(nn.Module):
    
    def __init__(self,input_size,hidden_size_1,hidden_size_2,out_size):
        
        super().__init__()
        self.hidden_size_1 = hidden_size_1
        self.hidden_size_2 = hidden_size_2
        self.input_size = input_size
        self.lstm_1 = nn.LSTM(input_size,hidden_size_1)
        self.lstm_2 = nn.LSTM(hidden_size_1,hidden_size_2)
        self.linear = nn.Linear(hidden_size_2,out_size)
        self.hidden_1 = (
            torch.zeros(1,1,hidden_size_1),
            torch.zeros(1,1,hidden_size_1)
        )
        self.hidden_2 = (
            torch.zeros(1,1,hidden_size_2),
            torch.zeros(1,1,hidden_size_2)
        )
        
    def forward(self,seq):
        lstm_out_1 , self.hidden_1 = self.lstm_1(seq.view(-1,1,self.input_size),self.hidden_1) 
        lstm_out_2 , self.hidden_2 = self.lstm_2(lstm_out_1,self.hidden_2)  
        pred = self.linear(lstm_out_2.view(len(seq),-1))
        return pred

这是训练循环:

def train(model, ddt):

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=0.001)

    model.train()
        # Set the number of epochs
    epochs = 50

    for epoch in range(epochs):
        
        # Running each batch separately 
        
        for bat in range(0,len(training_data), data[0].size(dim=0)):#
            #model.hidden_1 = (torch.zeros(1,1,model.hidden_size_1), torch.zeros(1,1,model.hidden_size_1))
            #model.hidden_2 = (torch.zeros(1,1,model.hidden_size_2), torch.zeros(1,1,model.hidden_size_2))
        
            for seq,label in training_data[bat:bat+data[0].size(dim=0)]:
                model.hidden_1 = (torch.zeros(1,1,model.hidden_size_1),                 torch.zeros(1,1,model.hidden_size_1))
                model.hidden_2 = (torch.zeros(1,1,model.hidden_size_2), torch.zeros(1,1,model.hidden_size_2))
        
                seq=seq.to(device)
                label=label.to(device)

                # set the optimization gradient to zero
                optimizer.zero_grad()

                model.zero_grad()
                # initialize the hidden states
                
                # Make predictions on the current sequence
                if ddt: 
                    y_pred = model(seq) + seq # learn derivative?
                else:
                    y_pred = model(seq)

                # Compute the loss
                loss = loss_fn(y_pred, label)         
                # Perform back propogation and gradient descent

                loss.backward(retain_graph=True)

                optimizer.step()

模型应该简单地预测给定当前位置的下一个时间步的摆角。

如果我尝试仅在每个批次开始时初始化隐藏状态(已注释掉),则会收到以下错误:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

尝试设置

loss.backward(retain_graph=True)
并不能解决问题。我收到以下错误:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.DoubleTensor [15, 40]], which is output 0 of AsStridedBackward0, is at version 7206; expected version 7205 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

我认为这是因为模型内的某些张量在梯度更新后发生了变化? 任何有关如何解决此问题的建议都很棒!

我有点困惑,因为我认为 LSTM(和 RNN)的要点一般是保留隐藏状态并在下一个时间步中使用它。但现在我将隐藏状态设置为 0,它仍然工作得很好。

提前致谢!

python pytorch lstm recurrent-neural-network
1个回答
0
投票

不要在训练期间重置这些层。您的目标是训练这些层!您似乎也在重新分配整个层,因此出现内存错误。您应该在模型创建时初始化这些层。

另外,你想将图层初始化为随机值,我不确定初始化为 0 是否合适。

查看此文档:https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html

© www.soinside.com 2019 - 2024. All rights reserved.