我有一个嵌入张量,我想将其减少为更少数量的嵌入。我在批处理环境中工作。张量形状为 B, F, D,其中 B 是批次数,F 是嵌入数,D 是维度。我想学习简化为 B、F-n、D。
例如
import torch
B = 10
F = 20
F_desired = 17
D = 64
x = torch.randn(B, F, D)
# torch.Size([50, 20, 64])
reduction = torch.?
y = reduction(x)
print(y.shape)
# torch.Size([50, 20, 64])
我认为 1x1 卷积在这里有意义,但不确定如何确认它确实达到了我的预期?所以很想知道这是否是正确的方法/是否有更好的方法
reduction = torch.nn.Conv1d(
in_channels=F,
out_channels=F_desired,
kernel_size=1,
)
内核大小为 1 的 1d 卷积可以实现此目的:
B = 10
F = 20
F_desired = 17
D = 64
x = torch.randn(B, F, D)
reduction1 = nn.Conv1d(F, F_desired, 1)
x1 = reduction1(x)
print(x1.shape)
> torch.Size([10, 17, 64])
你也可以做一个线性层,只要你排列轴:
reduction2 = nn.Linear(F, F_desired)
x2 = reduction2(x.permute(0,2,1)).permute(0,2,1)
print(x2.shape)
> torch.Size([10, 17, 64])
请注意,如果您的卷积核大小为
1
,这些实际上是等效的操作
reduction2.weight.data = reduction1.weight.squeeze().data
reduction2.bias.data = reduction1.bias.data
x2 = reduction2(x.permute(0,2,1)).permute(0,2,1)
print(torch.allclose(x1,x2, atol=1e-6))
> True