最大和子序列O(nlogn)

问题描述 投票:0回答:1

给出N个正整数的数组A。找到给定数组的严格增加的子序列的最大和之和。Check it out例如:-对于数组[1 101 2 3 100 4 5],答案应为106。

我编写了在某些测试用例上失败的代码。谁能帮忙。

for _ in range(int(input())):
    n   = int(input())
    lst = list(map(int,input().split()))

    dp     = [0 for i in range(n+1)]
    total  = [0 for i in range(n)]
    parent = [-1 for i in range(n)]
    dp[0]  = None

    length = 0
    ans    = -float('inf')

    for i in range(n):
        low  = 1
        high = length
        ele  = lst[i]

        while(low<=high):
            mid = (low+high)//2
            if(ele<=lst[dp[mid]]):
                high = mid-1
            else:
                low = mid+1

        pos = low
        parent[i] = dp[pos-1]

        if(parent[i]==None):
            total[i] = lst[i]
        else:
            total[i] = total[parent[i]]+lst[i]

        dp[pos] = i
        if(pos>length):
            length = pos
    # print(total)
    print(max(total))
python python-3.x algorithm binary-search
1个回答
0
投票

您的问题是,尽管列表元素可能有另一个(较短的)子序列,且其总和较高,但是您总是选择适合元素的最长严格增加的子序列。

例如,如果输入为[100,1,2,200],则程序失败,因为到最后一个元素的最长严格递增的子序列为[1,2,200],因此您的程序输出203。显然,正确的答案会是300。

您可以通过仅在当前子序列的总和高于前一个之一时更新dp数组来解决此问题:

for _ in range(int(input())):
    n   = int(input())
    lst = list(map(int,input().split()))

    dp     = [0 for i in range(n+1)]
    total  = [0 for i in range(n)]
    parent = [-1 for i in range(n)]
    dp[0]  = None

    length = 0
    ans    = -float('inf')

    for i in range(n):
        low  = 1
        high = length
        ele  = lst[i]

        while(low<=high):
            mid = (low+high)//2
            if(ele<=lst[dp[mid]]):
                high = mid-1
            else:
                low = mid+1

        pos = low
        parent[i] = dp[pos-1]

        if(parent[i]==None):
            total[i] = lst[i]
        else:
            total[i] = total[parent[i]]+lst[i]

        ############change#############
        if total[i] > total[dp[pos]]:
            dp[pos] = i
        ###############################

        if(pos>length):
            length = pos
    print(max(total)) 

© www.soinside.com 2019 - 2024. All rights reserved.