高效计算满射函数集

问题描述 投票:0回答:1

f : X -> Y
的每个元素在
Y
中至少有一个原像时,函数
X
是满射的。当
X = {0,...,m-1}
Y = {0,...,n-1}
是两个有限集时,
f
对应于数字
m
< n
元组,并且当每个数字
< n
至少出现一次时,它是满射的。 (当我们要求每个数字都恰好出现一次时,我们就有了
n=m
并且我们有排列。)

我想知道一种有效的算法来计算所有满射元组的集合,对于两个给定的数字

n
m
如上所述。这些元组的number可以通过包含-排除非常有效地计算(例如参见here),但我认为这在这里没有用(因为我们首先计算所有元组,然后删除非满射的逐步进行,我假设所有元组的计算将花费更长的时间。)。另一种方法如下:

考虑例如元组

(1,6,4,2,1,6,0,2,5,1,3,2,3)

其中每个数字< 7 appears at least once. Look at the largest number and erase it:

(1,*,4,2,1,*,0,2,5,1,3,2,3)

它出现在索引 1 和 5 中,因此这对应于集合

{1,5}
,即索引的子集。 其余的对应元组

(1,4,2,1,0,2,5,1,3,2,3)

每个数都有这样的性质 < 6 appears at least once.

我们看到,数字

m
的满射
< n
元组对应于
(T,a)
对,其中
T
{0,...,m-1}
的非空子集,而
a
是满射
(m-k)
元组数字
< n-1
,其中
T
k
个元素。

这导致了以下递归实现(用 Python 编写):

import itertools


def surjective_tuples(m: int, n: int) -> set[tuple]:
    """Set of all m-tuples of numbers < n where every number < n appears at least once.

    Arguments:
        m: length of the tuple
        n: number of distinct values
    """
    if n == 0:
        return set() if m > 0 else {()}
    if n > m:
        return set()
    result = set()
    for k in range(1, m + 1):
        smaller_tuples = surjective_tuples(m - k, n - 1)
        subsets = itertools.combinations(range(m), k)
        for subset in subsets:
            for smaller_tuple in smaller_tuples:
                my_tuple = []
                count = 0
                for i in range(m):
                    if i in subset:
                        my_tuple.append(n - 1)
                        count += 1
                    else:
                        my_tuple.append(smaller_tuple[i - count])
                result.add(tuple(my_tuple))
    return result

我注意到,当输入数字很大时,这非常慢。我怀疑有更快的算法。

python algorithm performance tuples combinatorics
1个回答
0
投票

我已经使用

sympy
s
multiset_permutations
成功提高了百分之几的性能:

from itertools import combinations_with_replacement

from sympy.utilities.iterables import multiset_permutations


def get_combs(s, n):
    for c in combinations_with_replacement(range(1, s), n):
        if sum(c) == s:
            yield c

def surjective_tuples_new(s, n):
    for c in get_combs(s, n):
        for p in multiset_permutations(c):
            out = []

            for i, n in enumerate(p):
                out.extend(i for _ in range(n))

            yield from multiset_permutations(out)

基准:

from timeit import timeit

assert sorted(surjective_tuples_new(10, 8)) == list(
    map(list, sorted(surjective_tuples(10, 8)))
)

t1 = timeit("surjective_tuples(10, 8)", number=1, globals=globals())
t2 = timeit("list(surjective_tuples_new(10, 8))", number=1, globals=globals())

print(t1)
print(t2)

在我的机器上打印(AMD 5700x,Python 3.11):

27.863450561184436
22.92276939912699
© www.soinside.com 2019 - 2024. All rights reserved.