使用 ScipyBoundedMinimize 和 Optax 进行慢速 JAX 优化 - 寻求加速策略

问题描述 投票:0回答:1

我正在致力于优化

jax
中的模型,该模型涉及将大型观测数据集(4800 个数据点)与包含插值的复杂模型进行拟合。当前使用
jaxopt.ScipyBoundedMinimize
的优化过程需要大约30秒进行100次迭代,其中大部分时间似乎花费在第一次迭代开始期间或之前。您可以在下面找到相关的代码片段。您可以在以下链接找到相关代码的必要数据。

必要的数据(idc、sg和cpcs)

import jax.numpy as jnp
import time as ela_time
from jaxopt import ScipyBoundedMinimize
import optax
import jax
import pickle


file1 = open('idc.pkl', 'rb')
idc = pickle.load(file1)
file1.close()

file2 = open('sg.pkl', 'rb')
sg = pickle.load(file2)
file2.close()

file3 = open('cpcs.pkl', 'rb')
cpcs = pickle.load(file3)
file3.close()


def model(fssc, fssh, time, rv, amp):

    fssp = 1.0 - (fssc + fssh)

    ivis = cpcs['common'][time]['ivis']
    areas = cpcs['common'][time]['areas']
    mus = cpcs['common'][time]['mus']

    vels = idc['vels'].copy()

    ldfs_phot = cpcs['line'][time]['ldfs_phot']
    ldfs_cool = cpcs['line'][time]['ldfs_cool']
    ldfs_hot = cpcs['line'][time]['ldfs_hot']

    lps_phot = cpcs['line'][time]['lps_phot']
    lps_cool = cpcs['line'][time]['lps_cool']
    lps_hot = cpcs['line'][time]['lps_hot']

    lis_phot = cpcs['line'][time]['lis_phot']
    lis_cool = cpcs['line'][time]['lis_cool']
    lis_hot = cpcs['line'][time]['lis_hot']

    coeffs_phot = lis_phot * ldfs_phot * areas * mus
    wgt_phot = coeffs_phot * fssp[ivis]
    wgtn_phot = jnp.sum(wgt_phot)

    coeffs_cool = lis_cool * ldfs_cool * areas * mus
    wgt_cool = coeffs_cool * fssc[ivis]
    wgtn_cool = jnp.sum(wgt_cool)

    coeffs_hot = lis_hot * ldfs_hot * areas * mus
    wgt_hot = coeffs_hot * fssh[ivis]
    wgtn_hot = jnp.sum(wgt_hot)

    prf = jnp.sum(wgt_phot[:, None] * lps_phot + wgt_cool[:, None] * lps_cool + wgt_hot[:, None] * lps_hot, axis=0)
    prf /= wgtn_phot + wgtn_cool + wgtn_hot

    prf = jnp.interp(vels, vels + rv, prf)

    prf = prf + amp

    avg = jnp.mean(prf)

    prf = prf / avg

    return prf


def loss(x0s, lmbd):

    noes = sg['noes']

    noo = len(idc['times'])

    fssc = x0s[:noes]
    fssh = x0s[noes: 2 * noes]
    fssp = 1.0 - (fssc + fssh)
    rv = x0s[2 * noes: 2 * noes + noo]
    amp = x0s[2 * noes + noo: 2 * noes + 2 * noo]

    chisq = 0
    for i, itime in enumerate(idc['times']):
        oprf = idc['data'][itime]['prf']
        oprf_errs = idc['data'][itime]['errs']

        nop = len(oprf)

        sprf = model(fssc=fssc, fssh=fssh, time=itime, rv=rv[i], amp=amp[i])

        chisq += jnp.sum(((oprf - sprf) / oprf_errs) ** 2) / (noo * nop)

    wp = sg['grid_areas'] / jnp.max(sg['grid_areas'])

    mem = jnp.sum(wp * (fssc * jnp.log(fssc / 1e-5) + fssh * jnp.log(fssh / 1e-5) +
                    (1.0 - fssp) * jnp.log((1.0 - fssp) / (1.0 - 1e-5)))) / sg['noes']

    ftot = chisq + lmbd * mem

    return ftot


if __name__ == '__main__':

    # idc: a dictionary containing observational data (150 x 32)
    # sg and cpcs: dictionaries with related coefficients

    noes = sg['noes']
    lmbd = 1.0
    maxiter = 1000
    tol = 1e-5

    fss = jnp.ones(2 * noes) * 1e-5
    x0s = jnp.hstack((fss, jnp.zeros(len(idc['times']) * 2)))

    minx0s = [1e-5] * (2 * noes) + [-jnp.inf] * len(idc['times']) * 2
    maxx0s = [1.0 - 1e-5] * (2 * noes) + [jnp.inf] * len(idc['times']) * 2

    bounds = (minx0s, maxx0s)

    start = ela_time.time()

    optimizer = ScipyBoundedMinimize(fun=loss, maxiter=maxiter, tol=tol, method='L-BFGS-B',
                                 options={'disp': True})
    x0s, info = optimizer.run(x0s, bounds,  lmbd)

    # optimizer = optax.adam(learning_rate=0.1)
    # optimizer_state = optimizer.init(x0s)
    #
    # for i in range(1, maxiter + 1):
    #
    #     print('ITERATION -->', i)
    #
    #     gradients = jax.grad(loss)(x0s, lmbd)
    #     updates, optimizer_state = optimizer.update(gradients, optimizer_state, x0s)
    #     x0s = optax.apply_updates(x0s, updates)
    #     x0s = jnp.clip(x0s, jnp.array(minx0s), jnp.array(maxx0s))
    #     print('Objective function: {:.3E}'.format(loss(x0s, lmbd)))

    end = ela_time.time()

    print(end - start)   # total elapsed time: ~30 seconds

以下是相关方面的细分:

  • 自由参数数量(
    x0s
    ): 5263
  • 数据:存储在
    idc
    字典中的观测数据(4800个数据点)
  • 模型:
    model
    函数中定义,也利用插值
  • 尝试过的优化方法:
    • jaxopt.ScipyBoundedMinimize
      使用
      L-BFGS-B
      方法(慢约 30 秒,大部分时间花在第一次迭代期间或之前)
    • optax.adam(太慢了~200秒)
  • 尝试并行化:我尝试并行化
    optax.adam
    ,但由于建模的固有性质,我无法成功,因为
    x0s
    无法分割。 (假设我正确理解并行化)

问题:

  1. ScipyBoundedMinimize
    中第一次迭代之前或期间缓慢的潜在原因是什么?
  2. jax
    中是否有其他优化算法对于我的场景可能更快(大量自由参数和数据点,带有插值的复杂模型)?
  3. 我是否误解了
    optax.adam
    的并行化?在这种情况下,有什么潜在并行化的策略吗?
  4. 所提供的代码片段中是否有任何可以提高性能的代码优化(例如矢量化)?

附加信息:

  • 硬件: Intel® Core™ i7-9750H CPU @ 2.60GHz × 12、16 GiB RAM(笔记本电脑)
  • 软件:操作系统Ubuntu 22.04,Python 3.10.12,JAX 0.4.25,optax 0.2.1

如果您有任何关于提高优化性能的见解或建议,我将不胜感激。

python performance optimization jax
1个回答
0
投票

JAX 代码是即时 (JIT) 编译的,这意味着第一步的持续时间较长可能与编译成本有关。代码越长,编译所需的时间就越长。

导致编译时间过长的一个常见问题是使用 Python 控制流,例如

for
循环。 JAX 的跟踪机制本质上使这些循环变平(请参阅JAX Sharp Bits:控制流)。在您的例子中,您在数据结构中循环了 4800 多个条目,因此创建了一个非常长且低效的程序。

此类情况的典型解决方案是使用

jax.vmap
重写程序。与大多数 JAX 构造一样,这最适合使用 struct-of-arrays 模式,而不是数据中使用的 array-of-structs 模式。因此,使用
vmap
的第一步是以 JAX 可以使用的方式重组数据;它可能看起来像这样:

itimes = jnp.arange(len(idc['times']))
prf = jnp.array([idc['data'][i]['prf'] for i in itimes])
errs = jnp.array([idc['data'][i]['errs'] for i in itimes])

sprf = jax.vmap(model, in_axes=[None, None, 0, 0, 0])(fssc, fssh, itimes, rv, amp)
chi2 = jnp.sum((oprf - sprf) / oprf_errs) ** 2) / len(times) / oprf.shape[1]

请注意,这假设

idc['data'][i]['prf']
idc['data'][i]['errs']
的每个条目都具有相同的形状。如果没有,那么我担心您的问题不是特别适合 JAX 的 SPMD 编程模型,并且没有一种简单的方法可以解决长时间编译的需要。

© www.soinside.com 2019 - 2024. All rights reserved.