pytorch中有两个以上输入参数时如何使用forward()方法

问题描述 投票:1回答:1

有人可以告诉我forward()方法中多个参数背后的概念吗?通常,forward()方法的实现有两个参数

  1. 自身
  2. 输入

如果前进方法比这些参数更多,PyTorch将如何使用前进方法。

让我们考虑以下代码库:https://github.com/bamps53/kaggle-autonomous-driving2019/blob/master/models/centernet.py在这里在线236作者使用了带有两个其他参数的正向方法:

  1. 中心
  2. return_embeddings

我找不到任何一篇文章可以回答我关于第254(return_embeddings:)行和第257(if centers is not None:)行将执行何种条件的查询。据我所知,该方法由nn模块内部调用。有人可以为此点灯吗?

python deep-learning neural-network pytorch tensor
1个回答
0
投票

您设置的转发功能。这意味着您可以根据需要添加更多参数。例如,您可以添加输入,如下所示

def forward(self, input1, input2,input3):
    x = self.layer1(input1)
    y = self.layer2(input2)
    z = self.layer3(input3)

    net = torch.cat((x,y,z),1)

    return net

您必须在馈送网络时控制参数。不能使用超过一个参数的方式来馈送图层。因此,您需要从输入中一个接一个地提取特征,并与它们的torch.cat((x,y),1)(维数为1)串联。

© www.soinside.com 2019 - 2024. All rights reserved.