获取与标签对应的行,用于许多标签

问题描述 投票:2回答:2

我有一个2D数组,其中每一行都有一个标签存储在一个单独的数组中(不一定是唯一的)。对于每个标签,我想从具有此标签的2D数组中提取行。我想要的基本工作示例是:

import numpy as np

data=np.array([[1,2],[3,5],[7,10], [20,32],[0,0]])
label=np.array([1,1,1,0,1])

#very simple approach
label_values=np.unique(label)
res=[]
for la in label_values:
    data_of_this_label_val=data[label==la]
    res+=[data_of_this_label_val]
print(res)

结果(res)可以具有任何格式,只要它易于访问即可。在上面的例子中,它将是

[array([[20, 32]]), array([[ 1,  2],
   [ 3,  5],
   [ 7, 10],
   [ 0,  0]])]

请注意,我可以轻松地将列表中的每个元素与label_values中的一个唯一标签相关联(即通过索引)。

虽然这有效,但使用for循环可能需要花费很多时间,特别是如果我的标签向量很大。这可以更加优雅地加速或编码吗?

python arrays python-3.x sorting numpy
2个回答
3
投票

你可以argsort标签(这是我认为unique在幕后做的)。

如果你的标签是小的nonnegatvie整数,如例子中你可以得到它便宜一点,请参阅https://stackoverflow.com/a/53002966/7207392

>>> import numpy as np
>>> 
>>> data=np.array([[1,2],[3,5],[7,10], [20,32],[0,0]])
>>> label=np.array([1,1,1,0,1])
>>> 
>>> idx = label.argsort()
# use kind='mergesort' if you require a stable sort, i.e. one that
# preserves the order of equal labels
>>> ls = label[idx]
>>> split = 1 + np.where(ls[1:] != ls[:-1])[0]
>>> np.split(data[idx], split)
[array([[20, 32]]), array([[ 1,  2],
       [ 3,  5],
       [ 7, 10],
       [ 0,  0]])]

2
投票

不幸的是,groupby没有内置的numpy功能,尽管你可以编写替代品。但是,使用pandas可以更简洁地解决您的问题,如果您可以使用的话:

import pandas as pd

res = pd.DataFrame(data).groupby(label).apply(lambda x: x.values).tolist()
# or, if performance is important, the following will be faster on large arrays, 
# but less readable IMO:
res = [data[i] for i in pd.DataFrame(data).groupby(label).groups.values()]

[array([[20, 32]]), array([[ 1,  2],
       [ 3,  5],
       [ 7, 10],
       [ 0,  0]])]
© www.soinside.com 2019 - 2024. All rights reserved.