当
f : X -> Y
的每个元素在 Y
中至少有一个原像时,函数 X
是满射的。当X = {0,...,m-1}
和Y = {0,...,n-1}
是两个有限集时,f
对应于数字m
的< n
元组,并且当每个数字< n
至少出现一次时,它是满射的。 (当我们要求每个数字都恰好出现一次时,我们就有了n=m
并且我们有排列。)
我想知道一种有效的算法来计算所有满射元组的集合,对于两个给定的数字
n
和 m
如上所述。这些元组的number可以通过包含-排除非常有效地计算(例如参见here),但我认为这在这里没有用(因为我们首先计算所有元组,然后删除非满射的逐步进行,我假设所有元组的计算将花费更长的时间。)。另一种方法如下:
考虑例如元组
(1,6,4,2,1,6,0,2,5,1,3,2,3)
其中每个数字< 7 appears at least once. Look at the largest number and erase it:
(1,*,4,2,1,*,0,2,5,1,3,2,3)
它出现在索引 1 和 5 中,因此这对应于集合
{1,5}
,即索引的子集。
其余的对应元组
(1,4,2,1,0,2,5,1,3,2,3)
每个数都有这样的性质 < 6 appears at least once.
我们看到,数字
m
的满射< n
元组对应于(T,a)
对,其中T
是{0,...,m-1}
的非空子集,而a
是满射(m-k)
元组数字 < n-1
,其中 T
有 k
个元素。
这导致了以下递归实现(用 Python 编写):
import itertools
def surjective_tuples(m: int, n: int) -> set[tuple]:
"""Set of all m-tuples of numbers < n where every number < n appears at least once.
Arguments:
m: length of the tuple
n: number of distinct values
"""
if n == 0:
return set() if m > 0 else {()}
if n > m:
return set()
result = set()
for k in range(1, m + 1):
smaller_tuples = surjective_tuples(m - k, n - 1)
subsets = itertools.combinations(range(m), k)
for subset in subsets:
for smaller_tuple in smaller_tuples:
my_tuple = []
count = 0
for i in range(m):
if i in subset:
my_tuple.append(n - 1)
count += 1
else:
my_tuple.append(smaller_tuple[i - count])
result.add(tuple(my_tuple))
return result
我注意到,当输入数字很大时,这非常慢。我怀疑有更快的算法。
我已经使用
sympy
s multiset_permutations
成功提高了百分之几的性能:
from itertools import combinations_with_replacement
from sympy.utilities.iterables import multiset_permutations
def get_combs(s, n):
for c in combinations_with_replacement(range(1, s), n):
if sum(c) == s:
yield c
def surjective_tuples_new(s, n):
for c in get_combs(s, n):
for p in multiset_permutations(c):
out = []
for i, n in enumerate(p):
out.extend(i for _ in range(n))
yield from multiset_permutations(out)
基准:
from timeit import timeit
assert sorted(surjective_tuples_new(10, 8)) == list(
map(list, sorted(surjective_tuples(10, 8)))
)
t1 = timeit("surjective_tuples(10, 8)", number=1, globals=globals())
t2 = timeit("list(surjective_tuples_new(10, 8))", number=1, globals=globals())
print(t1)
print(t2)
在我的机器上打印(AMD 5700x,Python 3.11):
27.863450561184436
22.92276939912699