在Pytorch中通过时间截断反向传播(BPTT)

问题描述 投票:2回答:1

在pytorch中,我通过以下方式启动反向传播(通过时间)来训练RNN / GRU / LSTM网络:

loss.backward()

当序列很长时,我想通过时间进行截断反向传播,而不是使用整个序列的正常反向传播时间。

但我在Pytorch API中找不到任何参数或函数来设置截断的BPTT。我错过了吗?我应该在Pytorch自己编码吗?

pytorch backpropagation truncated
1个回答
0
投票

这是一个例子:

for t in range(T):
   y = lstm(y)
   if T-t == k:
      out.detach()
out.backward()

因此,在此示例中,k是用于控制要展开的时间步长的参数。

© www.soinside.com 2019 - 2024. All rights reserved.