我有一个2D数组,其中每一行都有一个标签存储在一个单独的数组中(不一定是唯一的)。对于每个标签,我想从具有此标签的2D数组中提取行。我想要的基本工作示例是:
import numpy as np
data=np.array([[1,2],[3,5],[7,10], [20,32],[0,0]])
label=np.array([1,1,1,0,1])
#very simple approach
label_values=np.unique(label)
res=[]
for la in label_values:
data_of_this_label_val=data[label==la]
res+=[data_of_this_label_val]
print(res)
结果(res)可以具有任何格式,只要它易于访问即可。在上面的例子中,它将是
[array([[20, 32]]), array([[ 1, 2],
[ 3, 5],
[ 7, 10],
[ 0, 0]])]
请注意,我可以轻松地将列表中的每个元素与label_values
中的一个唯一标签相关联(即通过索引)。
虽然这有效,但使用for循环可能需要花费很多时间,特别是如果我的标签向量很大。这可以更加优雅地加速或编码吗?
你可以argsort
标签(这是我认为unique
在幕后做的)。
如果你的标签是小的nonnegatvie整数,如例子中你可以得到它便宜一点,请参阅https://stackoverflow.com/a/53002966/7207392。
>>> import numpy as np
>>>
>>> data=np.array([[1,2],[3,5],[7,10], [20,32],[0,0]])
>>> label=np.array([1,1,1,0,1])
>>>
>>> idx = label.argsort()
# use kind='mergesort' if you require a stable sort, i.e. one that
# preserves the order of equal labels
>>> ls = label[idx]
>>> split = 1 + np.where(ls[1:] != ls[:-1])[0]
>>> np.split(data[idx], split)
[array([[20, 32]]), array([[ 1, 2],
[ 3, 5],
[ 7, 10],
[ 0, 0]])]
不幸的是,groupby
没有内置的numpy
功能,尽管你可以编写替代品。但是,使用pandas
可以更简洁地解决您的问题,如果您可以使用的话:
import pandas as pd
res = pd.DataFrame(data).groupby(label).apply(lambda x: x.values).tolist()
# or, if performance is important, the following will be faster on large arrays,
# but less readable IMO:
res = [data[i] for i in pd.DataFrame(data).groupby(label).groups.values()]
[array([[20, 32]]), array([[ 1, 2],
[ 3, 5],
[ 7, 10],
[ 0, 0]])]