为了简单起见,我有一个由 N 个输入数据组成的序列,例如单词,并且有一个 RNN 单元。我想计算循环内 K 个单词的滑动窗口上的中继反向传播时间 (BPTT):
optimizer.zero_grad()
h = torch.zeros(hidden_size)
for i in range(N):
out, h = rnn_cell.forward(data[i], h)
if i > K:
loss += compute_loss(out, target)
loss.backward()
optimizer.step()
但显然它会计算之前所有步骤的梯度。我也尝试过这种方法:
h = torch.zeros(hidden_size)
for i in range(N):
optimizer.zero_grad()
out, h = rnn_cell.forward(data[i], h.detach())
loss += compute_loss(out, target)
loss.backward(retain_graph=True)
optimizer.step()
但它只会计算最后一步的梯度。我还尝试仅维护
deque(maxlen=K)
中的 K 个步骤的先前隐藏状态,因为我认为当从列表中丢弃对 h
状态的引用时,它也会从图中删除:
optimizer.zero_grad()
h = torch.zeros(hidden_size)
last_h = deque(maxlen=10)
for i in range(N):
last_h.append(h)
out, h = rnn_cell.forward(data[i], h)
if i > K:
optimizer.zero_grad()
loss += compute_loss(out, target)
loss.backward(retain_graph=True)
optimizer.step()
但我怀疑这里的任何方法是否能按我的预期发挥作用。作为一个非常幼稚的解决方法,我可以做到这一点:
h = torch.zeros(hidden_size)
optimizer.zero_grad()
for i in range(0, N, K):
h = h.detach()
optimizer.zero_grad()
for j in range(i, min(i + K, N)):
out, h = rnn_cell.forward(data[j], h)
loss += compute_loss(out, target)
loss.backward()
但是每一步需要计算K次。最终我也可以每 K 步分离
h
但这样梯度会不准确:
h = torch.zeros(hidden_size)
optimizer.zero_grad()
for i in range(0, N, K):
out, h = rnn_cell.forward(data[j], h)
if i % K == 0 and i > 0:
optimizer.zero_grad()
h = h.detach()
loss += compute_loss(out, target)
loss.backward()
optimizer.step()
如果您知道如何更好地实现这种滑动渐变窗口,我将非常高兴您的帮助。
您使用
RNNCell
而不是 RNN
有什么具体原因吗?另外,您应该使用 rnn_cell(data[i], h)
而不是 rnn_cell.forward(data[i], h)
。除非您特别需要为每个时间步骤添加自定义内容,否则RNN
将使您的批处理和使用多层变得更轻松。
无论如何:
通常设置 BPTT 值是在数据处理级别完成的。 RNN 接受大小为
(bs, sl, d_in)
的张量(我使用的是批量优先格式,但这同样适用于序列长度优先格式)。 “BPTT”只是在输入中指定 sl
最大值的一种奇特方式。
假设您的序列总长度为
N
并且想要使用 BPTT 值 K
。您可以选择块之间的重叠值O
。例如 O=1
表示块 n+1
是从块 n
移出的一个标记。如果 O=K
,则不存在重叠。您可以将整个数据集预处理为大小为 K
的块,并具有所需的重叠 O
。
然后在训练时,您将处理长度为
K
的完整序列,计算损失,然后反向传播。如果您想知道跟踪块之间的隐藏状态,答案是您不知道。这是使用 BPTT 时为了计算效率而做出的权衡。每个块都以一个新的隐藏状态开始 - 每个块对其之前存在的任何状态都是盲目的。
如果您担心隐藏状态的问题,您可以查看 Truncated BPTT。使用 Truncated BPTT,您首先运行一系列不带梯度跟踪的
K1
来构建隐藏状态,然后运行一系列带梯度跟踪的 K2
和来自 K1
的隐藏状态。然后,您可以通过 K2
进行更新和反向传播。