我有这个CCT编码器
class CctEncoder(nn.Module):
def __init__(self, in_channels, cct_block_params, num_layers):
super().__init__()
self.conv = nn.Conv2d(in_channels, cct_block_params[0][0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList()
for i in range(num_layers):
in_channels, out_channels, num_heads, mlp_ratio = cct_block_params[i]
block = CctBlock(in_channels, out_channels, num_heads, mlp_ratio)
self.blocks.append(block)
def forward(self, x):
x = self.conv(x)
for block in self.blocks:
x = block(x)
return x
CCTBlock 是这样的:
class CctBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_heads, mlp_ratio=4.0):
super().__init__()
self.norm1 = nn.LayerNorm(in_channels)
self.attn = nn.MultiheadAttention(in_channels, num_heads)
self.norm2 = nn.LayerNorm(in_channels)
self.mlp = nn.Sequential(
nn.Linear(in_channels, int(in_channels * mlp_ratio)),
nn.GELU(),
nn.Linear(int(in_channels * mlp_ratio), out_channels),
)
def forward(self, x):
x_norm = self.norm1(x)
attn_output, _ = self.attn(x_norm, x_norm, x_norm)
x = x + attn_output
x_norm = self.norm2(x)
mlp_output = self.mlp(x_norm)
x = x + mlp_output
return x
现在我也有一个CNN解码器了
class CnnDecoder(nn.Module):
def __init__(self, in_channels, num_blocks, out_channels):
super().__init__()
self.blocks = nn.ModuleList()
for i in range(num_blocks):
self.blocks.append(nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1))
in_channels //= 2
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x):
for block in self.blocks:
x = F.interpolate(x, scale_factor=2, mode="nearest")
x = block(x)
x = F.relu(x)
x = self.conv(x)
return x
我想用下面的代码来执行这整个事情。在这里,CCT 将作为主要编码器层,CNN 将作为解码器层。
import torch
import torch.nn as nn
import torch.nn.functional as F
class InpaintingModel(nn.Module):
def __init__(self, cct_block_params=((576, 128, 8, 2.0),) * 5, num_blocks=5):
super().__init__()
self.encoder = CctEncoder(3, cct_block_params, num_layers=len(cct_block_params))
self.grid_generator = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(1024, 2048, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(2048, 1024, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
)
self.decoder = CnnDecoder(1024, num_blocks, out_channels=3)
self.mask_conv = nn.Conv2d(3, 1, kernel_size=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x, mask):
encoded_x = self.encoder(x)
batch_size, channels, height, width = encoded_x.size()
mask = F.interpolate(mask, size=(height, width), mode='bilinear', align_corners=False)
mask = self.sigmoid(self.mask_conv(mask))
masked_encoded_x = encoded_x * mask
grid = self.grid_generator(mask.unsqueeze(1))
grid = grid.expand(batch_size, -1, -1, -1)
deformed_masked_encoded_x = F.grid_sample(masked_encoded_x, grid, mode='bilinear', align_corners=False)
decoded_x = self.decoder(deformed_masked_encoded_x)
return decoded_x
运行此模型时出现以下错误。
AssertionError: query should be unbatched 2D or batched 3D tensor but received 4-D query tensor
遇到这种情况我该怎么办?我在这里遗漏了什么吗?