我正在处理一些相当大的数据集(500,000 个数据点,每个数据点有 30 个变量)并且想找到最有效的方法来过滤它们。
为了与现有代码兼容,数据结构为列表字典,但不能转换(例如,转换为 pandas DataFrame),必须就地过滤。
工作示例:
data = {'Param0':['x1','x2','x3','x4','x5','x6'],
'Param1':['A','A','A','B','B','C'],
'Param2': [100,200,150,80,90,50],
'Param3': [20,60,40,30,30,5]}
# Param0 keys to keep
keep = ['x2', 'x4']
filtered = {k: [x for i, x in enumerate(v) if data['Param0'][i] in keep] for k, v in data.items()}
结果
filtered
给出了所需的输出,但这在规模上非常慢。
有没有更快的方法来做到这一点?
我会做:
keep_idx = [i for i, v in enumerate(data['Param0']) if v in keep]
filtered = {k: [v[i] for i in keep_idx] for k, v in data.items()}
时间
import numpy as np
from timeit import timeit
# Solution in question
def test_1(data, keep):
return {
k: [x for i, x in enumerate(v) if data['Param0'][i] in keep]
for k, v in data.items()
}
# First solution from @I'mahdi
def test_2(data, keep):
keep_idx = [i for i, v in enumerate(data['Param0']) if v in keep]
return {
k: [val for i, val in enumerate(v) if i in keep_idx]
for k, v in data.items()
}
# Second solution from @I'mahdi
def test_3(data, keep):
keep_idx = [i for i, v in enumerate(data['Param0']) if v in keep]
return {k: list(np.asarray(v)[keep_idx]) for k, v in data.items()}
# Solution in this answer
def test_4(data, keep):
keep_idx = [i for i, v in enumerate(data['Param0']) if v in keep]
return {k: [v[i] for i in keep_idx] for k, v in data.items()}
data = {f"Param{i}": list(range(10_000)) for i in range(20)}
keep = list(range(0, 10_000, 100))
print(test_1(data, keep) == test_2(data, keep))
print(test_2(data, keep) == test_3(data, keep))
print(test_3(data, keep) == test_4(data, keep))
for i in range(1, 5):
t = timeit(f"test_{i}(data, keep)", globals=globals(), number=10)
print(f"Solution {i}: {t:.3f}")
结果是这样的:
Solution 1: 4.571
Solution 2: 4.220
Solution 3: 0.298
Solution 4: 0.219
首先创建一个
look_up_idx
可能是一个更好的主意:
look_up_idx = [idx for idx, v in enumerate(data['Param0']) if v in keep]
filtered = {k: [v for idx, val in enumerate(v) if idx in look_up_idx] for k, v in data.items()}
print(filtered)
或者加上使用
numpy
import numpy as np
look_up_idx = [idx for idx, v in enumerate(data['Param0']) if v in keep]
filtered = {k: list(np.asarray(v)[look_up_idx]) for k, v in data.items()}
{'Param0': ['x2', 'x4'],
'Param1': ['A', 'B'],
'Param2': [200, 80],
'Param3': [60, 30]}