如何在 PyTorch 中计算 RNN 单元的截断时间反向传播 (BPTT)

问题描述 投票:0回答:1

为了简单起见,我有一个由 N 个输入数据组成的序列,例如单词,并且有一个 RNN 单元。我想计算循环内 K 个单词的滑动窗口上的中继反向传播时间 (BPTT):

optimizer.zero_grad()
h = torch.zeros(hidden_size)
for i in range(N):
    out, h = rnn_cell.forward(data[i], h)
    if i > K:
        loss += compute_loss(out, target)

loss.backward()
optimizer.step()

但显然它会计算之前所有步骤的梯度。我也尝试过这种方法:

h = torch.zeros(hidden_size)
for i in range(N):
    optimizer.zero_grad()
    out, h = rnn_cell.forward(data[i], h.detach())
    loss += compute_loss(out, target)
    loss.backward(retain_graph=True)
    optimizer.step()

但它只会计算最后一步的梯度。我还尝试仅维护

deque(maxlen=K)
中的 K 个步骤的先前隐藏状态,因为我认为当从列表中丢弃对
h
状态的引用时,它也会从图中删除:

optimizer.zero_grad()
h = torch.zeros(hidden_size)
last_h = deque(maxlen=10)

for i in range(N):
    last_h.append(h)
    out, h = rnn_cell.forward(data[i], h)
    if i > K:
        optimizer.zero_grad()
        loss += compute_loss(out, target)
        loss.backward(retain_graph=True)
        optimizer.step()

但我怀疑这里的任何方法是否能按我的预期发挥作用。作为一个非常幼稚的解决方法,我可以做到这一点:

h = torch.zeros(hidden_size)
optimizer.zero_grad()

for i in range(0, N, K):
    h = h.detach()

    optimizer.zero_grad()
    for j in range(i, min(i + K, N)):
        out, h = rnn_cell.forward(data[j], h)

    loss += compute_loss(out, target)
    loss.backward()

但是每一步需要计算K次。最终我也可以每 K 步分离

h
但这样梯度会不准确:

h = torch.zeros(hidden_size)
optimizer.zero_grad()

for i in range(0, N, K):
    out, h = rnn_cell.forward(data[j], h)
    if i % K == 0 and i > 0:
        optimizer.zero_grad()
        h = h.detach()
        loss += compute_loss(out, target)
        loss.backward()
        optimizer.step()

如果您知道如何更好地实现这种滑动渐变窗口,我将非常高兴您的帮助。

python pytorch backpropagation back-propagation-through-time
1个回答
0
投票

您使用

RNNCell
而不是
RNN
有什么具体原因吗?另外,您应该使用
rnn_cell(data[i], h)
而不是
rnn_cell.forward(data[i], h)
。除非您特别需要为每个时间步骤添加自定义内容,否则
RNN
将使您的批处理和使用多层变得更轻松。

无论如何:

通常设置 BPTT 值是在数据处理级别完成的。 RNN 接受大小为

(bs, sl, d_in)
的张量(我使用的是批量优先格式,但这同样适用于序列长度优先格式)。 “BPTT”只是在输入中指定
sl
最大值的一种奇特方式。

假设您的序列总长度为

N
并且想要使用 BPTT 值
K
。您可以选择块之间的重叠值
O
。例如
O=1
表示块
n+1
是从块
n
移出的一个标记。如果
O=K
,则不存在重叠。您可以将整个数据集预处理为大小为
K
的块,并具有所需的重叠
O

然后在训练时,您将处理长度为

K
的完整序列,计算损失,然后反向传播。如果您想知道跟踪块之间的隐藏状态,答案是您不知道。这是使用 BPTT 时为了计算效率而做出的权衡。每个块都以一个新的隐藏状态开始 - 每个块对其之前存在的任何状态都是盲目的。

如果您担心隐藏状态的问题,您可以查看 Truncated BPTT。使用 Truncated BPTT,您首先运行一系列不带梯度跟踪的

K1
来构建隐藏状态,然后运行一系列带梯度跟踪的
K2
和来自
K1
的隐藏状态。然后,您可以通过
K2
进行更新和反向传播。

© www.soinside.com 2019 - 2024. All rights reserved.