为什么在模型init()上发生了OOM?

问题描述 投票:1回答:1

我的模型中的一行tr.nn.Linear(hw_flat * num_filters*8, num_fc)在模型初始化时导致OOM错误。评论它会消除内存问题。

import torch as tr
from layers import Conv2dSame, Flatten

class Discriminator(tr.nn.Module):
    def __init__(self, cfg):
        super(Discriminator, self).__init__()
        num_filters = 64
        hw_flat = int(cfg.hr_resolution[0] / 2**4)**2
        num_fc = 1024

        self.model = tr.nn.Sequential(
            # Channels in, channels out, filter size, stride, padding
            Conv2dSame(cfg.num_channels, num_filters, 3),
            tr.nn.LeakyReLU(),
            Conv2dSame(num_filters, num_filters, 3, 2),
            tr.nn.BatchNorm2d(num_filters),
            tr.nn.LeakyReLU(),
            Conv2dSame(num_filters, num_filters*2, 3),
            tr.nn.BatchNorm2d(num_filters*2),
            tr.nn.LeakyReLU(),
            Conv2dSame(num_filters*2, num_filters*2, 3, 2),
            tr.nn.BatchNorm2d(num_filters*2),
            tr.nn.LeakyReLU(),
            Conv2dSame(num_filters*2, num_filters*4, 3),
            tr.nn.BatchNorm2d(num_filters*4),
            tr.nn.LeakyReLU(),
            Conv2dSame(num_filters*4, num_filters*4, 3, 2),
            tr.nn.BatchNorm2d(num_filters*4),
            tr.nn.LeakyReLU(),
            Conv2dSame(num_filters*4, num_filters*8, 3),
            tr.nn.BatchNorm2d(num_filters*8),
            tr.nn.LeakyReLU(),
            Conv2dSame(num_filters*8, num_filters*8, 3, 2),
            tr.nn.BatchNorm2d(num_filters*8),
            tr.nn.LeakyReLU(),
            Flatten(),
            tr.nn.Linear(hw_flat * num_filters*8, num_fc),
            tr.nn.LeakyReLU(),
            tr.nn.Linear(num_fc, 1),
            tr.nn.Sigmoid()
        )
        self.model.apply(self.init_weights)

    def forward(self, x_in):
        x_out = self.model(x_in)
        return x_out

    def init_weights(self, layer):
        if type(layer) in [tr.nn.Conv2d, tr.nn.Linear]:
            tr.nn.init.xavier_uniform_(layer.weight)

这很奇怪,因为hw_flat = 96 * 96 = 9216,而num_filters * 8 = 512,所以hw_flat * num_filters = 4718592,这是该层中的参数数量。我已经确认了这个计算,因为将图层更改为tr.nn.Linear(4718592, num_fc)会产生相同的输出。

对我来说这没有意义,因为dtype = float32,所以预期的大小为32 * 4718592 = 150,994,944字节。这相当于约150mb。

错误信息是:

Traceback (most recent call last):
  File "main.py", line 116, in <module>
    main()
  File "main.py", line 112, in main
    srgan = SRGAN(cfg)
  File "main.py", line 25, in __init__
    self.discriminator = Discriminator(cfg).to(device)
  File "/home/jpatts/Documents/ECE/ECE471-SRGAN/models.py", line 87, in __init__
    tr.nn.Linear(hw_flat * num_filters*8, num_fc),
  File "/home/jpatts/.local/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 51, in __init__
    self.weight = Parameter(torch.Tensor(out_features, in_features))
RuntimeError: $ Torch: not enough memory: you tried to allocate 18GB. Buy new RAM! at /pytorch/aten/src/TH/THGeneral.cpp:201

我也只运行1的批量大小(不影响此错误),网络的整体输入形状为(1,3,1536,1536),并且在展平图层之后的形状为(1,4718592)。

为什么会这样?

python-3.x pytorch
1个回答
1
投票

你的线性层非常大 - 实际上它需要至少18GB的内存。 (你的估计有两个原因:(1)float32占用4个字节的内存,而不是32,(2)你没有乘以输出大小。)

来自PyTorch documentation FAQs

不要使用太大的线性图层。线性层nn.Linear(m, n)使用O(n*m)内存:也就是说,权重的内存要求与特征的数量成二次方。以这种方式很容易吹过你的记忆(并记住你需要至少两倍的重量,因为你还需要存储渐变。)

© www.soinside.com 2019 - 2024. All rights reserved.