如何解决带有梯度惩罚的 WGAN-LSTM 的“_cudnn_rnn_backward”导数

问题描述 投票:0回答:1

我尝试训练一个 WGAN,使用 LSTM 作为批评者和生成器,在 MNIST 数据集上生成图像。 不幸的是,我不断遇到错误消息:

NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented. Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API. To run double backwards, please disable the CuDNN backend temporarily while running the forward pass of your RNN. For example: 
with torch.backends.cudnn.flags(enabled=False):
   output = model(inputs)

我很难理解此错误消息,因为我认为我没有执行

Double backwards
操作。 您能帮我了解此错误消息的来源以及如何解决它吗?

以下是我的实施的相关部分:

批评者

class LSTM_Critic(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(LSTM_Critic, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(IMG_SIZE, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)
    
    def forward(self, x, labels):
        # Set initial hidden and cell states 
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).requires_grad_().to(device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).requires_grad_().to(device)

        # Passing in the input and hidden state into the model and  obtaining outputs
        x = x.reshape(BATCH_SIZE, IMG_SIZE, IMG_SIZE)
        out, hidden = self.lstm(x, (h0.detach(), c0.detach()))  # out: tensor of shape (batch_size, seq_length, hidden_size)
        #Reshaping the outputs such that it can be fit into the fully connected layer
        
        out = self.fc(out[:, -1, :])
        return out

初始化

gen = LSTM_Generator(200, 100, num_layers, num_classes).to(device)
critic = LSTM_Critic(input_size, hidden_size, num_layers, num_classes).to(device)

initialize_weights(gen)
initialize_weights(critic)

# initializate optimizer
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

gen.train()
critic.train()

培训

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, labels) in enumerate(tqdm(loader)):
        real = real.to(device)
        cur_batch_size = real.shape[0]
        labels = labels.to(device)

        # Train Critic: max E[critic(real)] - E[critic(fake)]
        # equivalent to minimizing the negative of that
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
            fake = gen(noise, labels)
            critic_real = critic(real, labels).reshape(-1)
            critic_fake = critic(fake, labels).reshape(-1)

            gp = gradient_penalty(critic, labels, real, fake, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
            )
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

梯度惩罚

def gradient_penalty(critic, labels, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, labels)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]



    gradient = gradient.reshape(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

代码在

loss_critic.backward(retain_graph=True)

行失败
python pytorch lstm recurrent-neural-network
1个回答
0
投票

由于 cudnn 限制,ASAIK、WGAN GP 无法与 RNN 一起使用。

一种解决方法是编写 JIT 融合 RNN。请参阅https://github.com/pytorch/pytorch/issues/5261#issuecomment-687330144

另一个解决方案是按照提示在前向传递时禁用 cudnn,但它会占用 cpu 并且非常慢。

with torch.backends.cudnn.flags(enabled=False):
    output = model(inputs)

© www.soinside.com 2019 - 2024. All rights reserved.