我的代码如下。
import numpy as np
import torch
import torch.nn as nn
import cupy as cp
from torchviz import make_dot
from torchinfo import summary
from torchsummary import summary as summary_
def get_filter_torch(*args, **kwargs):
class TraversabilityFilter(nn.Module):
def __init__(self, w1, w2, w3, w_out, device="cuda", use_bias=False):
super(TraversabilityFilter, self).__init__()
self.conv1 = nn.Conv2d(1, 4, 3, dilation=1, padding=0, bias=use_bias)
self.conv2 = nn.Conv2d(1, 4, 3, dilation=2, padding=0, bias=use_bias)
self.conv3 = nn.Conv2d(1, 4, 3, dilation=3, padding=0, bias=use_bias)
self.conv_out = nn.Conv2d(12, 1, 1, bias=use_bias)
# Set weights.
self.conv1.weight = nn.Parameter(torch.from_numpy(w1).float())
self.conv2.weight = nn.Parameter(torch.from_numpy(w2).float())
self.conv3.weight = nn.Parameter(torch.from_numpy(w3).float())
self.conv_out.weight = nn.Parameter(torch.from_numpy(w_out).float())
def __call__(self, elevation_cupy):
# Convert cupy tensor to pytorch.
elevation_cupy = elevation_cupy.astype(cp.float32)
elevation = torch.as_tensor(elevation_cupy, device=self.conv1.weight.device)
print("input: ",elevation.shape)
with torch.no_grad():
out1 = self.conv1(elevation.view(-1, 1, elevation.shape[0], elevation.shape[1]))
out2 = self.conv2(elevation.view(-1, 1, elevation.shape[0], elevation.shape[1]))
out3 = self.conv3(elevation.view(-1, 1, elevation.shape[0], elevation.shape[1]))
out1 = out1[:, :, 2:-2, 2:-2]
out2 = out2[:, :, 1:-1, 1:-1]
out = torch.cat((out1, out2, out3), dim=1)
out = self.conv_out(out.abs())
out = torch.exp(-out)
print("output: ",out.shape)
return out
traversability_filter = TraversabilityFilter(*args, **kwargs).cuda().eval()
return traversability_filter
# Define the weight values (you need to provide actual weight values)
w1 = np.random.randn(4, 1, 3, 3) # Shape: (out_channels, in_channels, kernel_height, kernel_width)
w2 = np.random.randn(4, 1, 3, 3) # Shape: (out_channels, in_channels, kernel_height, kernel_width)
w3 = np.random.randn(4, 1, 3, 3) # Shape: (out_channels, in_channels, kernel_height, kernel_width)
w_out = np.random.randn(1, 12, 1, 1) # Shape: (out_channels, in_channels, kernel_height, kernel_width)
model = get_filter_torch(w1, w2, w3, w_out)
cell_n = 200
x = cp.random.randn(cell_n, cell_n, dtype=cp.float32)
output = model(x)
print(model)
input_size=(200,200)
summary(model)
summary_(model, input_size)
当我运行代码时,结果如下。
input: torch.Size([200, 200])
output: torch.Size([1, 1, 194, 194])
TraversabilityFilter(
(conv1): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), bias=False)
(conv2): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), dilation=(2, 2), bias=False)
(conv3): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), dilation=(3, 3), bias=False)
(conv_out): Conv2d(12, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
=================================================================
Layer (type:depth-idx) Param #
=================================================================
TraversabilityFilter --
├─Conv2d: 1-1 36
├─Conv2d: 1-2 36
├─Conv2d: 1-3 36
├─Conv2d: 1-4 12
=================================================================
Total params: 120
Trainable params: 120
Non-trainable params: 0
=================================================================
input: torch.Size([2, 200, 200])
Traceback (most recent call last):
File "/home/Documents/relay/temp/visual.py", line 72, in <module>
summary_(model, input_size)
File "/home/Documents/2levation_ws/venv/myenv/lib/python3.8/site-packages/torchsummary/torchsummary.py", line 72, in summary
model(*x)
File "/home/Documents/relay/temp/visual.py", line 34, in __call__
out1 = self.conv1(elevation.view(-1, 1, elevation.shape[0], elevation.shape[1]))
File "/home/Documents/2levation_ws/venv/myenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
result = forward_call(*input, **kwargs)
File "/home/Documents/2levation_ws/venv/myenv/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 463, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/Documents/2levation_ws/venv/myenv/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Calculated padded input size per channel: (2 x 200). Kernel size: (3 x 3). Kernel size can't be greater than actual input size
发生错误的原因是汇总函数中指定的输入大小与应用模型时使用的实际输入大小不匹配。即使使用期间的输入大小为 (200, 200),指定 input_size=(1, 200, 200) 仍然会导致相同的错误。如何解决这个问题?我的 pytorch 版本如下所示。
torch==1.13.1+cu116
torchaudio==0.13.1+cu116
torchinfo==1.8.0
torchsummary==1.5.1
torchvision==0.14.1+cu116
torchviz==0.0.2
我使用 python 3.8.10。
Torchsummary 的工作原理是在样本输入上运行模型并观察中间结果的形状。 Torchsummary 通过创建随机张量
(2, your_shape)
来实现此目的。也就是说,它采用您建议的形状并在前面添加大小为 2 的维度,因此输入形状 (200, 200)
变为 (2, 200, 200)
。
另请注意,虽然您的模型需要
cupy.ndarray
作为输入,但 Torchsummary 将传递 torch.Tensor
。这完全没问题,因为按照惯例,继承自 nn.Module
的类应该接受 torch.Tensor
而不是其他类型。
最后,
elevation.view(-1, 1, elevation.shape[0], elevation.shape[1])
假设 elevation.shape[0]
和 elevation.shape[1]
将是您的 input_size
,因此 (200, 200)
。这是一个非常强的假设,最好假设最后两个维度是 (200, 200)
。
这是代码的工作版本,但经过修改以考虑我刚才描述的内容。
import numpy as np
import torch
import torch.nn as nn
import cupy as cp
from torchviz import make_dot
from torchinfo import summary
from torchsummary import summary as summary_
def get_filter_torch(*args, **kwargs):
class TraversabilityFilter(nn.Module):
def __init__(self, w1, w2, w3, w_out, device="cuda", use_bias=False):
super(TraversabilityFilter, self).__init__()
self.conv1 = nn.Conv2d(1, 4, 3, dilation=1, padding=0, bias=use_bias)
self.conv2 = nn.Conv2d(1, 4, 3, dilation=2, padding=0, bias=use_bias)
self.conv3 = nn.Conv2d(1, 4, 3, dilation=3, padding=0, bias=use_bias)
self.conv_out = nn.Conv2d(12, 1, 1, bias=use_bias)
# Set weights.
self.conv1.weight = nn.Parameter(torch.from_numpy(w1).float())
self.conv2.weight = nn.Parameter(torch.from_numpy(w2).float())
self.conv3.weight = nn.Parameter(torch.from_numpy(w3).float())
self.conv_out.weight = nn.Parameter(torch.from_numpy(w_out).float())
def __call__(self, elevation_cupy):
# Convert cupy tensor to pytorch IF NEEDED
if isinstance(elevation_cupy, cp.ndarray):
elevation_cupy = elevation_cupy.astype(cp.float32)
elevation = torch.as_tensor(elevation_cupy, device=self.conv1.weight.device)
elif isinstance(elevation_cupy, torch.Tensor):
elevation = elevation_cupy
else:
raise TypeError()
print("input: ", elevation.shape)
# use last two axes
with torch.no_grad():
out1 = self.conv1(elevation.view(-1, 1, elevation.shape[-2], elevation.shape[-1]))
out2 = self.conv2(elevation.view(-1, 1, elevation.shape[-2], elevation.shape[-1]))
out3 = self.conv3(elevation.view(-1, 1, elevation.shape[-2], elevation.shape[-1]))
out1 = out1[:, :, 2:-2, 2:-2]
out2 = out2[:, :, 1:-1, 1:-1]
out = torch.cat((out1, out2, out3), dim=1)
out = self.conv_out(out.abs())
out = torch.exp(-out)
print("output: ", out.shape)
return out
traversability_filter = TraversabilityFilter(*args, **kwargs).cuda().eval()
return traversability_filter
# Define the weight values (you need to provide actual weight values)
w1 = np.random.randn(4, 1, 3, 3) # Shape: (out_channels, in_channels, kernel_height, kernel_width)
w2 = np.random.randn(4, 1, 3, 3) # Shape: (out_channels, in_channels, kernel_height, kernel_width)
w3 = np.random.randn(4, 1, 3, 3) # Shape: (out_channels, in_channels, kernel_height, kernel_width)
w_out = np.random.randn(1, 12, 1, 1) # Shape: (out_channels, in_channels, kernel_height, kernel_width)
model = get_filter_torch(w1, w2, w3, w_out)
cell_n = 200
x = cp.random.randn(cell_n, cell_n, dtype=cp.float32)
output = model(x)
print(model)
input_size = (200, 200)
summary(model)
summary_(model, input_size)