PyTorch:从一批图像中矢量化补丁选择

问题描述 投票:0回答:2

假设我有一批图像作为张量,例如:

images = torch.zeros(64, 3, 1024, 1024)

现在,我想从每个图像中选择一个补丁。所有补丁的大小相同,但批次中每个图像的起始位置不同。

size_x = 100
size_y = 100
start_x = torch.zeros(64)
start_y = torch.zeros(64)

我可以达到预期的结果:

result = []
for i in range(arr.shape[0]):
    result.append(arr[i, :, start_x[i]:start_x[i]+size_x, start_y[i]:start_y[i]+size_y])
result = torch.stack(result, dim=0)

问题是——是否有可能更快地完成同样的事情,而不需要循环?也许有某种形式的高级索引,或者 PyTorch 函数可以做到这一点?

python pytorch vectorization torchvision
2个回答
2
投票

您可以使用

torch.take
来摆脱 for 循环。但首先,应该使用此函数创建一个索引数组

def convert_inds(img_a,img_b,patch_a,patch_b,start_x,start_y):
    
    all_patches = np.zeros((len(start_x),3,patch_a,patch_b))
    
    patch_src = np.zeros((patch_a,patch_b))
    inds_src = np.arange(patch_b)
    patch_src[:] = inds_src
    for ind,info in enumerate(zip(start_x,start_y)):
        
        x,y = info
        if x + patch_a + 1 > img_a: return False
        if y + patch_b + 1 > img_b: return False
        start_ind = img_b * x + y
        end_ind = img_b * (x + patch_a -1) + y
        col_src = np.linspace(start_ind,end_ind,patch_b)[:,None]
        all_patches[ind,:] = patch_src + col_src
        
    return all_patches.astype(np.int)

如您所见,此函数本质上为您要切片的每个补丁创建索引。有了这个功能,问题就可以轻松解决了

size_x = 100
size_y = 100
start_x = torch.zeros(64)
start_y = torch.zeros(64)

images = torch.zeros(64, 3, 1024, 1024)
selected_inds = convert_inds(1024,1024,100,100,start_x,start_y)
selected_inds = torch.tensor(selected_inds)
res = torch.take(images,selected_inds)

更新

OP的观察是正确的,上面的方法并不比天真的方法更快。为了避免每次都建立索引,这里有另一种基于

unfold

的解决方案

首先,构建所有可能的补丁的张量

# create all possible patches
all_patches = images.unfold(2,size_x,1).unfold(3,size_y,1)

然后,从

all_patches

中切出所需的补丁
img_ind = torch.arange(images.shape[0])
selected_patches = all_patches[img_ind,:,start_x,start_y,:,:]

0
投票

我通过将最后两个(空间)维度展平为一个维度来解决这个问题,然后使用

gather()
。这看起来比使用
unfold()
更快、更高效。

假设

images
是(单通道)图像的 BS x W x H 数组,
posx
posy
是大小为 BS 的向量,表示每个图像所需的 3x3 块的中心。最后,假设中心距离边缘足够远,您无需担心填充。

idcs_flat = ( ( posx[:, None].repeat((1,9)) + torch.tensor([-1, 0, 1, -1, 0, 1, -1, 0, 1]) ) * S 
                    + posy[:, None].repeat((1,9)) + torch.tensor([-1, -1, -1, 0, 0, 0, 1, 1, 1]) )
patches = torch.gather(images.reshape(BS , -1), 1, idcs_flat)  # Shape: BS x 9 

检查尺寸是否符合您的方向。此外,生成的补丁也是大小为 9 的线性,因此您需要将它们重塑为 3x3。

© www.soinside.com 2019 - 2024. All rights reserved.