假设我在batches
和channels
上有一组坐标numpy数组。
每个数组都有:
lim
项,范围在0到0:2
之间3
上的坐标关联的第四索引上的值。graph
,以使coords
中包含的每个唯一坐标都表示为该索引处所有坐标的总和?即:
coords = np.random.randint(0,lim,(batches, channels,N, 4))
graph = foo(coords)
graph.shape = (batches, channels, lim, lim, lim)
在一种简单的情况下,可以执行以下操作:
def foo(coords): graph = np.zeros(batches, channels, lim, lim, lim) for b_i in range(batches): for c_i in range(channels): for n_i in range(N): elem = coords[b_i, c_i, n_i] if elem[3] > 0: graph[b_i, c_i, elem[0], elem[1], elem[2]]+=elem[3] return graph
但是,我正在寻找一种涉及广播的解决方案,因此我可以将该技术移植到PyTorch,在该广播中,速度是强制性的。
mask = np.argwhere(coords[:,:,:,3]>0)
graph = np.zeros((batches, channels, lim, lim, lim))
idx = coords[tuple(mask.T)]
graph[tuple(np.hstack((mask[:,0:2], idx[:,:3])).T)] += idx[:,3]