我正在尝试在Pytorch中实现UNet architecture。当我使用print(model)
打印模型时,我得到了正确的体系结构:
但是当我尝试使用(或其他任何输入尺寸)打印摘要时:
from torchsummary import summary
summary(model, input_size=(13, 572, 572))
我收到一个错误:
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 70 and 71 in dimension 2 at /Users/distiller/project/conda/conda-bld/pytorch_1579022061893/work/aten/src/TH/generic/THTensor.cpp:612
但是,如果我将input_size设置为input_size=(3, 224, 224))
,则效果很好(就像它对这个人here一样有效)。我很困惑。
有人可以帮我怎么了吗?
Edit:我使用了here中的模型架构。