我正在尝试构建一个基于动态尺寸大小的 PyTorch 模块。
import torch
from random import randint
from torch.nn import Linear, BatchNorm1d, ReLU, Dropout, Sequential
batch_size = 2
embed_size = 128
fc_1 = Sequential(
Sequential(
Linear(1, 64),
BatchNorm1d(64),
ReLU(),
Dropout(0.1),
),
Linear(64, embed_size),
Linear(embed_size, embed_size)
)
# Secondary Features
p = torch.randn(batch_size).unsqueeze(1) # torch.Size([2, 1])
q = torch.randn(batch_size).unsqueeze(1) # torch.Size([2, 1])
r = torch.randn(batch_size).unsqueeze(1) # torch.Size([2, 1])
s = torch.randn(batch_size).unsqueeze(1) # torch.Size([2, 1])
secondary = torch.cat([ # torch.Size([8, 1])
p, q, r, s
], dim=0)
# Random Dimension Size
x = randint(2, 400) # 239
# Primary Features
a = torch.rand(batch_size, embed_size) # torch.Size([2, 128])
b = fc_1(secondary) # torch.Size([8, 128])
c = torch.rand(x, embed_size) # torch.Size([239, 128])
如何将
a
、b
和 c
中的所有信息折叠到变量 y
中,使其大小为 (batch_size, embed_size)
?
我正在尝试进行回归分析,因此在折叠过程中不要丢失任何信息非常重要。显然,
torch.cat
是不可能的。任何使用可学习层来折叠它的方法都可以。
考虑到不同的维度,尤其是
c
的动态大小,简单的串联在这里不起作用。这种情况下的常见方法是使用“注意力机制”,它可以处理不同的输入大小并以可学习的方式聚合信息。注意力机制可以对输入的不同部分进行不同的权衡,从而使模型能够专注于输入中信息更丰富的部分。
以下代码是简单注意力机制的实现。这个想法是计算每个 a
、
b
和 c
的注意力分数,然后使用这些分数对这些张量进行加权和求和import torch
import torch.nn as nn
import torch.nn.functional as F
class AttentionAggregator(nn.Module):
def __init__(self, embed_size):
super().__init__()
self.embed_size = embed_size
self.query = nn.Parameter(torch.randn(embed_size))
self.key = nn.Linear(embed_size, embed_size)
def forward(self, a, b, c):
# concatenate all inputs for computing attention
combined = torch.cat([a, b, c], dim=0)
# compute keys
keys = self.key(combined)
# compute attention scores
attention_scores = torch.matmul(keys, self.query) / (self.embed_size ** 0.5)
attention_weights = F.softmax(attention_scores, dim=0).unsqueeze(-1)
# apply attention weights
weighted = combined * attention_weights
# aggregate information
aggregated = weighted.sum(dim=0)
return aggregated
embed_size = 128 # your example
# instantiate the attention aggregator
attention_aggregator = AttentionAggregator(embed_size)
# aggregate the information (a, b, c already defined/computed in your example)
y = attention_aggregator(a, b, c)