代码调试:如何在python中实现广义狄利克雷分布KL-Divergence?

问题描述 投票:0回答:0

我正在写这篇论文链接到论文。我想在本文的等式 11 中实现两个广义 Dirichlet 分布之间的 KL-Divergence,请参见下面的屏幕截图:

Alpha_1 和 Beta_1 是从解码器网络中估计出来的,例如:

decoder_alpha = torch.tensor([[0.2, 0.3, 0.3, 0.4],[0.2, 0.3, 0.3, 0.4],[0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4]])

decoder_beta = torch.tensor([[0.3, 0.6, 0.4, 0.8],[0.2, 0.3, 0.3, 0.4],[0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4]])

此外,Alpha_2 和 Beta_2 是先验的:

prior_alpha = torch.tensor([[0.1, 0.1, 0.4, 0.1],[0.8, 0.7, 0.1, 0.4],[0.2, 0.8, 0.9, 0.1], [0.1, 0.5, 0.2, 0.4], [0.1, 0.2, 0.1, 0.4], [0.2, 0.1, 0.3, 0.3]])

prior_beta = torch.tensor([[0.7, 0.6, 0.1, 0.2],[0.5, 0.8, 0.1, 0.2],[0.2, 0.8, 0.5, 0.4], [0.2, 0.6, 0.1, 0.4], [0.6, 0.8, 0.3, 0.2], [0.2, 0.6, 0.3, 0.9]])

这是我的实现:

        decoderParamSum = decoder_alpha + decoder_beta
        priorParamSum = prior_alpha + prior_beta
        alphaParamsDiff = decoder_alpha - prior_alpha
        numerator = torch.lgamma(decoderParamSum) + torch.lgamma(prior_alpha) + torch.lgamma(prior_beta)
        denomirator = torch.lgamma(decoder_alpha) + torch.lgamma(decoder_beta) + torch.lgamma(priorParamSum)
        firstTerm = torch.sum((numerator - denomirator),dim=1)
        secondTerm = torch.sum((torch.digamma(decoderParamSum)-torch.digamma(decoder_beta)), dim=1)
        secondTerm = torch.reshape(secondTerm, (input.shape[0], 1))
        secondTerm = torch.digamma(decoder_alpha) - torch.digamma(decoder_beta) - secondTerm
        secondTerm = torch.sum(torch.multiply(alphaParamsDiff, secondTerm), dim=1)
        thirdTerm = torch.cumsum((torch.digamma(decoderParamSum)-torch.digamma(decoder_beta)), dim=1)
        thirdTerm = torch.reshape(thirdTerm,(input.shape[0], 1))
        v1 = torch.cat([decoder_beta[:,:-1] -decoder_alpha[:, 1:] - decoder_beta[:, 1:], decoder_beta[:, -1:] - 1], dim=-1)
        v2 = torch.cat([prior_beta[:,:-1] -prior_alpha[:, 1:] - prior_beta[:, 1:], prior_beta[:, -1:] - 1], dim=-1)
        
        thirdTerm = torch.sum((torch.multiply((v1-v2), thirdTerm)), dim=1)
       
        KLD = firstTerm - secondTerm + thirdTerm

然而,我得到的损失值与我的预期相去甚远,负值很大,而且不稳定。所以,我猜我的实现有问题。

任何人都可以检查一下我对 KL 散度的实现,或者是否已经存在 python 实现(我已经上网,但没有任何结果)。提前谢谢你。

python pytorch distribution autoencoder
© www.soinside.com 2019 - 2024. All rights reserved.