我试过在一些输出张量上调用
.detach()
,但这似乎不起作用。作为参考,我在训练 GAN 时遇到了这个错误。这是代码:
def main():
# Encoder In: BxDxN
# Encoder Out: BxNxK
# Decoder In: BxNxK
# Decoder Out: BxDxN
const = Constants()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Autoencoder training setup
AE = USuRPNet(point_per_cloud=43)
#train_loader = DataLoader(PC_Data("ModelNet10"), batch_size = 128)
print("The autoencoder will be running on", device, "device")
AE.to(device)
sample = torch.randn(2000, 43, 3)
AE_train_loader = DataLoader(sample, batch_size=const.batch_size)
AE_loss = chamfer_distance
AE_optimizer = torch.optim.Adam(AE.parameters(), lr = 0.001, weight_decay=0.0001)
for epoch in range(const.num_epochs):
run_loss = 0.0
for i, pc in enumerate(AE_train_loader):
# Train the autoencoder
pc = pc.cuda()
AE_optimizer.zero_grad()
recon, latent_space = AE(pc)
loss = AE_loss(recon, pc)[0]
loss.backward()
AE_optimizer.step()
run_loss += loss.item()
print(f'epoch: {epoch+1} ~ loss: {run_loss}')
run_loss = 0.0
# Generative Adversarial Network training setup
gen = Generator()
disc = Discriminator()
print("The GAN will be running on", device, "device")
gen.to(device)
disc.to(device)
_, latent_space = AE(AE_train_loader.dataset.cuda())
GAN_train_loader = DataLoader(latent_space, batch_size=const.batch_size)
GAN_loss = nn.BCELoss()
disc_optimizer = torch.optim.Adam(disc.parameters(), lr=0.001)
gen_optimizer = torch.optim.Adam(gen.parameters(), lr=0.001)
for epoch in range(const.num_epochs):
run_loss = 0.0
for i, batch in enumerate(GAN_train_loader):
# Train the discriminator
real_sample_labels = torch.ones((const.batch_size, 1)).cuda()
real_samples = batch
gen_sample_labels = torch.zeros((const.batch_size, 1)).cuda()
gen_samples = gen(torch.randn((const.batch_size, 32)).cuda())
all_samples = torch.cat((real_samples, gen_samples))
all_samples_labels = torch.cat((real_sample_labels, gen_sample_labels))
disc.zero_grad()
disc_out = disc(all_samples.cuda())
disc_loss = GAN_loss(disc_out, all_samples_labels)
disc_loss.requires_grad = True
disc_loss.backward()
disc_optimizer.step()
# Train the generator
rand_vecs = torch.rand((const.batch_size, 32))
gen.zero_grad()
gen_samples = gen(rand_vecs.cuda())
out_disc_gen = disc(gen_samples.cuda())
gen_loss = GAN_loss(out_disc_gen, real_sample_labels)
gen_loss.backward()
gen_optimizer.step()
run_loss+=gen_loss+disc_loss
print(f'epoch: {epoch+1} ~ loss: {run_loss}')
run_loss = 0.0
if __name__=="__main__":
main()
错误出现在GAN模型的第二个大循环中。该站点上的其他一些问题的答案是分离图形或不指定
retain_graph=True
(但我听说这会增加运行时间)。