我有一批形状为(N, C, H, W)
的图像,其中N是图像数,C - 通道数,H,W - 高度和宽度。
每个图像都有2个通道,其中一些像素值为[-1 , -1]
。
如何在不使用for循环的情况下在批处理中找到这些像素的位置,因为它非常慢。
使用numpy.where
:
# creating test data
test = np.zeros((5,2,3,3))
test[3,:,2,1] = [-1.,-1.]
value = -np.ones((1.,2.,1.,1.)) # this is the value you are looking for
np.where(test == value)
# this returns: (array([3, 3], dtype=int64),
# array([0, 1], dtype=int64),
# array([2, 2], dtype=int64),
# array([1, 1], dtype=int64))
编辑:要获得相应的掩码,请不要使用where
:
test == value
你可以使用numpy.where
。一个简单的例子:
x = np.random.randn(4,2,10,10)
x[0,1,2,3] = 1
x[0,1,4,5] = 1
np.where(x==1)
(array([0,0],dtype = int64),array([1,1],dtype = int64),array([2,4],dtype = int64),array([3,5],dtype = Int64的))