我“手动”定义了一个 RNN,它由具有修剪连接的多个线性层组成。
为了跟踪隐藏状态,我有一个变量
next_hidden_states
,我在其中保存时间 t 的隐藏状态,以便在时间 t+1 时重新使用它们。该变量的大小为 (batch_size, N)
。
在训练/评估期间,我希望能够评估具有批量大小(训练代理)或不具有批量大小(在环境中运行一个片段)的输入的模型。对于经典的 pytorch 模块来说,这通常是可能的,因为批量大小是隐式的......
我想过将
next_hidden_states
作为网络的参数和输出,但它相当不优雅。
编辑
这是我的代码的最小版本
import numpy as np
import torch
import torch.nn.utils.prune as prune
import torch.nn as nn
class BrainRNN(nn.Module):
def __init__(self, activation=torch.sigmoid, batch_size=8):
super(BrainRNN, self).__init__()
self.n_neurons = 3*4
self.activation = activation
self.batch_size = batch_size
self.reset_hidden_states()
# Create the input layer
self.input_layer = nn.Linear(4, 4)
# Create forward hidden layers
self.hidden_layers = nn.ModuleList([])
new_layer = nn.Linear(4,4)
mask = np.ones((4,4))-np.eye(4)
prune.custom_from_mask(new_layer, name='weight', mask=torch.tensor(mask.T)) # delete fictive connections
self.hidden_layers.append(new_layer)
# Create the backward weights
self.recurrent_layers = nn.ModuleList([]) # recurrent_layers[i](hidden_states) = layer j>i to i
new_layer = nn.Linear(self.n_neurons, 4, bias=False) # no bias for backward connection
mask = np.zeros((12,4))
mask[1,0] = 1
prune.custom_from_mask(new_layer, name='weight', mask=torch.tensor(mask.T)) # delete fictive connections
self.recurrent_layers.append(new_layer)
# Create the output layer
self.output_layer = nn.Linear(4,4)
def forward(self, x):
next_hidden_states = torch.empty(x.shape[0], self.n_neurons) if x.dim() > 1 else torch.empty(self.n_neurons)
skips = [] # list of current states for skip connections
# Input layer
x = self.activation(self.input_layer(x) + self.recurrent_layers[0](self.hidden_states))
next_hidden_states[...,[0,1,2,3]] = x
# Hidden layers
x = self.hidden_layers[0](x)
x = self.activation(x)
next_hidden_states[...,[4,5,6,7]] = x
# Output layer
x = self.output_layer(x) # no activation nor recurrent/skip connection for the last one
self.hidden_states = next_hidden_states
return x
def reset_hidden_states(self, hidden_states=None):
if self.batch_size > 0:
self.hidden_states = nn.init.normal_(torch.empty(self.n_neurons), std=1).repeat(self.batch_size,1) # same hidden states for all batches
else:
self.hidden_states = nn.init.normal_(torch.empty(self.n_neurons), std=1)
nn = BrainRNN()
nn(torch.zeros(8,4)) # works well
nn(torch.zeros(4)) # shape issue at next_hidden_states[...,[0,1,2,3]] = x
其中有 3 层,每层 4 个节点,隐藏层和输入层之间有循环连接,以及一些修剪过的连接。
目标是能够(如果
nn = BrainRNN(...)
)评估nn(torch.zeros((B,4)))
以及nn(torch.zeros(4))
。
理想情况下,我想重现经典 nn.Modules 的行为,但我真的不知道如何在保存状态时做到这一点......
下面的简单 RNN,接受数据为
(sequence_length, n_features)
或 (batch_size, sequence_length, n_features)
。它逐步执行整个序列并返回每个步骤的输出和隐藏状态(还将它们存储为您可以访问的属性)。这是您想要的功能吗?没有修剪,但您可以像在原始代码中一样添加它。
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size=4, output_size=2, activation='tanh', batch_first=True):
super().__init__()
#Onyl support batch_first=True (as per OP's test data)
assert batch_first, 'This model assumes batch_first=True for simplicity'
self.input_size = input_size
self.hidden_size = hidden_size
self.activation_fn = getattr(torch.nn.functional, activation)
self.Wxh = nn.Linear(self.input_size, self.hidden_size)
self.Whh = nn.Linear(self.hidden_size, self.hidden_size)
self.Why = nn.Linear(self.hidden_size, output_size)
def forward(self, x):
x = x.clone()
x_ndim_orig = x.ndim
#If it's 2D, assume that means (sequence_length, n_features,)
# and prepend batch
if x.ndim == 2:
print('X.ndim is 2 | Assuming X.shape is (sequence_length, n_features)')
x = x.unsqueeze(dim=0)
elif x.ndim == 3:
print('X.ndim is 3 | Assuming X.shape is (batch_size, sequence_length, n_features)')
#Record the hidden state and y at each step for input x
hidden_states = []
outputs = []
batch_size, sequence_len, n_features = x.shape
assert self.input_size == n_features, f'Expected input features size of {self.input_size}'
#Initialise hidden_state to 0, and step through the sequence recurrently
hidden_state = torch.zeros(batch_size, self.hidden_size)
for frame_idx in range(sequence_len):
frame = x[:, frame_idx, :] #(batch, n_features) for this timestep
hidden_state = self.activation_fn(
self.Wxh(frame) + self.Whh(hidden_state)
)
output = self.activation_fn(self.Why(hidden_state))
#Record the hidden state and y for this frame
hidden_states.append(hidden_state)
outputs.append(output)
#Stack into (batch_size, sequence_length, output_size/hidden_size)
# Available as attributes
self.outputs = torch.stack(outputs, dim=1)
self.hidden_states = torch.stack(hidden_states, dim=1)
#Optionally drop the batch dim that we added
if x_ndim_orig == 2:
self.outputs, self.hidden_states = self.outputs[0], self.hidden_states[0]
return self.outputs, self.hidden_states
测试形状:
#Input: (sequence_length=12, n_features=4)
#Output: (sequence_length=12, hidden_size)
x = torch.rand(12, 4)
outputs, hidden_states = SimpleRNN(input_size=4)(x)
print(hidden_states.shape)
#Input: (batch_size=32, sequence_length=12, n_features=4)
#Output: (batch_size=32, sequence_length=12, hidden_size)
x = torch.rand(32, 12, 4)
outputs, hidden_states = SimpleRNN(input_size=4)(x)
print(hidden_states.shape)
X.ndim is 2 | Assuming X.shape is (sequence_length, n_features)
torch.Size([12, 4])
X.ndim is 3 | Assuming X.shape is (batch_size, sequence_length, n_features)
torch.Size([32, 12, 4])
RNN 未经测试,旨在说明如何在类内进行递归并存储隐藏状态。