所有最小生成树实现

问题描述 投票:0回答:5

我一直在寻找一个实现(我正在使用networkx库。)它将找到无向加权图的所有最小生成树(MST)。

我只能找到 Kruskal 算法和 Prim 算法的实现,这两个算法都只会返回一个 MST。

我看过解决这个问题的论文(例如代表所有最小生成树及其在计数和生成中的应用),但在尝试思考如何将其转换为代码时我的头脑往往会爆炸。

事实上我还没有找到任何语言的实现!

python algorithm language-agnostic graph-theory minimum-spanning-tree
5个回答
10
投票

我不知道这是否是the解决方案,但它是a解决方案(我想说,这是暴力的图形版本):

  1. 使用 kruskal 或 prim 算法求出图的 MST。这应该是 O(E log V)。
  2. 生成所有生成树。这可以在
    O(Elog(V) + V + n) for n = number of spanning trees
    中完成,据我从 2 分钟的谷歌了解,可能可以改进。
  3. 通过树的权重等于 MST 的权重来过滤步骤 #2 中生成的列表。对于 n 作为步骤 #2 中生成的树的数量,这应该是 O(n)。

注意:请懒惰地执行此操作!生成所有可能的树然后过滤结果将占用 O(V^2) 内存,并且多项式空间要求是邪恶的 - 生成一棵树,检查它的权重,如果它是 MST 将其添加到结果列表中,如果不是 - 丢弃它.
总体时间复杂度:

O(Elog(V) + V + n) for G(V,E) with n spanning trees


0
投票

Ronald Rivest 在 Python 中有一个很好的实现,mst.py


0
投票

您可以在 Sorensen 和 Janssens (2005).

的作品中找到一个想法。

这个想法是按升序生成 ST,一旦获得更大的 ST 值就停止枚举。


0
投票

这是一个简短的 Python 实现,基本上是 Kruskal 的递归变体。使用找到的第一个 MST 的权重来限制此后搜索空间的大小。绝对仍然是指数复杂度,但比生成每个生成树更好。还包括一些测试代码。

[注意:这只是我自己的实验,目的是为了好玩,并可能从其他人那里获得对问题的进一步思考的灵感,并不是试图具体实施此处提供的其他答案中建议的任何解决方案]

# Disjoint set find (and collapse) 
def find(nd, djset):
    uv = nd
    while djset[uv] >= 0: uv = djset[uv]
    if djset[nd] >= 0: djset[nd] = uv
    return uv

# Disjoint set union (does not modify djset)
def union(nd1, nd2, djset):
    unionset = djset.copy()
    if unionset[nd2] < unionset[nd1]:
        nd1, nd2 = nd2, nd1

    unionset[nd1] += unionset[nd2]
    unionset[nd2] = nd1
    return unionset

# Bitmask convenience methods; uses bitmasks
# internally to represent MST edge combinations
def setbit(j, mask): return mask | (1 << j)
def isbitset(j, mask): return (mask >> j) & 1
def masktoedges(mask, sedges):
    return [sedges[i] for i in range(len(sedges)) 
            if isbitset(i, mask)]

# Upper-bound count of viable MST edge combination, i.e.
# count of edge subsequences of length: NEDGES, w/sum: WEIGHT
def count_subsequences(sedges, weight, nedges):
#{
    def count(i, target, length, cache):
        tkey = (i, target, length)
        if tkey in cache: return cache[tkey]
        if i == len(sedges) or target < sedges[i][2]: return 0
            
        cache[tkey] = (count(i+1, target, length, cache) +
            count(i+1, target - sedges[i][2], length - 1, cache) + 
            (1 if sedges[i][2] == target and length == 1 else 0))
        
        return cache[tkey]
    
    return count(0, weight, nedges, {})
#}

# Arg: n is number of nodes in graph [0, n-1]
# Arg: sedges is list of graph edges sorted by weight
# Return: list of MSTs, where each MST is a list of edges
def find_all_msts(n, sedges):
#{
    # Recursive variant of kruskal to find all MSTs
    def buildmsts(i, weight, mask, nedges, djset):
    #{
        nonlocal maxweight, msts
        if nedges == (n-1):
            msts.append(mask)
            if maxweight == float('inf'):
                print(f"MST weight: {weight}, MST edges: {n-1}, Total graph edges: {len(sedges)}")
                print(f"Upper bound numb viable MST edge combinations: {count_subsequences(sedges, weight, n-1)}\n")
                maxweight = weight
                
            return
        
        if i < len(sedges):
        #{
            u,v,wt = sedges[i]
            if weight + wt*((n-1) - nedges) <= maxweight:
            #{
                # Left recursive branch - include edge if valid
                nd1, nd2 = find(u, djset), find(v, djset)
                if nd1 != nd2: buildmsts(i+1, weight + wt,
                    setbit(i, mask), nedges+1, union(nd1, nd2, djset))
            
                # Right recursive branch - always skips edge
                buildmsts(i+1, weight, mask, nedges, djset)
            #}
        #}
    #}
        
    maxweight, msts = float('inf'), []
    djset = {i: -1 for i in range(n)}    
    buildmsts(0, 0, 0, 0, djset)    
    return [masktoedges(mask, sedges) for mask in msts]
#}

import time, numpy

def run_test_case(low=10, high=21):
    rng = numpy.random.default_rng()
    n = rng.integers(low, high)
    nedges = rng.integers(n-1, n*(n-1)//2)

    edges = set()
    while len(edges) < nedges: 
        u,v = sorted(rng.choice(range(n), size=2, replace=False))
        edges.add((u,v))

    weights = sorted(rng.integers(1, 2*n, size=nedges))
    sedges = [[u,v,wt] for (u,v), wt in zip(edges, weights)]
    print(f"Numb nodes: {n}\nSorted edges: {sedges}\n")
    
    for i, mst in enumerate(find_all_msts(n, sedges)):
        if i == 0: print("MSTs:")
        print((i+1), ":", mst)

if __name__ == "__main__":
    initial = time.time()
    run_test_case(20, 35)
    print(f"\nRun time: {time.time() - initial}s")

0
投票

我最近发布了 R 代码,它可以完成此线程中的要求:

https://github.com/emmanuelparadis/allMST

这项工作正在进行中。

© www.soinside.com 2019 - 2024. All rights reserved.