如何跟踪不同输入形状的隐藏状态

问题描述 投票:0回答:1

我“手动”定义了一个 RNN,它由具有修剪连接的多个线性层组成。

为了跟踪隐藏状态,我有一个变量

next_hidden_states
,我在其中保存时间 t 的隐藏状态,以便在时间 t+1 时重新使用它们。该变量的大小为
(batch_size, N)

在训练/评估期间,我希望能够评估具有批量大小(训练代理)或不具有批量大小(在环境中运行一个片段)的输入的模型。对于经典的 pytorch 模块来说,这通常是可能的,因为批量大小是隐式的......

我想过将

next_hidden_states
作为网络的参数和输出,但它相当不优雅。

编辑

这是我的代码的最小版本

import numpy as np

import torch
import torch.nn.utils.prune as prune
import torch.nn as nn


class BrainRNN(nn.Module):
    def __init__(self, activation=torch.sigmoid, batch_size=8):
        super(BrainRNN, self).__init__()
        self.n_neurons = 3*4
        self.activation = activation
        self.batch_size = batch_size
        self.reset_hidden_states()

        # Create the input layer
        self.input_layer = nn.Linear(4, 4)

        # Create forward hidden layers
        self.hidden_layers = nn.ModuleList([])
        new_layer = nn.Linear(4,4)
        mask = np.ones((4,4))-np.eye(4)
        prune.custom_from_mask(new_layer, name='weight', mask=torch.tensor(mask.T)) # delete fictive connections
        self.hidden_layers.append(new_layer)

        # Create the backward weights
        self.recurrent_layers = nn.ModuleList([]) # recurrent_layers[i](hidden_states) = layer j>i to i

        new_layer = nn.Linear(self.n_neurons, 4, bias=False) # no bias for backward connection
        mask = np.zeros((12,4))
        mask[1,0] = 1
        prune.custom_from_mask(new_layer, name='weight', mask=torch.tensor(mask.T)) # delete fictive connections
        self.recurrent_layers.append(new_layer)

        # Create the output layer
        self.output_layer = nn.Linear(4,4)

    def forward(self, x):
        next_hidden_states = torch.empty(x.shape[0], self.n_neurons) if x.dim() > 1 else torch.empty(self.n_neurons)
        skips = [] # list of current states for skip connections

        # Input layer
        x = self.activation(self.input_layer(x) + self.recurrent_layers[0](self.hidden_states))
        next_hidden_states[...,[0,1,2,3]] = x

        # Hidden layers
        x = self.hidden_layers[0](x)
        x = self.activation(x)
        next_hidden_states[...,[4,5,6,7]] = x

        # Output layer
        x = self.output_layer(x) # no activation nor recurrent/skip connection for the last one
        
        self.hidden_states = next_hidden_states

        return x

    def reset_hidden_states(self, hidden_states=None):
        if self.batch_size > 0:
            self.hidden_states = nn.init.normal_(torch.empty(self.n_neurons), std=1).repeat(self.batch_size,1) # same hidden states for all batches
        else:
            self.hidden_states = nn.init.normal_(torch.empty(self.n_neurons), std=1)

nn = BrainRNN()
nn(torch.zeros(8,4)) # works well
nn(torch.zeros(4)) # shape issue at next_hidden_states[...,[0,1,2,3]] = x

其中有 3 层,每层 4 个节点,隐藏层和输入层之间有循环连接,以及一些修剪过的连接。

目标是能够(如果

nn = BrainRNN(...)
)评估
nn(torch.zeros((B,4)))
以及
nn(torch.zeros(4))
。 理想情况下,我想重现经典 nn.Modules 的行为,但我真的不知道如何在保存状态时做到这一点......

python pytorch recurrent-neural-network
1个回答
0
投票

下面的简单 RNN,接受数据为

(sequence_length, n_features)
(batch_size, sequence_length, n_features)
。它逐步执行整个序列并返回每个步骤的输出和隐藏状态(还将它们存储为您可以访问的属性)。这是您想要的功能吗?没有修剪,但您可以像在原始代码中一样添加它。

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size=4, output_size=2, activation='tanh', batch_first=True):
        super().__init__()
        
        #Onyl support batch_first=True (as per OP's test data)
        assert batch_first, 'This model assumes batch_first=True for simplicity'

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.activation_fn = getattr(torch.nn.functional, activation)
        
        self.Wxh = nn.Linear(self.input_size, self.hidden_size)
        self.Whh = nn.Linear(self.hidden_size, self.hidden_size)
        self.Why = nn.Linear(self.hidden_size, output_size)
        
    def forward(self, x):
        x = x.clone()
        
        x_ndim_orig = x.ndim
        
        #If it's 2D, assume that means (sequence_length, n_features,)
        # and prepend batch
        if x.ndim == 2:
            print('X.ndim is 2 | Assuming X.shape is (sequence_length, n_features)')
            x = x.unsqueeze(dim=0)
        elif x.ndim == 3:
            print('X.ndim is 3 | Assuming X.shape is (batch_size, sequence_length, n_features)')
        
        #Record the hidden state and y at each step for input x
        hidden_states = []
        outputs = []
        
        batch_size, sequence_len, n_features = x.shape
        assert self.input_size == n_features, f'Expected input features size of {self.input_size}'
        
        #Initialise hidden_state to 0, and step through the sequence recurrently
        hidden_state = torch.zeros(batch_size, self.hidden_size)
        for frame_idx in range(sequence_len):
            frame = x[:, frame_idx, :] #(batch, n_features) for this timestep
            
            hidden_state = self.activation_fn(
                self.Wxh(frame) + self.Whh(hidden_state)
            )
            output = self.activation_fn(self.Why(hidden_state))
            
            #Record the hidden state and y for this frame
            hidden_states.append(hidden_state)
            outputs.append(output)
        
        #Stack into (batch_size, sequence_length, output_size/hidden_size)
        # Available as attributes
        self.outputs = torch.stack(outputs, dim=1)
        self.hidden_states = torch.stack(hidden_states, dim=1)
        
        #Optionally drop the batch dim that we added
        if x_ndim_orig == 2:
            self.outputs, self.hidden_states = self.outputs[0], self.hidden_states[0]
        
        return self.outputs, self.hidden_states

测试形状:

#Input:  (sequence_length=12, n_features=4)
#Output: (sequence_length=12, hidden_size)
x = torch.rand(12, 4)
outputs, hidden_states = SimpleRNN(input_size=4)(x)
print(hidden_states.shape)

#Input: (batch_size=32, sequence_length=12, n_features=4)
#Output: (batch_size=32, sequence_length=12, hidden_size)
x = torch.rand(32, 12, 4)
outputs, hidden_states = SimpleRNN(input_size=4)(x)
print(hidden_states.shape)
X.ndim is 2 | Assuming X.shape is (sequence_length, n_features)
torch.Size([12, 4])

X.ndim is 3 | Assuming X.shape is (batch_size, sequence_length, n_features)
torch.Size([32, 12, 4])

RNN 未经测试,旨在说明如何在类内进行递归并存储隐藏状态。

© www.soinside.com 2019 - 2024. All rights reserved.