Pytorch 暹罗网络实现?

问题描述 投票:0回答:1

我尝试根据下面的代码实现孪生网络进行图像分类任务:

class SiameseNetwork(nn.Module):

    def __init__(self):

        super(SiameseNetwork, self).__init__()
        # Setting up the Sequential of CNN Layers

        self.cnn = nn.Sequential(
            nn.Conv2d(1, 96, kernel_size=11,stride=1),
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(5,alpha=0.0001,beta=0.75,k=2),
            nn.MaxPool2d(3, stride=2),

            nn.Conv2d(96, 256, kernel_size=5,stride=1,padding=2),
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(5,alpha=0.0001,beta=0.75,k=2),
            nn.MaxPool2d(3, stride=2),
            nn.Dropout2d(p=0.3),

            nn.Conv2d(256,384 , kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(384,256 , kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, stride=2),
            nn.Dropout2d(p=0.3),
        )
        # Defining the fully connected layers
        self.fc = nn.Sequential(
            nn.Linear(30976, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.5),
            
            nn.Linear(1024, 128),
            nn.ReLU(inplace=True),
            
            nn.Linear(128,2))
        
    def forward_once(self, x):
        # Forward pass 
        output = self.cnn(x)
        output = output.view(output.size()[0], -1)
        output = self.fc(output)
        return output

    def forward(self, input1, input2):
        # forward pass of input 1
        output1 = self.forward_once(input1)
        # forward pass of input 2
        output2 = self.forward_once(input2)
        return output1, output2

大部分内容我都懂,但是什么意思

output = output.view(output.size()[0], -1)
做什么?.

当我使用 resnet 或 vgg 等不同网络更改

self.cnn
时,我真的需要它吗?

pytorch siamese-network
1个回答
0
投票

它重塑

output
张量,使其具有相同的批量大小,但每个条目都是一个展平的向量。

output.size()
以元组形式返回张量形状。
output.size()[0]
选择该元组的第一个条目,通常是批量大小。
output.view()
返回一个具有相同内容但排列不同的张量,而不创建副本。
output.view(output.size()[0], -1') means that the shape of that tensor(view) matches the batch size in the first dimension, and in the second dimension 
-1`表示自动选择维度以匹配向量大小。

例如,假设

output
有 8 个元素,每个元素都是 50x40x7 张量。该张量的形状为 8x500x400x300。
output.view(output.size()[0], -1)
的结果将具有形状 8x14000。

通常,这是在全连接层之前完成的,因为全连接层需要一个平面向量作为输入,每个批处理元素一个。因此,您必须对任何不输出平面向量的网络执行此操作。 Resnet 和 VGG 是分类网络,因此它们的输出是平面向量,因此不需要此操作。

© www.soinside.com 2019 - 2024. All rights reserved.