我尝试了一个自定义的Conv2d函数,该函数必须类似于nn.Conv2d,但在nn.Conv2d内部使用的乘法和加法被mymult(num1,num2)和myadd(num1,num2)取代。
根据非常有用的论坛1, 2的见解,我可以做的是尝试展开它,然后进行矩阵乘法。我该如何使用带mymult()和myadd()的循环在下面给出@部分,我相信这是@ matmul,尽管我不确定。
基本上,我想使用python循环实现res = torch.matmul(kernels_flat,imageunfold),以便代替*我可以使用mymult(num1,num2),对于+我可以使用myadd(num1,num2)
def convcheck():
torch.manual_seed(123)
batch_size = 2
channels = 2
h, w = 2, 2
image = torch.randn(batch_size, channels, h, w) # input image
out_channels = 3
kh, kw = 1, 1# kernel size
dh, dw = 1, 1 # stride
size = int((h-kh+2*0)/dh+1) #include padding in place of zero
conv = nn.Conv2d(in_channels=channels, out_channels=out_channels, kernel_size=kw, padding=0,stride=dh ,bias=False)
out = conv (image)
#print('out', out)
#print('out.size()', out.size())
#print('')
filt = conv.weight.data
imageunfold = F.unfold(image,kernel_size=kh,padding=0,stride=dh)
print("Unfolded image","\n",imageunfold,"\n",imageunfold.shape)
kernels_flat = filt.view(out_channels,-1)
print("Kernel Flat=","\n",kernels_flat,"\n",kernels_flat.shape)
res = kernels_flat @ imageunfold # I have to replace this operation with mymult() and myadd()
print(res,"\n",res.shape)
#print(res.size(2),"\n",res.shape)
res = res.view(-1, out_channels, size, size)
#print("Same answer as buitlin function",res)
for m_batch in range(len(imageunfold)):
#iterate through rows of X
for i in range(kernels_flat.size(0) ):
# iterate through columns of Y
for j in range(imageunfold.size(2)):
# iterate through rows of Y
for k in range(imageunfold.size(1)):
#print(result[m_batch][i][j]," +=", kernels_flat[i][k], "*", imageunfold[m_batch][k][j])
result[m_batch][i][j] += kernels_flat[i][k] * imageunfold[m_batch][k][j]