Pytorch如何获得两次失落函数的梯度

问题描述 投票:4回答:2

这是我正在尝试实现的:

像往常一样,我们根据F(X)计算损失。但我们也定义了“对抗性损失”,这是基于F(X + e)的损失。 e定义为dF(X)/dX乘以某个常数。损失和对抗性损失都会因总损失而被反向传播。

在tensorflow中,这部分(获得dF(X)/dX)可以编码如下:

  grad, = tf.gradients( loss, X )
  grad = tf.stop_gradient(grad)
  e = constant * grad

下面是我的pytorch代码:

class DocReaderModel(object):
    def __init__(self, embedding=None, state_dict=None):
        self.train_loss = AverageMeter()
        self.embedding = embedding
        self.network = DNetwork(opt, embedding)
        self.optimizer = optim.SGD(parameters)

    def adversarial_loss(self, batch, loss, embedding, y):
        self.optimizer.zero_grad()
        loss.backward(retain_graph=True)
        grad = embedding.grad
        grad.detach_()

        perturb = F.normalize(grad, p=2)* 0.5
        self.optimizer.zero_grad()
        adv_embedding = embedding + perturb
        network_temp = DNetwork(self.opt, adv_embedding) # This is how to get F(X)
        network_temp.training = False
        network_temp.cuda()
        start, end, _ = network_temp(batch) # This is how to get F(X)
        del network_temp # I even deleted this instance.
        return F.cross_entropy(start, y[0]) + F.cross_entropy(end, y[1])

    def update(self, batch):
        self.network.train()
        start, end, pred = self.network(batch)
        loss = F.cross_entropy(start, y[0]) + F.cross_entropy(end, y[1])
        loss_adv = self.adversarial_loss(batch, loss, self.network.lexicon_encoder.embedding.weight, y) 
        loss_total = loss + loss_adv 

        self.optimizer.zero_grad()
        loss_total.backward()
        self.optimizer.step()

我有几个问题:

1)我用grad.detach_()替换了tf.stop_gradient。它是否正确?

2)我得到了"RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time."所以我在retain_graph=True添加了loss.backward。那个特定的错误消失了。但是现在我在几个纪元(RuntimeError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1525909934016/work/aten/src/THC/generic/THCStorage.cu:58)之后出现了内存错误。我怀疑我不必要地保留图表。

有人能让我知道pytorch的最佳实践吗?任何提示/甚至短评都将受到高度赞赏。

tensorflow pytorch
2个回答
1
投票

我认为你正在尝试实现生成对抗网络(GAN),但是从代码中,我不理解并且无法遵循你想要实现的目标,因为GAN有一些缺失的部分可以工作。我可以看到有一个鉴别器网络模块,DNetwork但缺少发电机网络模块。

如果猜测,当你说'损失功能两次'时,我认为你的意思是你有一个用于鉴别器网的损失功能和另一个用于发电机网的功能。如果是这种情况,让我分享一下如何实现基本的GAN模型。

举个例子,我们来看看这个Wasserstein GAN Jupyter notebook

我将跳过不太重要的部分并放大重要部分:

  1. 首先,导入PyTorch库并进行设置 # Set up batch size, image size, and size of noise vector: bs, sz, nz = 64, 64, 100 # nz is the size of the latent z vector for creating some random noise later
  2. 构建鉴别器模块 class DCGAN_D(nn.Module): def __init__(self): ... truncated, the usual neural nets stuffs, layers, etc ... def forward(self, input): ... truncated, the usual neural nets stuffs, layers, etc ...
  3. 构建生成器模块 class DCGAN_G(nn.Module): def __init__(self): ... truncated, the usual neural nets stuffs, layers, etc ... def forward(self, input): ... truncated, the usual neural nets stuffs, layers, etc ...
  4. 把它们放在一起 netG = DCGAN_G().cuda() netD = DCGAN_D().cuda()
  5. 需要告知优化器要优化哪些变量。模块自动跟踪其变量。 optimizerD = optim.RMSprop(netD.parameters(), lr = 1e-4) optimizerG = optim.RMSprop(netG.parameters(), lr = 1e-4)
  6. Discriminator的一个前进步骤和一个后退步骤 这里,网络可以在后向传递期间计算梯度,取决于此功能的输入。所以,在我的情况下,我有3种类型的损失;发生器损失,鉴别者实际图像丢失,鉴别者假图像丢失。对于3种不同的净通行证,我可以获得三次失步函数的梯度。 def step_D(input, init_grad): # input can be from generator's generated image data or input image from dataset err = netD(input) err.backward(init_grad) # backward pass net to calculate gradient return err # loss
  7. 控制可训练参数[重要] 模型中的可训练参数是需要梯度的参数。 def make_trainable(net, val): for p in net.parameters(): p.requires_grad = val # note, i.e, this is later set to False below in netG update in the train loop. 在TensorFlow中,这部分可以编码如下: grad = tf.gradients(loss, X) grad = tf.stop_gradient(grad) 所以,我认为这将回答你的第一个问题,“我用grad.detach_()替换了tf.stop_gradient。这是正确的吗?”
  8. 火车循环

你可以在这里看到这里如何调用3种不同的损失函数。

    def train(niter, first=True):

        for epoch in range(niter):
            # Make iterable from PyTorch DataLoader
            data_iter = iter(dataloader)
            i = 0

            while i < n:
                ###########################
                # (1) Update D network
                ###########################
                make_trainable(netD, True)

                # train the discriminator d_iters times
                d_iters = 100

                j = 0

                while j < d_iters and i < n:
                    j += 1
                    i += 1

                    # clamp parameters to a cube
                    for p in netD.parameters():
                        p.data.clamp_(-0.01, 0.01)

                    data = next(data_iter)

                    ##### train with real #####
                    real_cpu, _ = data
                    real_cpu = real_cpu.cuda()
                    real = Variable( data[0].cuda() )
                    netD.zero_grad()

                    # Real image discriminator loss
                    errD_real = step_D(real, one)

                    ##### train with fake #####
                    fake = netG(create_noise(real.size()[0]))
                    input.data.resize_(real.size()).copy_(fake.data)

                    # Fake image discriminator loss
                    errD_fake = step_D(input, mone)

                    # Discriminator loss
                    errD = errD_real - errD_fake
                    optimizerD.step()

                ###########################
                # (2) Update G network
                ###########################
                make_trainable(netD, False)
                netG.zero_grad()

                # Generator loss
                errG = step_D(netG(create_noise(bs)), one)
                optimizerG.step()

                print('[%d/%d][%d/%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f'
                    % (epoch, niter, i, n,
                    errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0]))

“我得到了”RuntimeError:第二次尝试向下浏览图表......“

PyTorch有这种行为;为了减少GPU内存使用量,在.backward()调用期间,所有中间结果(如果你有保存的激活等)将在不再需要时被删除。因此,如果您尝试再次调用.backward(),则中间结果不存在,并且无法执行向后传递(并且您会看到错误)。

这取决于你想要做什么。您可以调用.backward(retain_graph=True)进行不会删除中间结果的向后传递,这样您就可以再次调用.backward()。除了最后一次向后调用之外的所有调用都应该有retain_graph=True选项。

有人能让我知道pytorch的最佳实践

正如您从上面的PyTorch代码以及PyTorch中正在尝试保留Pythonic的方式所做的那样,您可以了解PyTorch在那里的最佳实践。


0
投票

如果你想使用高阶导数(即导数的导数),请看一下create_graphbackward选项。

例如:

loss = get_loss()
loss.backward(create_graph=True)
loss_grad_penalty = loss + loss.grad
loss_grad_penalty.backward()
© www.soinside.com 2019 - 2024. All rights reserved.