我正在写这篇论文链接到论文。我想在本文的等式 11 中实现两个广义 Dirichlet 分布之间的 KL-Divergence,请参见下面的屏幕截图:
Alpha_1 和 Beta_1 是从解码器网络中估计出来的,例如:
decoder_alpha = torch.tensor([[0.2, 0.3, 0.3, 0.4],[0.2, 0.3, 0.3, 0.4],[0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4]])
decoder_beta = torch.tensor([[0.3, 0.6, 0.4, 0.8],[0.2, 0.3, 0.3, 0.4],[0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4], [0.2, 0.3, 0.3, 0.4]])
此外,Alpha_2 和 Beta_2 是先验的:
prior_alpha = torch.tensor([[0.1, 0.1, 0.4, 0.1],[0.8, 0.7, 0.1, 0.4],[0.2, 0.8, 0.9, 0.1], [0.1, 0.5, 0.2, 0.4], [0.1, 0.2, 0.1, 0.4], [0.2, 0.1, 0.3, 0.3]])
prior_beta = torch.tensor([[0.7, 0.6, 0.1, 0.2],[0.5, 0.8, 0.1, 0.2],[0.2, 0.8, 0.5, 0.4], [0.2, 0.6, 0.1, 0.4], [0.6, 0.8, 0.3, 0.2], [0.2, 0.6, 0.3, 0.9]])
这是我的实现:
decoderParamSum = decoder_alpha + decoder_beta
priorParamSum = prior_alpha + prior_beta
alphaParamsDiff = decoder_alpha - prior_alpha
numerator = torch.lgamma(decoderParamSum) + torch.lgamma(prior_alpha) + torch.lgamma(prior_beta)
denomirator = torch.lgamma(decoder_alpha) + torch.lgamma(decoder_beta) + torch.lgamma(priorParamSum)
firstTerm = torch.sum((numerator - denomirator),dim=1)
secondTerm = torch.sum((torch.digamma(decoderParamSum)-torch.digamma(decoder_beta)), dim=1)
secondTerm = torch.reshape(secondTerm, (input.shape[0], 1))
secondTerm = torch.digamma(decoder_alpha) - torch.digamma(decoder_beta) - secondTerm
secondTerm = torch.sum(torch.multiply(alphaParamsDiff, secondTerm), dim=1)
thirdTerm = torch.cumsum((torch.digamma(decoderParamSum)-torch.digamma(decoder_beta)), dim=1)
thirdTerm = torch.reshape(thirdTerm,(input.shape[0], 1))
v1 = torch.cat([decoder_beta[:,:-1] -decoder_alpha[:, 1:] - decoder_beta[:, 1:], decoder_beta[:, -1:] - 1], dim=-1)
v2 = torch.cat([prior_beta[:,:-1] -prior_alpha[:, 1:] - prior_beta[:, 1:], prior_beta[:, -1:] - 1], dim=-1)
thirdTerm = torch.sum((torch.multiply((v1-v2), thirdTerm)), dim=1)
KLD = firstTerm - secondTerm + thirdTerm
然而,我得到的损失值与我的预期相去甚远,负值很大,而且不稳定。所以,我猜我的实现有问题。
任何人都可以检查一下我对 KL 散度的实现,或者是否已经存在 python 实现(我已经上网,但没有任何结果)。提前谢谢你。