我正在致力于优化
jax
中的模型,该模型涉及将大型观测数据集(4800 个数据点)与包含插值的复杂模型进行拟合。当前使用jaxopt.ScipyBoundedMinimize
的优化过程需要大约30秒进行100次迭代,其中大部分时间似乎花费在第一次迭代开始期间或之前。您可以在下面找到相关的代码片段。您可以在以下链接找到相关代码的必要数据。
import jax.numpy as jnp
import time as ela_time
from jaxopt import ScipyBoundedMinimize
import optax
import jax
import pickle
file1 = open('idc.pkl', 'rb')
idc = pickle.load(file1)
file1.close()
file2 = open('sg.pkl', 'rb')
sg = pickle.load(file2)
file2.close()
file3 = open('cpcs.pkl', 'rb')
cpcs = pickle.load(file3)
file3.close()
def model(fssc, fssh, time, rv, amp):
fssp = 1.0 - (fssc + fssh)
ivis = cpcs['common'][time]['ivis']
areas = cpcs['common'][time]['areas']
mus = cpcs['common'][time]['mus']
vels = idc['vels'].copy()
ldfs_phot = cpcs['line'][time]['ldfs_phot']
ldfs_cool = cpcs['line'][time]['ldfs_cool']
ldfs_hot = cpcs['line'][time]['ldfs_hot']
lps_phot = cpcs['line'][time]['lps_phot']
lps_cool = cpcs['line'][time]['lps_cool']
lps_hot = cpcs['line'][time]['lps_hot']
lis_phot = cpcs['line'][time]['lis_phot']
lis_cool = cpcs['line'][time]['lis_cool']
lis_hot = cpcs['line'][time]['lis_hot']
coeffs_phot = lis_phot * ldfs_phot * areas * mus
wgt_phot = coeffs_phot * fssp[ivis]
wgtn_phot = jnp.sum(wgt_phot)
coeffs_cool = lis_cool * ldfs_cool * areas * mus
wgt_cool = coeffs_cool * fssc[ivis]
wgtn_cool = jnp.sum(wgt_cool)
coeffs_hot = lis_hot * ldfs_hot * areas * mus
wgt_hot = coeffs_hot * fssh[ivis]
wgtn_hot = jnp.sum(wgt_hot)
prf = jnp.sum(wgt_phot[:, None] * lps_phot + wgt_cool[:, None] * lps_cool + wgt_hot[:, None] * lps_hot, axis=0)
prf /= wgtn_phot + wgtn_cool + wgtn_hot
prf = jnp.interp(vels, vels + rv, prf)
prf = prf + amp
avg = jnp.mean(prf)
prf = prf / avg
return prf
def loss(x0s, lmbd):
noes = sg['noes']
noo = len(idc['times'])
fssc = x0s[:noes]
fssh = x0s[noes: 2 * noes]
fssp = 1.0 - (fssc + fssh)
rv = x0s[2 * noes: 2 * noes + noo]
amp = x0s[2 * noes + noo: 2 * noes + 2 * noo]
chisq = 0
for i, itime in enumerate(idc['times']):
oprf = idc['data'][itime]['prf']
oprf_errs = idc['data'][itime]['errs']
nop = len(oprf)
sprf = model(fssc=fssc, fssh=fssh, time=itime, rv=rv[i], amp=amp[i])
chisq += jnp.sum(((oprf - sprf) / oprf_errs) ** 2) / (noo * nop)
wp = sg['grid_areas'] / jnp.max(sg['grid_areas'])
mem = jnp.sum(wp * (fssc * jnp.log(fssc / 1e-5) + fssh * jnp.log(fssh / 1e-5) +
(1.0 - fssp) * jnp.log((1.0 - fssp) / (1.0 - 1e-5)))) / sg['noes']
ftot = chisq + lmbd * mem
return ftot
if __name__ == '__main__':
# idc: a dictionary containing observational data (150 x 32)
# sg and cpcs: dictionaries with related coefficients
noes = sg['noes']
lmbd = 1.0
maxiter = 1000
tol = 1e-5
fss = jnp.ones(2 * noes) * 1e-5
x0s = jnp.hstack((fss, jnp.zeros(len(idc['times']) * 2)))
minx0s = [1e-5] * (2 * noes) + [-jnp.inf] * len(idc['times']) * 2
maxx0s = [1.0 - 1e-5] * (2 * noes) + [jnp.inf] * len(idc['times']) * 2
bounds = (minx0s, maxx0s)
start = ela_time.time()
optimizer = ScipyBoundedMinimize(fun=loss, maxiter=maxiter, tol=tol, method='L-BFGS-B',
options={'disp': True})
x0s, info = optimizer.run(x0s, bounds, lmbd)
# optimizer = optax.adam(learning_rate=0.1)
# optimizer_state = optimizer.init(x0s)
#
# for i in range(1, maxiter + 1):
#
# print('ITERATION -->', i)
#
# gradients = jax.grad(loss)(x0s, lmbd)
# updates, optimizer_state = optimizer.update(gradients, optimizer_state, x0s)
# x0s = optax.apply_updates(x0s, updates)
# x0s = jnp.clip(x0s, jnp.array(minx0s), jnp.array(maxx0s))
# print('Objective function: {:.3E}'.format(loss(x0s, lmbd)))
end = ela_time.time()
print(end - start) # total elapsed time: ~30 seconds
以下是相关方面的细分:
x0s
): 5263idc
字典中的观测数据(4800个数据点)model
函数中定义,也利用插值jaxopt.ScipyBoundedMinimize
使用 L-BFGS-B
方法(慢约 30 秒,大部分时间花在第一次迭代期间或之前)optax.adam
,但由于建模的固有性质,我无法成功,因为x0s
无法分割。 (假设我正确理解并行化)问题:
ScipyBoundedMinimize
中第一次迭代之前或期间缓慢的潜在原因是什么?jax
中是否有其他优化算法对于我的场景可能更快(大量自由参数和数据点,带有插值的复杂模型)?optax.adam
的并行化?在这种情况下,有什么潜在并行化的策略吗?附加信息:
如果您有任何关于提高优化性能的见解或建议,我将不胜感激。
JAX 代码是即时 (JIT) 编译的,这意味着第一步的持续时间较长可能与编译成本有关。代码越长,编译所需的时间就越长。
导致编译时间过长的一个常见问题是使用 Python 控制流,例如
for
循环。 JAX 的跟踪机制本质上使这些循环变平(请参阅JAX Sharp Bits:控制流)。在您的例子中,您在数据结构中循环了 4800 多个条目,因此创建了一个非常长且低效的程序。
jax.vmap
重写程序。与大多数 JAX 构造一样,这最适合使用 struct-of-arrays 模式,而不是数据中使用的 array-of-structs 模式。因此,使用 vmap
的第一步是以 JAX 可以使用的方式重组数据;它可能看起来像这样:
itimes = jnp.arange(len(idc['times']))
prf = jnp.array([idc['data'][i]['prf'] for i in itimes])
errs = jnp.array([idc['data'][i]['errs'] for i in itimes])
sprf = jax.vmap(model, in_axes=[None, None, 0, 0, 0])(fssc, fssh, itimes, rv, amp)
chi2 = jnp.sum((oprf - sprf) / oprf_errs) ** 2) / len(times) / oprf.shape[1]
请注意,这假设
idc['data'][i]['prf']
和 idc['data'][i]['errs']
的每个条目都具有相同的形状。如果没有,那么我担心您的问题不是特别适合 JAX 的 SPMD 编程模型,并且没有一种简单的方法可以解决长时间编译的需要。