PyTorch 中的平方距离计算 - 避免 for 循环

问题描述 投票:0回答:1

我有一个代码,其中大小为 (20, 20) 的 2D 网格在展平 (400) 时需要根据 2D 网格上的所有其他索引计算距离。目前,我正在使用 for 循环来存储它。

# best locations of indices existing on 2D grid-
best_loc.shape
# torch.Size([1024, 2])

# Specify 2D grid size-
m = 20
n = 20

locs = [np.array([i, j]) for i in range(m) for j in range(n)]
locations = torch.LongTensor(np.array(locs))

locations.shape
# torch.Size([400, 2])

def get_distance_squares(best_loc):
    '''
    Compute squared distances between 'best_loc' and 'locations'
    '''
    best_loc = best_loc.unsqueeze(0).expand_as(locations).float()
    best_distance_squares = torch.sum(torch.pow(locations.float() - best_loc, 2), 1)
    return best_distance_squares
     
bmu_distance_squares = list()

for loc in bmu_loc:
    bmu_distance_squares.append(get_distance_squares(loc))
best_distance_squares = torch.stack(best_distance_squares)

best_distance_squares.shape
# torch.Size([1024, 400])

如何避免 for 循环来获取“bmu_distance_squares”(成对平方距离矩阵)?

python-3.x numpy pytorch
1个回答
0
投票

您可以简单地使用

torch.cdist

bmu_distance_squares = torch.cdist(best_loc, locations) ** 2
© www.soinside.com 2019 - 2024. All rights reserved.