最近我一直在尝试重现https://github.com/zexuanqiu/CIBHash的结果,但是每次评估后都会遇到损失爆炸。我正在使用官方网站上的 cifar-10 数据集。
举个例子,
python main.py cifar16 --train --dataset cifar10 --encode_length 16 --cuda
运行代码(默认设置validate_frequency=20
),代码会在epoch=20
之后评估自身的性能并继续训练,epoch=21
出现损失爆炸问题。python main.py cifar16 --train --dataset cifar10 --encode_length 16 --cuda --validate_frequency=3
运行代码,设置validate_frequency=3
,损失爆炸发生在epoch=4
,稳定。这是它的
run_training_session
功能:
def run_training_session(self, run_num, logger):
self.train()
# Scramble hyperparameters if number of runs is greater than 1.
if self.hparams.num_runs > 1:
logger.log('RANDOM RUN: %d/%d' % (run_num, self.hparams.num_runs))
for hparam, values in self.get_hparams_grid().items():
assert hasattr(self.hparams, hparam)
self.hparams.__dict__[hparam] = random.choice(values)
random.seed(self.hparams.seed)
torch.manual_seed(self.hparams.seed)
self.define_parameters()
# if encode_length is 16, then al least 80 epochs!
if self.hparams.encode_length == 16:
self.hparams.epochs = max(80, self.hparams.epochs)
logger.log('hparams: %s' % self.flag_hparams())
device = torch.device('cuda' if self.hparams.cuda else 'cpu')
self.to(device)
optimizer = self.configure_optimizers()
train_loader, val_loader, _, database_loader = self.data.get_loaders(
self.hparams.batch_size, self.hparams.num_workers,
shuffle_train=True, get_test=False)
best_val_perf = float('-inf')
best_state_dict = None
bad_epochs = 0
try:
for epoch in range(1, self.hparams.epochs + 1):
forward_sum = {}
num_steps = 0
for batch_num, batch in enumerate(train_loader):
optimizer.zero_grad()
imgi, imgj, _ = batch
imgi = imgi.to(device)
imgj = imgj.to(device)
forward = self.forward(imgi, imgj, device)
for key in forward:
if key in forward_sum:
forward_sum[key] += forward[key]
else:
forward_sum[key] = forward[key]
num_steps += 1
if math.isnan(forward_sum['loss']):
logger.log('Stopping epoch because loss is NaN')
break
forward['loss'].backward()
optimizer.step()
if math.isnan(forward_sum['loss']):
logger.log('Stopping training session because loss is NaN')
break
logger.log('End of epoch {:3d}'.format(epoch), False)
logger.log(' '.join([' | {:s} {:8.4f}'.format(
key, forward_sum[key] / num_steps)
for key in forward_sum]), True)
if epoch % self.hparams.validate_frequency == 0:
print('evaluating...')
val_perf = self.evaluate(database_loader, val_loader, self.data.topK, device)
logger.log(' | val perf {:8.4f}'.format(val_perf), False)
if val_perf > best_val_perf:
best_val_perf = val_perf
bad_epochs = 0
logger.log('\t\t*Best model so far, deep copying*')
best_state_dict = deepcopy(self.state_dict())
else:
bad_epochs += 1
logger.log('\t\tBad epoch %d' % bad_epochs)
if bad_epochs > self.hparams.num_bad_epochs:
break
except KeyboardInterrupt:
logger.log('-' * 89)
logger.log('Exiting from training early')
return best_state_dict, best_val_perf
这是CIBHash模型的
forward
函数:
def forward(self, imgi, imgj, device):
imgi = self.vgg.features(imgi)
imgi = imgi.view(imgi.size(0), -1)
imgi = self.vgg.classifier(imgi)
prob_i = torch.sigmoid(self.encoder(imgi))
z_i = hash_layer(prob_i - torch.empty_like(prob_i).uniform_().to(prob_i.device))
imgj = self.vgg.features(imgj)
imgj = imgj.view(imgj.size(0), -1)
imgj = self.vgg.classifier(imgj)
prob_j = torch.sigmoid(self.encoder(imgj))
z_j = hash_layer(prob_j - torch.empty_like(prob_j).uniform_().to(prob_j.device))
kl_loss = (self.compute_kl(prob_i, prob_j) + self.compute_kl(prob_j, prob_i)) / 2
contra_loss = self.criterion(z_i, z_j, device)
loss = contra_loss + self.hparams.weight * kl_loss
return {'loss': loss, 'contra_loss': contra_loss, 'kl_loss': kl_loss}
我尝试按照
https://github.com/zexuanqiu/CIBHash/issues/6中的说明替换为
z_i
和z_j
,但是,它未能防止NaN问题。
我尝试过
gradient_gripping
方法但没有用。
根据作者的回复,他们在训练模型时并没有遇到任何 NaN 问题。 (https://github.com/zexuanqiu/CIBHash/issues/7)
我希望代码能够完成训练而不会出现 NaN 问题。谁能好心告诉我是什么因素可能导致这个问题?或者是否有任何潜在的解决方案来解决 NaN 损失问题?
原来这个问题是由于 GPU 内存不足以及之前某些 CUDA 版本中的某种未知 bug 导致的?
我已经尝试过:
以下是我收集或尝试过的一些关于不同 GPU 上性能差异的信息:
我的机器:
2080Ti机:
Colab环境:
A40 GPU 环境:
总之,我怀疑旧版本的 CUDA 中没有正确处理 MemoryError 问题,如 https://github.com/ultralytics/ultralytics/issues/5294 中所报告的那样: CUDA 错误(例如“内存不足”)可能会出现也会导致 NaN 结果。
我怀疑旧版本的CUDA可能缺乏适当的错误处理机制来处理由于内存不足而导致的NaN,但是,我对此问题的证据太少了。如果有人知道这个问题的更具体细节,请联系我!