我是 ML 的新手,我正在尝试制作一个编码器-解码器模型以从屏幕截图生成 emmet 代码。我制作了一个由屏幕截图及其相应的 emmet 代码(它是 html 代码的某种缩写)组成的数据集。我使用 swinTransformer 从图像中提取图像特征,然后我的编码器输入为 (32, 512)( 即 (batch_size, sequnce_length)。但我了解到变换器编码器期望输入大小为 (batch_size, sequnce_length, embeddings). 我是不是在提取特征的步骤中做错了什么,或者是否可以修改 transformer 编码器来接受我的输入?请帮助我理解这一点,非常感谢! 我的代码是这样的:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from build_dataset import EmmetDataset
from swin_transformer_pytorch import SwinTransformer
from transformer_encoder import TransformerEncoder
STModel = SwinTransformer(
hidden_dim=96,
layers=(2, 2, 6, 2),
heads=(3, 6, 12, 24),
channels=3,
num_classes=512,
head_dim=32,
window_size=4,
downscaling_factors=(4, 2, 2, 2),
relative_pos_embedding=True
)
encoder = TransformerEncoder(d_model=512, num_heads=8, num_layers=6)
train_dataset = EmmetDataset('train')
val_dataset = EmmetDataset('val')
test_dataset = EmmetDataset('test')
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
num_epochs = 10
for epoch in range(num_epochs):
for i, (screenshot_tensor, serialized_code_tensor) in enumerate(train_dataloader):
print(screenshot_tensor.shape) # [32, 3, 768, 768]
print(serialized_code_tensor.shape) # [32, 512]
# swinTransformer to extract features
features = STModel(screenshot_tensor)
print(features.shape) # [32, 512]
# Encoder-Decoder
encoder_output = encoder(features) # the encode expects an input of (batch_size, sequnce_length, embeddings), but i only got an input of (batch_size, sequnce_length)
print(encoder_output)
# ... ...