使用元组时 Python 中的关键错误。但是当我打印时,我会在按键中看到它

问题描述 投票:0回答:1

我正在尝试为 Loopy Belief Propagation 编写算法。我正在使用 Numpy 和 pGMpy。目标是首先初始化从节点到因子的消息。然后在每次迭代中,您将计算节点消息的因子,然后将节点的消息更新为因子。

对于从节点到因子的消息(M_v_to_f)和从因子到节点的消息(M_f_to_v),我使用元组作为键。 M_v_to_f 会有 M_v_to_f[('x2', )]。一次迭代后,更新 M_v_to_f。

然而,在第二次迭代时,我遇到了一个关键的错误问题。所以我打印出据称会引发密钥错误的密钥,并在 M_v_to_f 中打印密钥。问题是我看到了一个匹配项,但我不知道为什么 Python 没有响应它。这表明我实际上可以看到一个密钥。

如果有帮助,这里也是代码:

import numpy as np
import copy
from pgmpy.models import FactorGraph
from pgmpy.factors.discrete import DiscreteFactor
from pgmpy.factors import factor_product
from pgmpy.readwrite import BIFReader


def make_debug_graph():
    
    G = FactorGraph()
    G.add_nodes_from(['x1', 'x2', 'x3', 'x4'])
    
    # add factors 
    phi1 = DiscreteFactor(['x1', 'x2'], [2, 3], np.array([0.5, 0.7, 0.2,
                                                          0.5, 0.3, 0.8]))
    phi2 = DiscreteFactor(['x2', 'x3', 'x4'], [3, 2, 2], np.array([0.2, 0.25, 0.70, 0.30,
                                                                   0.4, 0.25, 0.15, 0.65,
                                                                   0.4, 0.50, 0.15, 0.05]))
    phi3 = DiscreteFactor(['x3'], [2], np.array([0.5, 
                                                 0.5]))
    phi4 = DiscreteFactor(['x4'], [2], np.array([0.4, 
                                                 0.6]))
    G.add_factors(phi1, phi2, phi3, phi4)
    
    G.add_nodes_from([phi1, phi2, phi3, phi4])
    G.add_edges_from([('x1', phi1), ('x2', phi1), ('x2', phi2), ('x3', phi2), ('x4', phi2), ('x3', phi3), ('x4', phi4)])
    
    return G
G = make_debug_graph()
def _custom_reshape(arr, shape_len, axis):
    shape = tuple([1 if i != axis else arr.shape[0] for i in range(shape_len)])
    return np.reshape(arr, shape)
# initialize M_v_to_f
M_v_to_f = {}
for var in G.get_variable_nodes():
    for factor in G.neighbors(var):
        key = (var, factor)
        print(key)
        print(M_v_to_f)
        M_v_to_f[key] = np.ones(G.get_cardinality(var))

for epoch in range(10):
    print(epoch)
    M_f_to_v = {}
    for factor in G.get_factor_nodes():
        num_axis = len(factor.values.shape)
        for j, to_node in enumerate(factor.scope()):
            incoming_msg = []
            for k, in_node in enumerate(factor.scope()):
                if j==k: continue
                key = (in_node, factor) 
# Error on here on the second iteration.
                incoming_msg.append(_custom_reshape(M_v_to_f[key], num_axis, k))
            outgoing = factor.values
            for msg in incoming_msg:
                print(msg.shape)
                outgoing *= msg
            sum_axis = list(range(num_axis))
            sum_axis.remove(j)
            outgoing = np.sum(outgoing, axis = tuple(sum_axis))
            outgoing /= np.sum(outgoing)
            key = (factor, to_node)
            M_f_to_v[key] = outgoing
    # update the M_v_to_f
    for var in G.get_variable_nodes():
        for j, factor in enumerate(G.neighbors(var)):
            incoming_msg = []
            for k, in_fact in enumerate(G.neighbors(var)):
                if j == k: continue
                key = (in_fact, var)
                incoming_msg.append(M_f_to_v[key])
            
            if incoming_msg:
                outgoing = incoming_msg[0]
                for msg in incoming_msg[1:]:
                    outgoing *= msg
                outgoing /= np.sum(outgoing)
                key = (var,factor)
                M_v_to_f[key] = outgoing
            

enter image description here

我尝试了不同的方式来使用键(事先定义元组......等等)。但是,我真的不知道如何解决这个问题。

关于打印语句,可以看出关键是:

('x2', <DiscreteFactor representing phi(x2:3, x3:2, x4:2) at 0x7f94f90db0d0>)

M_v_to_f 是:

{('x2', <DiscreteFactor representing phi(x1:2, x2:3) at 0x7f94f90db190>): array([0.3625, 0.3625, 0.275 ]), **('x2', <DiscreteFactor representing phi(x2:3, x3:2, x4:2) at 0x7f94f90db0d0>)**: array([0.33333333, 0.33333333, 0.33333333]), ('x3', <DiscreteFactor representing phi(x2:3, x3:2, x4:2) at 0x7f94f90db0d0>): array([0.5, 0.5]), ('x3', <DiscreteFactor representing phi(x3:2) at 0x7f94f90db1f0>): array([0.5, 0.5]), ('x1', <DiscreteFactor representing phi(x1:2, x2:3) at 0x7f94f90db190>): array([1., 1.]), ('x4', <DiscreteFactor representing phi(x2:3, x3:2, x4:2) at 0x7f94f90db0d0>): array([0.4, 0.6]), ('x4', <DiscreteFactor representing phi(x4:2) at 0x7f94f90db130>): array([0.5, 0.5])}
python numpy keyerror pgmpy
1个回答
2
投票

你正在改变你的字典键:

outgoing = factor.values
for msg in incoming_msg:
    print(msg.shape)
    outgoing *= msg

这打破了字典查找。

© www.soinside.com 2019 - 2024. All rights reserved.