Python生成器中的奇怪错误

问题描述 投票:1回答:2

我有一个Knuth算法(“跳舞链接”)的实现,行为非常奇怪。我找到了一种解决方法,但它就像魔术一样。下面的脚本测试N皇后问题的代码。该错误发生在第一个函数solve中。参数limit应该限制生成的解决方案的数量,默认值0意味着“生成所有解决方案”。

#Test generalized exact cover for n queens problem 

def solve(Cols,  Rows, SecondaryIDs=set(), limit = 0):
    for soln in solver(Cols, Rows, SecondaryIDs):
        print('solve:', limit, soln)
        yield soln
        limit -= 1
        if limit == 0: return

def solver(Cols, Rows, SecondaryIDs, solution=[]):
    live=[col for col in Cols if col not in SecondaryIDs] 
    if not live:
        yield solution
    else:
        col = min(live, key = lambda col: len(Cols[col]))        
        for row in list(Cols[col]):                                          
            solution.append(row)                                            
            columns = select(Cols, Rows, row)                        
            for soln in solver(Cols, Rows, SecondaryIDs, solution):
                yield soln
            deselect(Cols, Rows, row, columns)
            solution.pop()

def select(Cols, Rows, row):
    columns = []
    for col in Rows[row]:
        for rrow in Cols[col]:
            for ccol in Rows[rrow]:
                if ccol != col:
                    Cols[ccol].remove(rrow)
        columns.append(Cols.pop(col))  
    return columns

def deselect(Cols, Rows, row, columns): 
    for col in reversed(Rows[row]):
        Cols[col] = columns.pop()
        for rrow in Cols[col]:
            for ccol in Rows[rrow]:
                if ccol != col:
                    Cols[ccol].add(rrow)

n = 5

# From Dancing Links paper
solutionCounts = {4:2, 5:10, 6:4, 7:40, 8:92, 9:352, 10:724}

def makeRows(n):
    # There is one row for each cell.    
    rows = dict()
    for rank in range(n):
        for file in range(n):
            rows["R%dF%d"%(rank,file)] = ["R%d"%rank, "F%d"%file, "S%d"%(rank+file), "D%d"%(rank-file)]
    return rows

def makePrimary(n):
    # One primary column for each rank and file
    prim = dict()
    for rank in range(n):
        prim["R%d"%rank] = {"R%dF%d"%(rank,file) for file in range(n)}
    for file in range(n):
        prim["F%d"%file] = {"R%dF%d"%(rank,file) for rank in range(n)}
    return prim

def makeSecondary(n):
    # One secondary column for each diagonal
    second = dict()
    for s in range(2*n-1):
        second["S%d"%s] = {"R%dF%d"%(r, s-r) for r in range(max(0,s-n+1), min(s+1,n))}
    for d in range(-n+1, n):
        second["D%d"%(-d)]={"R%dF%d"%(r, r+d) for r in range(max(0,-d),min(n-d, n))}
    return second

rows = makeRows(n)
primary = makePrimary(n)
secondary = makeSecondary(n)
primary.update(secondary)
secondary = secondary.keys()
#for soln in solve(primary, rows, secondary, 15):
    #print(soln)
solutions = [s for s in solve(primary, rows, secondary)]
try:
    assert len(solutions) == solutionCounts[n]
except AssertionError:
    print("actual %d expected %d"%(len(solutions), solutionCounts[n]))
for soln in solutions:print(soln)

代码设置为生成5个皇后问题的前6个解决方案,并且它工作正常。 (见电话

solutions = [s for s in solve(primary, rows, secondary, 6)]

在第80行。)实际上有10个解决方案,如果我要求10个解决方案,我会得到它们。如果我放弃限制,那么呼叫是

solutions = [s for s in solve(primary, rows, secondary)]

主程序打印十个空列表[]作为解决方案,但solve中的代码打印出真正的解决方案。如果我达到限制15,也会发生同样的事情。

当我将生成器转换为第80行的列表时,似乎出现了这个问题。如果我在第78行和第79行放回注释掉的行,并注释掉第80行的所有内容,程序就像我期望的那样工作。但我不明白这一点;我经常以这种方式列出生成器返回的对象列表。

另一件更奇怪的事情是,如果我改变第13行来阅读

yield list(solution)

那么第80行的代码在所有情况下都能正常工作。我不记得在我最初编写代码时我是如何偶然发现这个问题的。我今天看着它,并将yield list(solution)改为yield solution,那时虫子变得明显了。我无法理解这一点; solution已经是一个名单。事实上,我已经尝试添加该行

assert solution == list(solution)

就在第13行之前,它永远不会引发AssertionError

我完全失去了。我试图制作一个可以重现这种行为的小脚本,但我无法做到。你明白发生了什么,并且(更难)你能解释一下吗?

python python-3.x
2个回答
3
投票
yield solution

问题是你正在产生一个列表,你随后添加和删除项目。当调用者检查列表时,它已经改变了。您需要返回解决方案的冻结副本,以确保保留每个yield语句点的结果。其中任何一个都可以:

yield list(solution)
yield solution[:]
yield tuple(solution)

4
投票

在看到代码之前进行预测:yield list(solution)产生一个浅层的解决方案。 yield solution会自己生成解决方案列表,所以当你之后改变该列表时,就会遇到麻烦。


看起来我是对的。 :-)更短的版本:

def weird(solution):
    for i in range(len(solution)):
        yield solution
        solution.pop()

这使:

In [8]: result = list(weird(['a','b','c']))

In [9]: result
Out[9]: [[], [], []]

因为

In [10]: [id(x) for x in result]
Out[10]: [140436644005128, 140436644005128, 140436644005128]

但如果我们改为yield list(solution),我们就会得到

In [15]: list(less_weird(['a','b','c']))
Out[15]: [['a', 'b', 'c'], ['a', 'b'], ['a']]

首先我们看到你的一个可变的默认参数,这是一个坏主意,但实际上并不是你看到的bug的原因:

def solver(Cols, Rows, SecondaryIDs, solution=[]):
    live=[col for col in Cols if col not in SecondaryIDs] 
    if not live:
        yield solution

在这里你得出解决方案^ ..

else:
    col = min(live, key = lambda col: len(Cols[col]))        
    for row in list(Cols[col]):                                          
        solution.append(row)                                            
        columns = select(Cols, Rows, row)                        
        for soln in solver(Cols, Rows, SecondaryIDs, solution):
            yield soln
        deselect(Cols, Rows, row, columns)
        solution.pop()

在这里你改变你事先产生的相同列表。

© www.soinside.com 2019 - 2024. All rights reserved.