两幅图像特征之间的每像素余弦相似度

问题描述 投票:0回答:1

我们有两个 3D 矩阵 HxWxC,其中 H、W 是 2D 图像的尺寸,C 是每像素特征。我们想要计算第一幅图像的每个像素特征与第二幅图像的每个像素特征的arg-最大余弦相似度,或者等效地,我们需要一个HxW数组来存储第二幅图像具有最大余弦相似度的像素坐标第一张图像中的每个像素。

对于足够小的 H、W 和 C,可以轻松且非常有效地计算

torch.nn.CosineSimilarity() 
向量化后,您可以获得一个 HxWxHxW 矩阵,然后计算最后两行的 argmax。

如果 H 和 W 更大(具体来说,我们有一个 800x800 的图像)并且特征维度 C 是 1500,则先前的解决方案内存效率不高。即使尝试对每一列进行向量化并仅使用一个 for 循环也需要足够的内存。

那么,问题是,是否有一种有效的方法在时间和内存上计算 GPU 中两个图像的最大每像素余弦相似度的位置?

谢谢!

python image pytorch vectorization cosine-similarity
1个回答
0
投票

为什么我建议首先压平你的原始图像,得到一个(800*800, 1500)向量

v

如果这适合内存,我会先求标量积,然后求 v 的一部分与其自身的 argmax,除了使用的 v 块之外,不会消耗额外的内存(尽管它甚至可能是对标量积中使用的 v 的引用,所以不会使用更多内存)。那么您只需将 argmax 索引存储在 (800*800) 形状向量中。

v: torch.Tensor = ...
v = torch.flatten(v)

chunk_size: int = 50 # arbitrary chunk size

argmax_flat_indices = []
for cursor in range(0, len(v), chunk_size):
    # Here compute the cosine sim between v and (v[cursor: cursor + chunk_size)
    # Do also the argmax along dim 1
    chunk_argmax_indices
    argmax_flat_indices.append(chunk_argmax_indices)
    

argmax_flat_indices = torch.cat(argmax_flat_indices)

#Then compute back original indices
...

这适合你吗?

© www.soinside.com 2019 - 2024. All rights reserved.