如何在使用自定义对数似然的 PyMC V5 模型上应用 Blackjax 采样器

问题描述 投票:0回答:1

我正在使用

PyMC v5
在模型中执行哈密顿蒙特卡罗。我可以运行下面的代码,但即使有多个核心,它也非常慢。我有一个函数
applyMCMC
用于此目的:

# define a pytensor Op for our likelihood function
import pytensor.tensor as pt
import scipy.optimize
import numpy as np
from scipy.optimize import approx_fprime

# define a pytensor Op for our likelihood function
class LogLikeWithGrad(pt.Op):
    itypes = [pt.dvector]  # expects a vector of parameter values when called
    otypes = [pt.dscalar]  # outputs a single scalar value (the log likelihood)

    def __init__(self, loglike,):

        # add inputs as class attributes
        self.likelihood = loglike

        # initialise the gradient Op (below)
        self.loglike_grad = LogLikeGrad()

    def perform(self, node, inputs, outputs):
        # the method that is used when calling the Op
        (theta,) = inputs  # this will contain my variables
        # call the log-likelihood function
        logl = self.likelihood(theta,)
        outputs[0][0] = np.array(logl)  # output the log-likelihood
    
    def grad(self, inputs, grad_outputs):
        (theta,) = inputs
        grads = self.loglike_grad(theta)
        return [grad_outputs[0] * grads]
    
    """ 
    def grad(self, inputs, grad_outputs):
        theta = inputs[0]
        grad = pt.grad(self.loglike_grad(theta), theta)
        get_grad = pt.function([theta], grad)
        return [grad_outputs[0] * get_grad]
    """ 

class LogLikeGrad(pt.Op):

    """
    This Op will be called with a vector of values and also return a vector of
    values - the gradients in each dimension.
    """

    itypes = [pt.dvector]
    otypes = [pt.dvector]

    def __init__(self, ):
       pass

        # add inputs as class attributes

    def perform(self, node, inputs, outputs):
        (theta,) = inputs
        # Below works for gradient 
        grads = approx_fprime(theta, applyMCMC, epsilon=1e-8)
        # Test for gradient
        #grads = pt.grad(applyMCMC, theta)
        outputs[0][0] = grads

#############################################

import sys
import pymc as pm
import pytensor
import arviz as az
import nutpie
from utils_scipy_GRADIENT import *
import os,sys
from getdist import plots, MCSamples
import getdist
import matplotlib.pyplot as plt
import numpy as np
param_names = ["Omega_m", "Omega_k", "H0", "Psi_0", "dPsi_0_dt", "omega_BD"]
logl = LogLikeWithGrad(applyMCMC)

# Allocate variable for model
model = pm.Model()

initial_values = {
    'Omega_m': 0.3,
    'Omega_k': 1e-3,
    'H0' : 67.4,
    'Psi_0' : 1.0,
    'dPsi_0_dt' : 0.5e-3,
    'omega_BD' : 0.5e5,
}

if __name__ == '__main__':
  # Define model
  with model:
      for i, name in enumerate(param_names):
              pm.Uniform(name, lower=lower_boundaries[0][i], upper=upper_boundaries[0][i])
          #pm.Potential('custom_likelihood', pytensor.tensor.as_tensor_variable(applyMCMC((model[param] for param in param_names ))))
      theta = pt.as_tensor_variable([model[param] for param in param_names])

      
      # COMMENT OR NOT
      pm.Potential("likelihood", logl(theta))
      niter = 1000  
      start = pm.find_MAP()
      step = pm.NUTS() # Hamiltonian MCMC with No U-Turn Sampler
      trace = pm.sample(draws=niter, step=step, tune=500, cores=64, init="jitter+adapt_diag", progressbar=True)
      print(pm.summary(trace).to_string())
      trace_df = trace.to_dataframe()

使用导入的源“

utils_scipy_GRADIENT.py
”:

    import sys
    import math
    import numpy as np
    from hi_classy import Class  # Importing Hi-CLASS
    import clik  # For Planck's likelihood
    import pandas as pd
    from scipy.interpolate import CubicSpline
    from scipy.linalg import pinvh
    # Initializing Planck's likelihood
    path_to_planck_likelihood = "../..//baseline/plc_3.0/hi_l/plik_lite/plik_lite_v22_TT.clik"
    planck_likelihood = clik.clik(path_to_planck_likelihood)
    lmax = planck_likelihood.get_lmax()[0]
    A_planck = 1.0
    
    # Limits of range of parameters respectively
    lower_boundaries = np.array([[0.05, -1e-2, 50.0, 0.9, 0.0, 4e4]])
    upper_boundaries = np.array([[0.35, 1e-2, 80.0, 1.1, 1e-2, 1e5]])
    
    
    params = {
        "Omega_Lambda": "0.0",
        "Omega_fld": "0.0",
        "Omega_smg": "-1.0",
        "gravity_model": "brans dicke",
        "parameters_smg": "0.0,   800,       1.0,        1e-3",
        "M_pl_today_smg": "1.0",
        "a_min_stability_test_smg": "1e-6",
        "root": "output/brans_dicke_",
        "output": "tCl, pCl, lCl, mPk",
        "lensing": "yes",
        "l_max_scalars": str(lmax),
        "output_background_smg": "10",
        "write parameters": "no",
        "write background": "no",
        "write thermodynamics": "no",
        "input_verbose": "0",
        "background_verbose": "0",
        "output_verbose": "0",
        "thermodynamics_verbose": "0",
        "perturbations_verbose": "0",
        "spectra_verbose": "0",
        "omega_b": "0.022032",
        "omega_cdm": "0.12038",
    }
    
    ####      Solve ordinary differential equation        ####
    
    
    # Parameters
    z0 = 0
    z_past = 100 - 1  # a_min = 1 / 100
    z_future = 1 / 4 - 1  # a_max = 4
    n = 100000
    z_line_past = np.linspace(z0, z_past, num=n)
    z_line_future = np.linspace(z0, z_future, num=n)
    z_line_all = np.linspace(z_past, z_future, num=2 * n)
    
    def dH(Rho_m, Phi, u, omega_BD, Omega_k, z):
        val = (-16 * math.pi * Rho_m - 6 * (1 + z) ** 2 * Omega_k * Phi) / (
            6 * (1 + z) * u + ((1 + z) ** 2 * omega_BD * u**2) / Phi - 6 * Phi
        )
        if val >= 0:
            return -(
                (
                    (1 + z)
                    * (16 * math.pi * Rho_m + 6 * (1 + z) ** 2 * Omega_k * Phi)
                    * (
                        (1 + z) * omega_BD * u**3
                        - 2
                        * omega_BD
                        * u
                        * ((1 + z) * du(Rho_m, Phi, u, omega_BD, Omega_k, z) + u)
                        * Phi
                        - 6 * du(Rho_m, Phi, u, omega_BD, Omega_k, z) * Phi**2
                    )
                    + (
                        6
                        * Phi
                        * (
                            -8 * math.pi * Rho_m
                            + (1 + z) ** 2 * Omega_k * ((1 + z) * u + 2 * Phi)
                        )
                        * (6 * Phi**2 - (1 + z) * u * ((1 + z) * omega_BD * u + 6 * Phi))
                    )
                    / (1 + z)
                )
                / (
                    2
                    * math.sqrt(val)
                    * (
                        (1 + z) ** 2 * omega_BD * u**2
                        + 6 * (1 + z) * u * Phi
                        - 6 * Phi**2
                    )
                    ** 2
                )
            )
        else:
            return None
    
    
    d_Rho_m = lambda Rho_m, Phi, u, z: 3 / (1 + z) * Rho_m
    
    d_Phi = lambda Rho_m, Phi, u, z: u
    
    
    def du(Rho_m, Phi, u, omega_BD, Omega_k, z):
        return (
            24 * math.pi * Rho_m * Phi**3
            + (1 + z)
            * u
            * Phi**2
            * (
                8 * math.pi * (-3 + omega_BD) * Rho_m
                - 3 * (1 + z) ** 2 * (3 + 2 * omega_BD) * Omega_k * Phi
            )
            - 3
            * (1 + z) ** 2
            * u**2
            * Phi
            * (
                -4 * math.pi * omega_BD * Rho_m
                + (1 + z) ** 4 * (3 + 2 * omega_BD) * Omega_k * Phi
            )
            - omega_BD
            * u**3
            * (
                4 * math.pi * (1 + z) ** 3 * (1 + omega_BD) * Rho_m
                + (1 + z) ** 5 * (3 + 2 * omega_BD) * Omega_k * Phi
            )
        ) / (
            (1 + z) ** 2
            * (3 + 2 * omega_BD)
            * Phi**2
            * (8 * math.pi * Rho_m + 3 * (1 + z) ** 2 * Omega_k * Phi)
        )
    
    def dzeta(H):
        return 1 / H
    
    
    def RK4Method(Omega_m, Omega_k, H0, Phi_0, dPhi_0, omega_BD, zLine):
    
        zLine = np.concatenate(([1e-5], zLine))
        z_length = len(zLine)
    
        Htable = np.zeros(z_length)
        Hval = H0
        Htable[0] = Hval
    
        zeta_table = np.zeros(z_length)
        zeta = 0.0
        zeta_table[0] = zeta
    
        Rho_m = 3 * H0 * H0 * Phi_0 * Omega_m / (8 * math.pi)
        u = dPhi_0
        Phi = Phi_0
        i = 1
    
        while i < z_length:
    
            h = zLine[i] - zLine[i - 1]
    
            H_k1 = dH(Rho_m, Phi, u, omega_BD, Omega_k, zLine[i - 1])
            Phi_k1 = d_Phi(Rho_m, Phi, u, zLine[i - 1])
            Rho_m_k1 = d_Rho_m(Rho_m, Phi, u, zLine[i - 1])
            u_k1 = du(Rho_m, Phi, u, omega_BD, Omega_k, zLine[i - 1])
            zeta_k1 = dzeta(Hval)
    
            if H_k1 is None:
                return None, None
    
            H_k2 = dH(
                Rho_m + h / 2 * Rho_m_k1,
                Phi + h / 2 * Phi_k1,
                u + h / 2 * u_k1,
                omega_BD,
                Omega_k,
                zLine[i - 1] + h / 2,
            )
            Phi_k2 = d_Phi(
                Rho_m + h / 2 * Rho_m_k1,
                Phi + h / 2 * Phi_k1,
                u + h / 2 * u_k1,
                zLine[i - 1] + h / 2,
            )
            Rho_m_k2 = d_Rho_m(
                Rho_m + h / 2 * Rho_m_k1,
                Phi + h / 2 * Phi_k1,
                u + h / 2 * u_k1,
                zLine[i - 1] + h / 2,
            )
            u_k2 = du(
                Rho_m + h / 2 * Rho_m_k1,
                Phi + h / 2 * Phi_k1,
                u + h / 2 * u_k1,
                omega_BD,
                Omega_k,
                zLine[i - 1] + h / 2,
            )
            zeta_k2 = dzeta(Hval + h / 2 * H_k1)
    
            if H_k2 is None:
                return None, None
    
            H_k3 = dH(
                Rho_m + h / 2 * Rho_m_k2,
                Phi + h / 2 * Phi_k2,
                u + h / 2 * u_k2,
                omega_BD,
                Omega_k,
                zLine[i - 1] + h / 2,
            )
            Phi_k3 = d_Phi(
                Rho_m + h / 2 * Rho_m_k2,
                Phi + h / 2 * Phi_k2,
                u + h / 2 * u_k2,
                zLine[i - 1] + h / 2,
            )
            Rho_m_k3 = d_Rho_m(
                Rho_m + h / 2 * Rho_m_k2,
                Phi + h / 2 * Phi_k2,
                u + h / 2 * u_k2,
                zLine[i - 1] + h / 2,
            )
            u_k3 = du(
                Rho_m + h / 2 * Rho_m_k2,
                Phi + h / 2 * Phi_k2,
                u + h / 2 * u_k2,
                omega_BD,
                Omega_k,
                zLine[i - 1] + h / 2,
            )
            zeta_k3 = dzeta(Hval + h / 2 * H_k2)
    
            if H_k3 is None:
                return None, None
    
            H_k4 = dH(
                Rho_m + h * Rho_m_k3,
                Phi + h * Phi_k3,
                u + h * u_k3,
                omega_BD,
                Omega_k,
                zLine[i],
            )
            Phi_k4 = d_Phi(Rho_m + h * Rho_m_k3, Phi + h * Phi_k3, u + h * u_k3, zLine[i])
            Rho_m_k4 = d_Rho_m(
                Rho_m + h * Rho_m_k3, Phi + h * Phi_k3, u + h * u_k3, zLine[i]
            )
            u_k4 = du(
                Rho_m + h * Rho_m_k3,
                Phi + h * Phi_k3,
                u + h * u_k3,
                omega_BD,
                Omega_k,
                zLine[i],
            )
            zeta_k4 = dzeta(Hval + h * H_k3)
    
            if H_k4 is None:
                return None, None
    
            Hval = Hval + h * (H_k1 + 2 * H_k2 + 2 * H_k3 + H_k4) / 6
            Rho_m = Rho_m + h * (Rho_m_k1 + 2 * Rho_m_k2 + 2 * Rho_m_k3 + Rho_m_k4) / 6
            Phi = Phi + h * (Phi_k1 + 2 * Phi_k2 + 2 * Phi_k3 + Phi_k4) / 6
            u = u + h * (u_k1 + 2 * u_k2 + 2 * u_k3 + u_k4) / 6
            zeta = zeta + h * (zeta_k1 + 2 * zeta_k2 + 2 * zeta_k3 + zeta_k4) / 6
    
            Htable[i] = Hval
            zeta_table[i] = zeta
            i += 1
    
        return Htable[1:], zeta_table[1:]
    
###########################    End  Solve ordinary differential equation     
    
#############################################
    def get_spectrum(Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt):
    
        params['parameters_smg'] = f"0.0, {omega_BD:.5f}, {Psi:.5f}, {dPsi_dt:.5f}"
        params['H0'] = H0
        omega_b = float(params['omega_b'])
    
        Omega_b = omega_b/(H0/100)**2
        params['omega_cdm'] = float((Omega_m - Omega_b)*(H0/100)**2)
    
        # Set up and run CLASS
        cosmology = Class()
        cosmology.set(params)
        cosmology.compute()
    
        # Obtain the relevant spectra
        cl = cosmology.lensed_cl(lmax)
    
        tt = cl["tt"] * 10**12 * 2.7255**2
    
        cosmology.struct_cleanup()
        cosmology.empty()
    
        return tt

######################  LogLikelihood 
    
    def log_likelihood(x):
        # Supposons que x est une liste de paires de paramètres
        # Exemple : x = [[omega_BD_1, Psi_1], [omega_BD_2, Psi_2], ...]
        if len(x) == 0:
            return np.array([])
        
        # Initialisez la liste des log likelihoods
        ll = []
    
        # Bouclez sur chaque paire de paramètres
        for params in x:
            Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt = params
            # Obtenez le spectre pour la paire actuelle de paramètres
            spectrum = get_spectrum(Omega_m, Omega_k, H0, omega_BD, Psi, dPsi_dt)
            
            # Vérifiez si le spectre est valide
            if isinstance(spectrum, np.ndarray) and len(spectrum) == lmax + 1:
                # Calculez le log likelihood et ajoutez-le à la liste
                ll_value = planck_likelihood(np.concatenate([spectrum, [A_planck]])).squeeze()
                ll.append(ll_value)
            else:
                # Si le spectre n'est pas valide, ajoutez -inf au log likelihood
                ll.append(-np.inf)
        
        # Retournez la liste des log likelihoods convertie en un tableau numpy
        return np.array(ll)
    
###################### END LogLikelihood 
    
###################### Apply MCMC        
    ######################################################################
    def chi2(x, mu, err):
        return np.sum((x - mu)**2 / err**2)
    
    def chi2_nondiag(x, mu, cov_inv):
        delta = x - mu
        return np.einsum("i,ij,j", delta, cov_inv, delta)
    
    def sinn(x, Omega_k):
        return np.sinh(x) if Omega_k >= 0.0 else np.sin(x)
    
    
    dataH = np.loadtxt("./H_All.txt")
    
    G = 6.674e-11
    Omega_r = 1e-4
    H0 = 67.4
    
    data_mu = np.loadtxt(
        "./mu.txt", 
        skiprows=5,
        converters={0: lambda s: 0}
    )
    data_mu = data_mu[:, 1:4]
    z_max = max(np.max(dataH[:, 0]), np.max(data_mu[:, 0]))
    zLine = np.linspace(1e-4, z_max, 100)
    
    cov_mu = np.loadtxt("./mu_cov.txt")
    cov_mu_inv = pinvh(cov_mu)
    
    def applyMCMC(x):
        Omega_m, Omega_k, H0, Psi_0, dPsi_0_dt, omega_BD = x
    
        Phi_0 = Psi_0 / G
        dPhi_0_dz = dPsi_0_dt / (-H0 * G)
    
        Omega_BD = 1 - Omega_m - Omega_k - Omega_r
        Hvals, zeta_vals = RK4Method(Omega_m, Omega_k, H0, Phi_0, dPhi_0_dz, omega_BD, zLine)
        if Hvals is None:
            return -np.inf
        if (
            np.any(np.isnan(Hvals)) 
            or np.any(np.isnan(zeta_vals)) 
            or np.any(np.isinf(Hvals)) 
            or np.any(np.isinf(zeta_vals))
        ):
            return -np.inf
        
        zeta_vals = np.maximum(zeta_vals, 1e-30)
    
        H_sol_fun = CubicSpline(zLine, Hvals)
    
        if Omega_k == 0.0:
            mu_sol = (
                25 + 5 * np.log10(299792.458) + 5 * np.log10((1 + zLine) * zeta_vals)
            )
        else:
            mu_sol = (
                25 + 5 * np.log10(299792.458) + 5 * np.log10(
                    (1 + zLine) / H0 * np.abs(Omega_k) ** (-0.5) * sinn(
                        np.abs(Omega_k) ** 0.5 * H0 * zeta_vals, Omega_k
                    )
                )
            )
        mu_sol_fun = CubicSpline(zLine, mu_sol)
    
        chi2_H = chi2(dataH[:, 1], H_sol_fun(dataH[:, 0]), dataH[:, 2])
        chi2_mu = chi2_nondiag(data_mu[:, 1], mu_sol_fun(data_mu[:, 0]), cov_mu_inv)
    
        # Compute log-likelihood of Planck
        log_likelihood_planck = log_likelihood([[Omega_m, Omega_k, H0, omega_BD, Psi_0, dPsi_0_dt]])
    
        # Ensure that log_likelihood_planck is a scalar
        if isinstance(log_likelihood_planck, np.ndarray):
            log_likelihood_planck = log_likelihood_planck.item(0) if log_likelihood_planck.size > 0 else -np.inf
        
        if np.isnan(chi2_H) or np.isnan(chi2_mu) or np.isinf(chi2_H) or np.isinf(chi2_mu) or np.isnan(log_likelihood_planck) or np.isinf(log_likelihood_planck):
            return -np.inf
        else:
            return -0.5 * chi2_H / len(dataH) - 0.5 * chi2_mu / len(data_mu) + log_likelihood_planck / (lmax + 1)
            #return -0.5 * chi2_H / len(dataH) - 0.5 * chi2_mu / len(data_mu)
    
    ## end MCMC ##

我想优化它以实现更快的采样,因为我已经安装了带有 GPU 后端的 blackjax。所以,我尝试使用:

trace = pm.sample(nuts_sampler="blackjax", progressbar=True)

但我收到以下错误:

    return fgraph_to_python(
  File "/opt/intel/oneapi/intelpython/python3.9/lib/python3.9/site-packages/pytensor/link/utils.py", line 734, in fgraph_to_python
    compiled_func = op_conversion_fn(
  File "/opt/intel/oneapi/intelpython/python3.9/lib/python3.9/functools.py", line 888, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/opt/intel/oneapi/intelpython/python3.9/lib/python3.9/site-packages/pytensor/link/jax/dispatch/basic.py", line 41, in jax_funcify
    raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
NotImplementedError: No JAX conversion for the given `Op`: LogLikeWithGrad

有人用 PyMC 5.10 实验过这种优化问题吗?

python bayesian pymc mcmc hierarchical-bayesian
1个回答
0
投票

抱歉,无法在 JAX 方面为您提供帮助。但是您可以在 dH() 和 du() 中重构原始计算。尝试每个计算只进行一次。 喜欢:

-16 * math.pi * Rho_m - 6 * (1 + z) ** 2 * Omega_k * Phi

16 * math.pi * Rho_m + 6 * (1 + z) ** 2 * Omega_k * Phi

是一样的。 你也像这样施加力量

(1 + z) ** 2
(1 + z) ** 4

如果做成这样呢

one_plus_z_pow2 = (1 + z)**2
one_plus_z_pow4 = one_plus_z_pow2 * one_plus_z_pow2
one_plus_z_pow5 = one_plus_z_pow4 * (1 + z)

我的速度几乎快了两倍(0.7 秒 vs 0.4 秒)

while i < 1000000:
a = (1 + z) ** 2
b = a * (1 + z)
c = a * a
d = c * a
i += 1
© www.soinside.com 2019 - 2024. All rights reserved.