LeNet5 和自组织映射 - RuntimeError:尝试再次向后浏览图表 - PyTorch

问题描述 投票:0回答:1

我进行了 LeNet-5 CNN 训练,并使用在 MNIST 数据上训练的自组织图进行训练。训练代码(为简洁起见)是:

# SOM (flattened) weights-
# m = 40, n = 40, n = 84 (LeNet's output shape/dim)
centroids = torch.randn(m * n, dim, device = device, dtype = torch.float32)

locs = [np.array([i, j]) for i in range(m) for j in range(n)]
locations = torch.LongTensor(np.asarray(locs)).to(device)
del locs

def get_bmu_distance_squares(bmu_loc):
    bmu_distance_squares = torch.sum(
        input = torch.square(locations.float() - bmu_loc),
        dim = 1
    )
    return bmu_distance_squares

distance_mat = torch.stack([get_bmu_distance_squares(loc) for loc in locations])

centroids = centroids.to(device)

num_epochs = 50
qe_train = list()

step = 1


for epoch in range(1, num_epochs + 1):
    qe_epoch = 0.0
    for x, y in train_loader:
        x = x.to(device)
        z = model(x)

        # SOM training code:

        batch_size = len(z)

        # Compute distances from batch to (all SOM units) centroids-
        dists = torch.cdist(x1 = z, x2 = centroids, p = p_norm)

        # Find closest (BMU) and retrieve the gaussian correlation matrix
        # for each point in the batch
        # bmu_loc is BS, num points-
        mindist, bmu_index = torch.min(dists, -1)
        # print(f"quantization error = {mindist.mean():.4f}")

        bmu_loc = locations[bmu_index]


        # Compute the SOM weight update:

        # Update LR
        # It is a matrix of shape (BS, centroids) or, (BS, mxn) and tells
        # for each input how much it will affect each (SOM unit) centroid-
        bmu_distance_squares = distance_mat[bmu_index]

        # Get current lr and neighbourhood radius for current step-
        decay_val = scheduler(it = step, tot =  int(len(train_loader) * num_epochs))
        curr_alpha = (alpha * decay_val).to(device)
        curr_sigma = (sigma * decay_val).to(device)

        # Compute Gaussian neighbourhood function-
        neighborhood_func = torch.exp(torch.neg(torch.div(bmu_distance_squares, ((2 * torch.square(curr_sigma)) + 1e-5))))

        expanded_z = z.unsqueeze(dim = 1).expand(-1, grid_size, -1)
        expanded_weights = centroids.unsqueeze(0).expand((batch_size, -1, -1))

        delta = expanded_z - expanded_weights
        lr_multiplier = curr_alpha * neighborhood_func

        delta.mul_(lr_multiplier.reshape(*lr_multiplier.size(), 1).expand_as(delta))
        delta = torch.mean(delta, dim = 0)
        new_weights = torch.add(centroids, delta)
        centroids = new_weights

        # return bmu_loc, torch.mean(mindist)

        # Compute quantization error los-
        qe_loss = torch.mean(mindist)
        qe_epoch += qe_loss.item()

        # Empty accumulated gradients-
        optimizer.zero_grad()

        # Perform backprop-
        qe_loss.backward()

        # Update model trainable params-
        optimizer.step()
         
        step += 1


    qe_train.append(qe_epoch / len(train_loader))
    print(f"\nepoch = {epoch}, QE = {qe_epoch / len(train_loader):.4f}"
        f" & SOM wts L2-norm = {torch.norm(input = centroids, p = 2).item():.4f}"
    )

尝试执行此代码时,我收到错误:

第252行:qe_loss.backward()

Traceback (most recent call last):   File "c:\some_dir\som_lenet5.py", line 252, in <module>
    qe_loss.backward()   File "c:\pytorch_venv\pytorch_cuda\lib\site-packages\torch\_tensor.py", line 522, in backward
    torch.autograd.backward(   File "c:\pytorch_venv\pytorch_cuda\lib\site-packages\torch\autograd\__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
python pytorch conv-neural-network
1个回答
0
投票

问题是梯度在迭代过程中保持在

centroids
。这意味着早在第二次迭代时,您对
dists
的计算就会涉及
centroids
(在第一次迭代结束时更新)。当您对该张量进行反向传播时,它将传播回第一次迭代。防止迭代
centroids
n
的梯度通过
n-1
传播直至迭代
1
的一种方法是在更新其值之前分离
centroids

© www.soinside.com 2019 - 2024. All rights reserved.