我正在尝试弄清楚 PyTorch LSTM 如何获取输入。我已阅读文档,但我希望有更有经验的人来确认或纠正我迄今为止收集到的内容。
首先,让我们根据文档建立符号。
输入 x 中预期特征的数量
其中 x 被定义为时间 t if 我理解正确。
我将提供2个例子。
示例 - 1:
假设我每批有 3 个实例/序列,因此 N = 3,每个实例/序列表示为 [X, Y],其中 X 和 Y 是数字,因此 L = 2,X 和 Y 分别对应于第一个和第二个时间步长.
因此,对于 1 层和隐藏状态 4 来说,正确的做法是这样的:
batch_tensor = torch.tensor([
# The first sequence [1, 2] of length 2 where 1 is the first timestep and 2 is the second timestep
[[1], [2]],
# The second sequence [4, 5] of length 2 where 4 is the first timestep and 5 is the second timestep
[[4], [5]],
# The third sequence [7, 8] of length 2 where 7 is the first timestep and 8 is the second timestep
[[7], [8]]
], dtype=torch.float32)
print(batch_tensor.shape)
# Outputs -> torch.Size([3, 2, 1])
# input_size should be 1 as each timestep has dimensionality of 1
lstm = nn.LSTM(input_size=1, hidden_size=4, num_layers=1, batch_first=True)
示例 - 2:
这次我每批有 3 个实例/序列,因此 N = 3,每个实例/序列表示为 [X, Y],其中 X 和 Y 这次是向量,所以 L = 2 并且 X 和 Y 都对应于分别是第一和第二时间步。
因此,对于 1 层和隐藏状态 4 来说,正确的做法是这样的:
batch_tensor = torch.tensor([
# The first sequence [[1, 1.5], [2, 2.5]] of length 2 where [1, 1.5] is the first timestep and [2, 2.5] is the second timestep
[[1, 1.5], [2, 2.5]],
# The second sequence [[4, 4.5], [5, 5.5]] of length 2 where [4, 4.5] is the first timestep and [5, 5.5] is the second timestep
[[4, 4.5], [5, 5.5]],
# The third sequence [[7, 7.5], [8, 8.5]] of length 2 where [7, 7.5] is the first timestep and [8, 8.5] is the second timestep
[[7, 7.5], [8, 8.5]]
], dtype=torch.float32)
print(batch_tensor.shape)
# Outputs -> torch.Size([3, 2, 2])
# input_size should be 2 as each timestep has dimensionality of 2
lstm = nn.LSTM(input_size=2, hidden_size=4, num_layers=1, batch_first=True)
问题:
nn.LSTM
模块接受大小为(bs, sl, n)
或(sl, bs, n)
的输入,具体取决于batch_first
参数。
LSTM 输入预计是一个完整序列。这与
nn.LSTMCell
模块形成鲜明对比,后者一次采用一个时间步长。
import torch
import torch.nn as nn
inputs = torch.tensor([
[1, 2, 3, 4],
[7, 8, 0, 0]
])
d_embedding = 32
d_hidden = d_embedding*2
embedding = nn.Embedding(num_embeddings=12, embedding_dim=d_embedding)
lstm = nn.LSTM(input_size=d_embedding, hidden_size=d_hidden, num_layers=3, batch_first=True)
x_embedded = embedding(inputs)
(output, (hidden_state, cell_state)) = lstm(x_embedded)