我需要找到同一集合的两个分区之间的所有不同交集。例如,如果我们有以下两个相同集合的分区
x = [[1、2],[3、4、5],[6、7、8、9、10]]
y = [[1、3、6、7],[2、4、5、8、9、10]],
所需的结果是
[[[1],[2],[3],[4,5],[6,7],[8,9,10]]。
详细地,我们计算x和y的每个子集之间的笛卡尔乘积,对于这些乘积中的每个,我们将新子集中的元素分类为相应的子集,如果它们不属于其关联子集的交集。
我发现了一个不太优雅的解决方案,但是,嘿!它有效(显然):)
import numpy as np
def decomposition(row1, row2):
'''It decomposes two sets in maximum three sets, given by their intersection and the two non intersected parts'''
d1 = np.setdiff1d(row1, row2)
d2 = np.intersect1d(row1, row2)
d3 = np.setdiff1d(row2, row1)
ds = (d1, d2, d3)
return [d for d in ds if d.size != 0]
def partitions_decomposition_product(aa, bb):
'''It find all the different intersections based on Cartesian product between two partitions of the same set'''
assert (np.sort(np.concatenate((aaa), axis=None)) == np.sort(np.concatenate((bbb), axis=None))).all()
decomposition_list = []
for a in aa:
for b in bb:
cc = decomposition(a, b)
if not decomposition_list:
decomposition_list = cc
else:
for c in cc:
if not any(np.array_equal(c, x) for x in decomposition_list):
trigger_last = True
for k, lt in enumerate(decomposition_list):
if (any(np.array_equal(c, x) for x in decomposition_list)) or (np.intersect1d(c, lt).size != 0):
deco_c_lt = decomposition(c, lt)
if not any(np.array_equal(decomposition_list[k], x) for x in deco_c_lt):
del decomposition_list[k]
decomposition_list = decomposition_list + decomposition(c, lt)
trigger_last = False
elif (k == len(decomposition_list)-1) and trigger_last:
decomposition_list.append(c)
decomposition_list.sort(key=lambda x: x[0])
return decomposition_list
将此最后一个函数应用于前面提到的(一维numpy数组的列表)
aaa = [np.array([1, 2]), np.array([3, 4, 5]), np.array([6, 7, 8, 9, 10])]
bbb = [np.array([1, 3, 6, 7]), np.array([2, 4, 5, 8, 9, 10])]
p_d_p = partitions_decomposition_product(aaa, bbb)
我们获得了预期的结果
p_d_p
[array([1]),
array([2]),
array([3]),
array([4, 5]),
array([6, 7]),
array([8, 9, 10])]
有什么方法可以优化此代码(的一部分)?预先感谢!
不确定我是否正确理解您,但是此脚本会产生您所遇到的问题:
x = [[1, 2], [3, 4, 5], [6, 7, 8, 9, 10]]
y = [[1, 3, 6, 7], [2, 4, 5, 8, 9, 10]]
out = []
for sublist1 in x:
d = {}
for val in sublist1:
for i, sublist2 in enumerate(y):
if val in sublist2:
d.setdefault(i, []).append(val)
out.extend(d.values())
print(out)
打印:
[[1], [2], [3], [4, 5], [6, 7], [8, 9, 10]]