关于计算log_sum_exp的一个代码段

问题描述 投票:0回答:1

在此tutorial on using Pytorch to implement BiLSTM-CRF中,作者实现了以下功能。具体而言,我不太了解max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])试图做什么?或者它对应于哪种数学公式?

# Compute log sum exp in a numerically stable way for the forward algorithm
def log_sum_exp(vec):
    max_score = vec[0, argmax(vec)]
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + \
        torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))
deep-learning pytorch lstm
1个回答
0
投票

查看代码,似乎vec的形状为(1, n)。现在我们可以逐行遵循代码:

max_score = vec[0, argmax(vec)]

在位置vec中使用0, argmax(v)仅是获取vec最大值的一种理想方法。因此,max_score是(顾名思义)vec的最大值。

max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])

接下来,我们要从max_score的每个元素中减去vec。为此,代码创建一个与shape相同的vec的向量,并且所有元素都等于max_score。首先,使用max_score命令将view整形为二维,然后使用view命令将扩展的2d向量“拉伸”为长度n

最后,对数总和exp进行了稳健的计算:

expand

此计算的有效性可以在这张图片中看到:expand

其基本原理是 return max_score + \ torch.log(torch.sum(torch.exp(vec - max_score_broadcast))) 可以“爆炸” enter image description here,因此,为了保持数值稳定性,最好减去最大值之前

exp(x)
© www.soinside.com 2019 - 2024. All rights reserved.