PyTorch 中的多层双向 LSTM/GRU 合并模式

问题描述 投票:0回答:1

我正在尝试将我的代码从 Keras 复制到 PyTorch 中,以比较多层双向 LSTM/GRU 模型在 CPU 和 GPU 上的性能。我想研究不同的合并模式,例如“concat”(这是 PyTorch 中的默认模式)、sum、mul、average。合并模式定义了向前和向后方向的输出如何传递到下一层。

在 Keras 中,这只是多层双向 LSTM/GRU 模型的合并模式的参数更改,PyTorch 中是否也存在类似的东西?一种选择是在每一层之后手动进行合并模式操作并传递到下一层,但我想研究性能,所以我想知道是否还有其他有效的方法。

谢谢,

python pytorch lstm bidirectional
1个回答
0
投票

据我所知,没有比在 PyTorch 中自己实现更有效的方法了,即不存在简单的参数选项。

正如你所说,标准模式是tensorflow的

'concat'
。如果我们想验证这一点,我们可以按如下方式进行测试:

import torch
from torch import nn

# Create the LSTMs
lstm = nn.LSTM(2, 2, batch_first=True)
bilstm = nn.LSTM(2, 2, batch_first=True, bidirectional=True)

# Copy forward weights
bilstm.weight_ih_l0 = lstm.weight_ih_l0
bilstm.weight_hh_l0 = lstm.weight_hh_l0
bilstm.bias_ih_l0 = lstm.bias_ih_l0
bilstm.bias_hh_l0 = lstm.bias_hh_l0

# Execute on random example
x = torch.randn(1, 3, 2)
output1, (h_n1, c_n1) = lstm(x)
output2, (h_n2, c_n2) = bilstm(x)

# Assert equality of the forward loops
assert torch.allclose(output1, output2[:, :, :2])  # Output is the same
assert torch.allclose(h_n1, h_n2[0])  # Hidden state is the same
assert torch.allclose(c_n1, c_n2[0])  # Cell state is the same

以下为其他三种合并模式的示例,供以后参考:

初始化

bilstm = nn.LSTM(2, 2, batch_first=True, bidirectional=True)
x = torch.randn(1, 3, 2)
output, (h_n, c_n) = bilstm(x)

总和 (
'sum'
)

# Merge Mode: 'sum'

# Simple version
output_sum = output[:, :, :2] + output[:, :, 2:]
assert output_sum.shape == (1, 3, 2)

# Faster version
output_sum2 = torch.sum(output.view(x.size(0), x.size(1), 2, -1), dim=2)
assert torch.allclose(output_sum, output_sum2)

在我的机器上,“更快的版本”需要大约。时间是简单版本的一半。

乘法 (
'mul'
)

# Merge Mode: 'mul'
output_mul = output[:, :, :2] * output[:, :, 2:]
assert output_mul.shape == (1, 3, 2)

平均 (
'ave'
)

# Merge Mode: 'ave'

# Simple version
output_ave = (output[:, :, :2] + output[:, :, 2:]) / 2
assert output_ave.shape == (1, 3, 2)

# Faster version
output_ave2 = torch.mean(output.view(x.size(0), x.size(1), 2, -1), dim=2)
assert torch.allclose(output_ave, output_ave2)

同样,更快的版本大约需要。我的设备上简单版本的时间为 50%。

我希望这可以帮助人们将来找到这一点。 :)

© www.soinside.com 2019 - 2024. All rights reserved.