我尝试训练一个 WGAN,使用 LSTM 作为批评者和生成器,在 MNIST 数据集上生成图像。 不幸的是,我不断遇到错误消息:
NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented. Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API. To run double backwards, please disable the CuDNN backend temporarily while running the forward pass of your RNN. For example:
with torch.backends.cudnn.flags(enabled=False):
output = model(inputs)
我很难理解此错误消息,因为我认为我没有执行
Double backwards
操作。
您能帮我了解此错误消息的来源以及如何解决它吗?
以下是我的实施的相关部分:
批评者
class LSTM_Critic(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(LSTM_Critic, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(IMG_SIZE, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x, labels):
# Set initial hidden and cell states
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).requires_grad_().to(device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).requires_grad_().to(device)
# Passing in the input and hidden state into the model and obtaining outputs
x = x.reshape(BATCH_SIZE, IMG_SIZE, IMG_SIZE)
out, hidden = self.lstm(x, (h0.detach(), c0.detach())) # out: tensor of shape (batch_size, seq_length, hidden_size)
#Reshaping the outputs such that it can be fit into the fully connected layer
out = self.fc(out[:, -1, :])
return out
初始化
gen = LSTM_Generator(200, 100, num_layers, num_classes).to(device)
critic = LSTM_Critic(input_size, hidden_size, num_layers, num_classes).to(device)
initialize_weights(gen)
initialize_weights(critic)
# initializate optimizer
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
gen.train()
critic.train()
培训
for epoch in range(NUM_EPOCHS):
for batch_idx, (real, labels) in enumerate(tqdm(loader)):
real = real.to(device)
cur_batch_size = real.shape[0]
labels = labels.to(device)
# Train Critic: max E[critic(real)] - E[critic(fake)]
# equivalent to minimizing the negative of that
for _ in range(CRITIC_ITERATIONS):
noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
fake = gen(noise, labels)
critic_real = critic(real, labels).reshape(-1)
critic_fake = critic(fake, labels).reshape(-1)
gp = gradient_penalty(critic, labels, real, fake, device=device)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
)
critic.zero_grad()
loss_critic.backward(retain_graph=True)
opt_critic.step()
梯度惩罚
def gradient_penalty(critic, labels, real, fake, device="cpu"):
BATCH_SIZE, C, H, W = real.shape
alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
interpolated_images = real * alpha + fake * (1 - alpha)
# Calculate critic scores
mixed_scores = critic(interpolated_images, labels)
# Take the gradient of the scores with respect to the images
gradient = torch.autograd.grad(
inputs=interpolated_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
)[0]
gradient = gradient.reshape(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
return gradient_penalty
代码在
loss_critic.backward(retain_graph=True)
行失败
由于 cudnn 限制,ASAIK、WGAN GP 无法与 RNN 一起使用。
一种解决方法是编写 JIT 融合 RNN。请参阅https://github.com/pytorch/pytorch/issues/5261#issuecomment-687330144
另一个解决方案是按照提示在前向传递时禁用 cudnn,但它会占用 cpu 并且非常慢。
with torch.backends.cudnn.flags(enabled=False):
output = model(inputs)