贝叶斯 MMM 用 Pytensor 代替 Theano

问题描述 投票:0回答:1

我有一段使用 Theano 执行几何广告库存衰减的代码。 这是一段旧代码,我需要使用最新版本的 PyTensor 来更新它。 有人可以帮我转换一下吗?

def adstock_geometric_theano_pymc3(x, theta):
    x = tt.as_tensor_variable(x)
    
    def adstock_geometric_recurrence_theano(index, 
                                            input_x, 
                                            decay_x,   
                                            theta):
        return tt.set_subtensor(decay_x[index], 
               tt.sum(input_x + theta * decay_x[index - 1]))
    len_observed = x.shape[0]
    x_decayed = tt.zeros_like(x)
    x_decayed = tt.set_subtensor(x_decayed[0], x[0])
    output, _ = theano.scan(
        fn = adstock_geometric_recurrence_theano, 
        sequences = [tt.arange(1, len_observed), x[1:len_observed]], 
        outputs_info = x_decayed,
        non_sequences = theta, 
        n_steps = len_observed - 1
    )
    
    return output[-1]
theano pymc
1个回答
0
投票

首先我将分享转换后的代码,然后解释一切如何以及为何有效:

假设

x
是随时间变化的一系列广告支出
theta
是衰减率。我将使用一小组数字表示 x 和 theta 的假设值。

示例数据:

x
:10个时间段内的广告支出, 例如,
[100, 120, 90, 110, 95, 105, 115, 100, 130, 125]
theta
: 衰减率,比方说
0.5

import torch

def adstock_geometric_pytensor(x, theta):
    x = torch.tensor(x, dtype=torch.float32)
    theta = torch.tensor(theta, dtype=torch.float32)

    def adstock_geometric_recurrence_pytensor(index, input_x, decay_x, theta):
        decay_x[index] = input_x + theta * decay_x[index - 1]
        return decay_x

    len_observed = x.shape[0]
    x_decayed = torch.zeros_like(x)
    x_decayed[0] = x[0]

    for index in range(1, len_observed):
        x_decayed = adstock_geometric_recurrence_pytensor(index, x[index], x_decayed, theta)

    return x_decayed

# Example usage
x_data = [100, 120, 90, 110, 95, 105, 115, 100, 130, 125] # Advertising expenditures
theta_value = 0.5 # Decay rate
output = adstock_geometric_pytensor(x_data, theta_value)
print(output)

原始代码问题

您的原始 Theano 代码使用

theano.scan
,这是一个强大的工具,用于以针对并行计算优化的方式循环序列。这是
Theano
中有效处理递归操作的常用方法。但是,当切换到
PyTorch
(
PyTensor
) 时,没有与
theano.scan
直接等效的内容。 PyTorch 倾向于支持 Python 中的显式循环,这种循环优化程度较低,但更简单。

修改代码说明

在修改后的

PyTorch
代码中,我用标准 Python for 循环替换了
theano.scan
。此循环迭代地应用 adstock 转换。此更改牺牲了一些计算效率,但保留了核心功能。

© www.soinside.com 2019 - 2024. All rights reserved.