AssertionError: query should be unbatched 2D or batched 3D tensor but received 4-D query tensor: Pytorch

问题描述 投票:0回答:0

我有这个CCT编码器

class CctEncoder(nn.Module):
def __init__(self, in_channels, cct_block_params, num_layers):
    super().__init__()
    self.conv = nn.Conv2d(in_channels, cct_block_params[0][0], kernel_size=3, padding=1)
    self.blocks = nn.ModuleList()
    for i in range(num_layers):
        in_channels, out_channels, num_heads, mlp_ratio = cct_block_params[i]
        block = CctBlock(in_channels, out_channels, num_heads, mlp_ratio)
        self.blocks.append(block)

def forward(self, x):
    x = self.conv(x)
    for block in self.blocks:
        x = block(x)
    return x

CCTBlock 是这样的:

class CctBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_heads, mlp_ratio=4.0):
    super().__init__()
    self.norm1 = nn.LayerNorm(in_channels)
    self.attn = nn.MultiheadAttention(in_channels, num_heads)
    self.norm2 = nn.LayerNorm(in_channels)
    self.mlp = nn.Sequential(
        nn.Linear(in_channels, int(in_channels * mlp_ratio)),
        nn.GELU(),
        nn.Linear(int(in_channels * mlp_ratio), out_channels),
    )

def forward(self, x):
    x_norm = self.norm1(x)
    attn_output, _ = self.attn(x_norm, x_norm, x_norm)
    x = x + attn_output
    x_norm = self.norm2(x)
    mlp_output = self.mlp(x_norm)
    x = x + mlp_output
    return x

现在我也有一个CNN解码器了

class CnnDecoder(nn.Module):
def __init__(self, in_channels, num_blocks, out_channels):
    super().__init__()
    self.blocks = nn.ModuleList()
    for i in range(num_blocks):
        self.blocks.append(nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1))
        in_channels //= 2
    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

def forward(self, x):
    for block in self.blocks:
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        x = block(x)
        x = F.relu(x)
    x = self.conv(x)
    return x

我想用下面的代码来执行这整个事情。在这里,CCT 将作为主要编码器层,CNN 将作为解码器层。

import torch
import torch.nn as nn
import torch.nn.functional as F

class InpaintingModel(nn.Module):
    def __init__(self, cct_block_params=((576, 128, 8, 2.0),) * 5, num_blocks=5):
        super().__init__()
        self.encoder = CctEncoder(3, cct_block_params, num_layers=len(cct_block_params))
        self.grid_generator = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(1024, 2048, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(2048, 1024, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
        )
        self.decoder = CnnDecoder(1024, num_blocks, out_channels=3)
        self.mask_conv = nn.Conv2d(3, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, mask):
        encoded_x = self.encoder(x)
        batch_size, channels, height, width = encoded_x.size()
        mask = F.interpolate(mask, size=(height, width), mode='bilinear', align_corners=False)
        mask = self.sigmoid(self.mask_conv(mask))
        masked_encoded_x = encoded_x * mask
        grid = self.grid_generator(mask.unsqueeze(1))
        grid = grid.expand(batch_size, -1, -1, -1)
        deformed_masked_encoded_x = F.grid_sample(masked_encoded_x, grid, mode='bilinear', align_corners=False)
        decoded_x = self.decoder(deformed_masked_encoded_x)
        return decoded_x

运行此模型时出现以下错误。

AssertionError: query should be unbatched 2D or batched 3D tensor but received 4-D query tensor

回溯

完整代码

遇到这种情况我该怎么办?我在这里遗漏了什么吗?

python deep-learning pytorch tensor
© www.soinside.com 2019 - 2024. All rights reserved.