如何使用第一个iter的输入来初始化模块中的变量?

问题描述 投票:0回答:1

我想使用与输入相关的东西来初始化模块中的属性,在第一次 iter 之前,它在 init()

中初始化为零
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.test_var = torch.zeros((1,3))

    def forward(self, imgs):
  
        print(self.test_var)
        if torch.equal( self.test_var, torch.zeros(self.test_var.shape)):
            print('zeros. ')
            var = torch.tensor([1,1,1])
            self.test_var = var + self.test_var

在第一个iter之后应该将test_var的值更改为[1,1,1],但我只发现输出是 enter image description here 这意味着test_var的改变是无效的,并且每次iter都会将其设置为零。 我很困惑。 我想 DataParallel 可能有问题,因为当我设置 CUDA_VISIEBLE_DEVICES=0 时,问题就可以解决。

deep-learning pytorch dataparallel
1个回答
0
投票

您使用哪段代码来调用

model.forward()
函数?
难道你在每个时期后都重新初始化模型吗?
另外,您提供的代码中未使用您的
imgs
参数。

当我初始化模型并像这样调用

model.forward()
...

model = Model()
epochs = 4
for epoch in range(epochs):
  print(f"Epoch {epoch+1}:")
  model.forward(None)
  print("--")

...我得到了您也期望出现的以下输出:在第一个纪元中,添加了这些输出,因为两个张量都只包含零。下面不添加它们,因为零值不等于一值。

Epoch 1:
tensor([[0., 0., 0.]])
zeros. 
--
Epoch 2:
tensor([[1., 1., 1.]])
--
Epoch 3:
tensor([[1., 1., 1.]])
--
Epoch 4:
tensor([[1., 1., 1.]])
--
© www.soinside.com 2019 - 2024. All rights reserved.