在pytorch中,我通过以下方式启动反向传播(通过时间)来训练RNN / GRU / LSTM网络:
loss.backward()
当序列很长时,我想通过时间进行截断反向传播,而不是使用整个序列的正常反向传播时间。
但我在Pytorch API中找不到任何参数或函数来设置截断的BPTT。我错过了吗?我应该在Pytorch自己编码吗?
这是一个例子:
for t in range(T):
y = lstm(y)
if T-t == k:
out.detach()
out.backward()
因此,在此示例中,k
是用于控制要展开的时间步长的参数。