如何从对数概率分布中抽样?

问题描述 投票:0回答:1

我有一些使用对数概率的代码。当我想从概率分布中抽取样本时,我使用

import numpy as np

probs = np.exp(logprobs)
probs /= probs.sum()
sample = np.random.choice(X, p=probs, size=1)[0]

但是这里的求幂和除法有一些开销。而 numpy

random.choice
函数要求概率在 0 和 1 之间,并且总和为 1.

有什么快速的技巧可以让我使用非归一化对数概率数组来抽样吗?我一次只需要一个样本,绘制它的频率只需要与对数概率成正比。

python numpy statistics probability sampling
1个回答
0
投票

使用Gumbel-max技巧。在 this answer on Cross Validated 中查看更多解释和参考资料。这是一个最小的代码示例:

import numpy as np

# Assume we only have log-probabilities (for sampling, even logits will do)
log_probs = np.log([0.1, 0.2, 0.3, 0.4])
num_categories = len(log_probs)

# Sample a single category
gumbels = np.random.gumbel(size=num_categories)
sample = np.argmax(log_probs + gumbels)

矢量化实现

请注意,此功能目前没有与

np.random.choice
相同的界面。它只能通过替换进行采样,并且只返回索引。

from typing import Union

import numpy as np


def random_choice_log_space(
    logits: np.ndarray,
    size: int = 1,
    random_state: Union[np.random.RandomState, int] = None,
) -> np.ndarray:
    """
    Sample (with replacement) from a categorical distribution parametrized by logits or
    log-probabilities.

    Parameters
    ----------
    logits : np.ndarray
        the last dimension contains log-probabilities (e.g., out of a log-softmax
        function) or unnormalized logits corresponding to the categorical
        distribution(s)
    size : int, optional
        sample size, by default 1
    random_state : Union[np.random.RandomState, int], optional
        ``np.random.RandomState`` object or an integer seed, by default None

    Returns
    -------
    np.ndarray
        sampled indexes

    Raises
    ------
    ValueError
        if `size` is negative
    """
    if size < 0:
        raise ValueError("size must be at least 1.")
    # Independently sample as many Gumbels as needed. During addition, they'll be
    # broadcasted
    _gumbels_shape = (size,) + logits.shape if size > 1 else logits.shape
    # Create a RandomState if needed
    if not isinstance(random_state, np.random.RandomState):
        random_state = np.random.RandomState(seed=random_state)
    gumbels = random_state.gumbel(size=_gumbels_shape)
    gumbels_rescaled: np.ndarray = logits + gumbels
    return gumbels_rescaled.argmax(axis=-1)

如文档字符串中所述,您可以将对数概率或非标准化 logits 传递给输入

logits
。那是因为这两个输入仅相差一个常数——具体来说,log-sum-exp(probabilities)——这是无关紧要的,因为采用了 argmax。

快速而肮脏的统计检查

为了使

random_choice_log_space
正确,它需要从
logits
隐含的概率分布中独立采样。独立性部分已经很清楚了。所以我们只需要将样本的经验分布与实际分布进行比较即可。

import numpy as np
import pandas as pd
from scipy.special import logsumexp, softmax


_probs = np.array([0.1, 0.2, 0.3, 0.4])

log_probs = np.log(_probs)
logits = np.log(_probs) + logsumexp(_probs, axis=-1)
# You start out with access to log_probs or logits

num_categories = len(_probs)

sample_size = 500_000
seed = 123
random_state = np.random.RandomState(seed)


# helper function
def empirical_distr(discrete_samples):
    return (pd.Series(discrete_samples)
            .value_counts(normalize=True)
            .sort_index()
            .to_numpy())


# np.random.choice (select one at a time) AKA vanilla sampling
def random_choice_log_space_vanilla(logits, size, random_state=None):
    probs = softmax(logits, axis=-1)
    if not isinstance(random_state, np.random.RandomState):
        random_state = np.random.RandomState(seed=random_state)
    return random_state.choice(len(probs), p=probs, size=size, replace=True)

samples = random_choice_log_space_vanilla(logits, size=sample_size, random_state=random_state)
distr_vanilla = empirical_distr(samples)


# random_choice_log_space for log-probabilities input
samples = random_choice_log_space(log_probs, size=sample_size, random_state=random_state)
distr_log_probs = empirical_distr(samples)


# random_choice_log_space for logits input
samples = random_choice_log_space(logits, size=sample_size, random_state=random_state)
distr_logits = empirical_distr(samples)
print(pd.DataFrame({'rel error (vanilla)': (distr_vanilla - _probs)/_probs,
                    'rel error (log-probs)': (distr_log_probs - _probs)/_probs,
                    'rel error (logits)': (distr_logits - _probs)/_probs},
                   index=pd.Index(range(num_categories), name='category')))
          rel error (vanilla)  rel error (log-probs)  rel error (logits)
category                                                                
0                   -0.005760               0.000560           -0.002960
1                    0.004730               0.000380           -0.002560
2                   -0.002367              -0.000567            0.004273
3                    0.000850               0.000095           -0.001185

效率

如果

random_choice_log_space
总是比 softmaxing 和使用
np.random.choice
慢,那么这些工作都无关紧要。幸运的是,在这种情况下存在一个足够普遍的问题。这是你问题中的问题:你有
logits
,你想对一个元素进行采样。

from time import time

from scipy.stats import trim_mean


def time_func(func, *args, num_replications: int=50, **kwargs) -> list[float]:
    '''
    Returns a list, `times`, where `times[i]` is the time it took to run
    `func(*args, **kwargs)` at replication `i` for `i in range(num_replications)`.
    '''
    times = []
    for _ in range(num_replications):
        time_start = time()
        _ = func(*args, **kwargs)
        time_end = time()
        times.append(time_end - time_start)
    return times


category_sizes = np.power(2, np.arange(1, 14+1))
num_replications = 100

times_vanilla = []
times_gumbel = []
for size in category_sizes:
    logits = np.random.normal(size=size)
    times_vanilla.append(time_func(random_choice_log_space_vanilla, logits, size=1,
                                   num_replications=num_replications))
    times_gumbel.append(time_func(random_choice_log_space, logits, size=1,
                                  num_replications=num_replications))


(pd.DataFrame({'vanilla': trim_mean(times_vanilla, 0.1, axis=1),
               'Gumbel': trim_mean(times_gumbel, 0.1, axis=1)},
              index=pd.Index(category_sizes, name='# categories'))
 .plot.bar(title='Categorical sampling',
           figsize=(8,5),
           ylabel='mean wall-clock time (sec)'));

情节(我多次运行它)表明,对于较大的类别大小,比较变得不稳定。

注:预先计算并在绘制中重复使用 Gumbel 样本 将显着加快技巧。但我担心抽取的样本会依赖于常见的 Gumbel 样本。 (独立于 logits 数据的 Gumbel 样本是不够的,这是链接评论所说的。)我将进一步研究这个想法并在此处更新。

© www.soinside.com 2019 - 2024. All rights reserved.