基于 2 个标准查找最短路径的算法

问题描述 投票:0回答:1

我们从节点 0 开始,需要使用尽可能少的步骤到达节点 n-1。同时每一步都会影响我们的温度,有些步骤增加 1 度,有些则减少 1 度。

输入的格式是这样的 ->

arr = [4, 3, [[0, 1, -1], [1, 3, -1], [1, 2, 1]]]

arr[0]
4 是 n -> 所有可能节点的数量,其中第一个节点是 0,最后一个节点是 3,在这种情况下我们希望从节点 0 到节点 3 (n-1)

arr[1]
3 是 m -> 所有可用步骤的数量(下面列出)

arr[2]
[[0, 1, -1], [1, 3, -1], [1, 2, 1]] -> arr 长度为 m,其中每个元素是一个包含元素 u, v, c 的列表有关步骤的信息 -> u!=v (u(node) 是步骤的开始,v(node) 是结束),c 是 1 或 -1 -> 这告诉我们步骤是否减去温度或添加它。

目标是使用尽可能少的步骤并尽可能使温度最接近 0。较短的路径比温度更重要,因此,如果我们在 path_length=1 温度=20 和 path_length=2 温度=0 之间进行选择,我们选择路径长度较小的那个 -> path_length=1

如果没有有效路径,我们就打印出来

ajajaj

输入/输出示例:

输入:

arr = [4, 3, [[0, 1, -1], [1, 3, -1], [1, 2, 1]]]

输出:

(-2, 2, [0, 1])
,其中 -2 是温度,2 是路径长度,列表是使用的路径 -> 我们使用
[0, 1, -1]
[1, 3, -1]
来到达 n-1

输入:

[4, 4, [[0, 1, 1], [0, 2, -1], [1, 3, -1], [2, 3, 1]]]

输出:

(0, 2, [1, 3])

输入:

[3, 1, [[0, 1, 1]]]

输出:

ajajaj

输入:

[5, 5, [[0, 1, 1], [1, 2, -1], [1, 2, 1], [2, 3, -1], [3, 4, -1]]]

输出:

[(0, 4, [0, 2, 3, 4])]

这是我的解决方案,它通过了测试用例和 70% 的大输入用例:

from heapq import heappush, heappop
from typing import Union


def main(entry: list) -> Union[str, tuple]:
    results = []
    # n == num of possible nodes -> if n=4 then possible nodes are 0, 1, 2, 3
    # m == num of all paths -> yet again they start at 0, so if m=4, the first path is 0 and the last one is 3
    n, m = entry.pop(0), entry.pop(0)

    # loop trough the paths and store them in a graph
    graph = [[] for _ in range(n)]
    paths = entry.pop()
    for i in range(m):
        # u == start of the path, v == end of the path, c == type of the path
        # u != v, c in {'ohniva', 'ledova'} -> {+1, -1}
        u, v, c = paths[i]
        graph[u].append((v, c))

    pq = [(0, 0, [])]
    visited = set()
    max_path = float("inf")
    while pq:
        temp, node, path = heappop(pq)
        # NOTE this probably isnt correct
        if len(path) > max_path:
            continue
        if node == n - 1:
            if len(path) < max_path:
                max_path = len(path)
            result = (temp, len(path), path)
            results.append(result)
            visited = set()

        if node not in visited:
            visited.add(node)
            for end, c in graph[node]:
                new_temp = temp + c  # +/- 1
                new_node = None
                # NOTE this probably isnt correct
                for i, p in enumerate(paths):
                    if p == [node, end, c]:
                        new_node = i
                        break

                new_path = path + [new_node]
                heappush(pq, (new_temp, end, new_path))

    if n - 1 not in visited:
        return "ajajaj"

    if len(results) > 1:
        return sorted(results, key=lambda x: (x[1], abs(x[0])))[0]
    else:
        return results[0]

算法是否有任何错误,或者整个方法是错误的?在过去的 4 个小时里我一直在试图弄清楚......

编辑: 未通过的案例示例:

输入:

[93, 289, [[62, 31, -1], [45, 27, 1], [11, 7, 1], [80, 74, 1], [15, 82, 1], [56, 12, 1], [49, 85, 1], [61, 21, -1], [90, 35, 1], [13, 68, 1], [7, 83, -1], [65, 68, -1], [44, 74, -1], [48, 59, 1], [39, 45, 1], [1, 82, 1], [32, 62, -1], [72, 82, 1], [27, 23, -1], [27, 73, 1], [69, 35, -1], [24, 77, 1], [8, 66, 1], [68, 8, -1], [14, 61, 1], [80, 76, 1], [82, 8, 1], [76, 61, -1], [48, 53, -1], [90, 33, 1], [11, 86, 1], [52, 42, -1], [46, 36, 1], [26, 69, 1], [46, 64, -1], [0, 14, -1], [31, 60, 1], [88, 11, -1], [28, 60, -1], [73, 78, 1], [52, 2, 1], [23, 82, 1], [63, 92, 1], [21, 84, -1], [80, 7, 1], [91, 49, -1], [62, 65, -1], [92, 16, -1], [13, 59, -1], [14, 40, 1], [58, 86, 1], [6, 60, 1], [21, 59, 1], [68, 12, 1], [92, 75, -1], [83, 36, -1], [90, 60, -1], [2, 84, -1], [22, 50, 1], [72, 21, -1], [47, 3, 1], [51, 9, -1], [67, 77, -1], [92, 10, -1], [80, 20, 1], [55, 5, -1], [46, 64, -1], [22, 15, 1], [87, 24, -1], [80, 71, 1], [61, 39, 1], [59, 83, 1], [60, 63, 1], [10, 12, 1], [70, 75, -1], [33, 27, -1], [75, 14, 1], [52, 4, -1], [45, 61, 1], [59, 55, -1], [30, 37, 1], [10, 38, 1], [56, 4, -1], [51, 39, -1], [35, 3, 1], [49, 37, -1], [40, 6, -1], [47, 90, -1], [20, 68, -1], [74, 38, 1], [88, 18, -1], [25, 0, 1], [51, 73, -1], [75, 91, 1], [14, 75, -1], [86, 73, -1], [52, 21, -1], [44, 89, 1], [68, 80, -1], [82, 37, 1], [59, 78, 1], [48, 43, -1], [47, 88, -1], [77, 60, -1], [32, 22, 1], [55, 6, 1], [49, 77, -1], [27, 48, 1], [46, 31, -1], [65, 57, 1], [83, 11, 1], [68, 84, 1], [29, 27, 1], [87, 59, 1], [75, 41, -1], [46, 44, -1], [67, 29, -1], [75, 55, -1], [10, 19, 1], [52, 46, 1], [12, 20, -1], [0, 4, 1], [39, 27, -1], [12, 28, -1], [1, 61, 1], [79, 34, -1], [45, 79, -1], [13, 86, 1], [20, 74, -1], [35, 60, -1], [10, 89, -1], [70, 44, 1], [14, 3, -1], [81, 7, 1], [3, 78, -1], [52, 6, -1], [62, 73, -1], [0, 34, 1], [2, 45, -1], [50, 25, 1], [73, 63, -1], [70, 92, -1], [64, 80, 1], [66, 53, -1], [35, 7, 1], [0, 84, 1], [85, 14, 1], [2, 42, 1], [26, 80, -1], [18, 24, -1], [31, 86, 1], [78, 45, -1], [21, 66, -1], [61, 57, -1], [46, 49, -1], [19, 82, -1], [55, 30, -1], [1, 6, -1], [29, 33, 1], [44, 45, 1], [66, 91, 1], [42, 58, 1], [56, 26, -1], [36, 48, 1], [32, 41, 1], [12, 90, -1], [92, 24, -1], [76, 47, 1], [47, 25, -1], [90, 36, -1], [22, 37, -1], [70, 57, 1], [51, 31, 1], [32, 13, -1], [39, 10, -1], [13, 36, 1], [67, 50, 1], [13, 24, 1], [11, 12, 1], [26, 51, -1], [54, 47, -1], [19, 43, 1], [76, 88, 1], [40, 39, -1], [75, 91, -1], [31, 92, 1], [36, 13, -1], [51, 47, -1], [14, 1, 1], [92, 17, -1], [87, 79, -1], [16, 9, -1], [17, 84, -1], [69, 43, -1], [33, 5, 1], [23, 17, -1], [20, 49, 1], [0, 61, 1], [25, 9, 1], [77, 12, 1], [80, 44, -1], [52, 23, 1], [18, 1, 1], [75, 50, -1], [86, 92, -1], [52, 6, -1], [37, 51, 1], [20, 91, 1], [7, 85, -1], [76, 48, 1], [70, 11, 1], [78, 75, 1], [57, 16, -1], [62, 31, -1], [29, 3, -1], [34, 79, -1], [50, 71, 1], [6, 90, 1], [13, 77, -1], [62, 54, 1], [24, 38, -1], [49, 7, 1], [88, 82, -1], [73, 8, -1], [13, 17, 1], [78, 66, 1], [75, 6, -1], [9, 51, -1], [31, 58, -1], [76, 74, -1], [54, 34, -1], [85, 59, -1], [69, 68, -1], [33, 67, 1], [36, 55, -1], [92, 23, 1], [28, 65, 1], [48, 24, -1], [2, 71, -1], [53, 59, -1], [78, 61, 1], [82, 79, 1], [91, 58, 1], [82, 76, -1], [61, 70, -1], [92, 10, -1], [4, 26, -1], [76, 86, 1], [24, 20, 1], [41, 59, -1], [44, 46, -1], [64, 33, -1], [84, 14, 1], [54, 1, -1], [21, 82, 1], [77, 8, 1], [10, 40, 1], [74, 33, 1], [77, 15, -1], [57, 78, 1], [24, 26, -1], [36, 2, -1], [74, 87, -1], [83, 90, 1], [63, 49, -1], [12, 91, 1], [54, 36, -1], [72, 26, -1], [73, 36, 1], [35, 2, 1], [70, 72, 1], [73, 26, 1], [76, 23, -1], [59, 69, -1], [27, 5, -1], [87, 24, 1], [61, 84, -1], [77, 33, 1], [63, 68, -1], [87, 36, -1], [20, 77, -1], [31, 11, -1], [90, 63, -1], [51, 62, 1], [91, 77, 1], [13, 7, -1], [18, 55, -1], [75, 33, -1], [56, 74, -1]]]

我的输出:

(-2, 4, [35, 24, 244, 141])

正确输出:

(-1, 3, [35, 76, 54])
python algorithm data-structures shortest-path heap
1个回答
0
投票

如果是一个选项:

import networkx as nx

def sp_on_2c(arr):
    n, m, steps = arr
    G = nx.DiGraph()

    for idx, (u, v, c) in enumerate(steps):
        G.add_edge(u, v, weight=c, indices=idx)

    try:
        # this can be replace with shortest_path but test2 has two !
        sp = list(nx.all_shortest_paths(G, source=0, target=n-1))[-1]
        pg = nx.path_graph(sp).edges()

        return (
            sum(G[u][v]["weight"] for u,v in pg),
            len(sp) - 1,
            [G[u][v]["indices"] for u,v in pg],
        )
    except (nx.NodeNotFound, nx.NetworkXNoPath):
        return "ajajaj"

输出:

for idx, test in enumerate([test1, test2, test3, test4]):
    print(f"test{idx+1} >>", sp_on_2c(test))
    
# test1 >> (-2, 2, [0, 1])
# test2 >> (0, 2, [1, 3])
# test3 >> ajajaj
# test4 >> (0, 4, [0, 2, 3, 4])

图表:

使用的输入:

test1 = [4, 3, [[0, 1, -1], [1, 3, -1], [1, 2, 1]]]
test2 = [4, 4, [[0, 1, 1], [0, 2, -1], [1, 3, -1], [2, 3, 1]]]
test3 = [3, 1, [[0, 1, 1]]]
test4 = [5, 5, [[0, 1, 1], [1, 2, -1], [1, 2, 1], [2, 3, -1], [3, 4, -1]]]
© www.soinside.com 2019 - 2024. All rights reserved.