jax 中的向量化最小化和求根

问题描述 投票:0回答:1

我有一系列由

args

参数化的函数
f(x, args)

并且想要确定

f
x
之间的最小值,以获得
N = 1000
args
值。我可以访问该函数及其导数。我的第一次尝试是循环遍历
args
的不同值并在每次迭代时使用 scipy.optimizer,但它花费的时间太长。我相信矢量化可以加快操作速度。我的下一次尝试是在
jax.vmap
jax.scipy.optimize.minimize
中使用
jaxopt.ScipyMinimize
,但我似乎无法为
args
传递多个值。

或者,我可以编写自己的矢量化优化方法,例如二分,其中向量化的意思是对数组进行固定次数的迭代操作,并且如果其中一个优化问题已提前达到特定的容错级别,则不会提前停止。我希望使用一些优化的现成算法。

如果 jax 中有可用的实现,我希望使用一些已经优化的现成算法。this 线程相关,但

args
没有改变。

python optimization scipy vectorization jax
1个回答
0
投票

您可以定义一个函数来查找给定特定

args
的最小值,然后将其包装在
jax.vmap
中以自动对其进行矢量化。例如:

import jax
import jax.numpy as jnp
from jax.scipy import optimize

def f(x, args):
  a, b = args
  return jnp.sum(a + (x - b) ** 2)

def find_min(a, b):
  x0 = jnp.array([1.0])
  args = (a, b)
  return optimize.minimize(f, x0, (args,), method="BFGS")

a_grid, b_grid = jnp.meshgrid(jnp.arange(5.0), jnp.arange(5.0))

results = jax.vmap(find_min)(a_grid.ravel(), b_grid.ravel())

print(results.success)
# [ True  True  True  True  True  True  True  True  True  True  True  True
#   True  True  True  True  True  True  True  True  True  True  True  True
#   True]

print(results.x.T)
# [[0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 2. 2. 2. 2. 2.
#   3. 3. 3. 3. 3. 4. 4. 4. 4. 4.]]
© www.soinside.com 2019 - 2024. All rights reserved.