我正在开发一个用 JAX 编写的微分方程求解器。我遇到的常见工作流程是这样的:
import jax.numpy as jnp
from jax import jit
# Function to integrate.
@jit
def dxdt(t, x):
return -x**2
# Euler method for simplicity.
@jit
def integrator(f, t, x, dt):
return x + f(t, x) * dt
t_arr = jnp.linspace(0, 10, 100)
dt = t_arr[1] - t_arr[0]
x_list = []
# initialize x.
x = 0.
for t in t_arr:
x_list.append(x)
x = integrator(f, t, x, dt)
x_arr = jnp.array(x_list)
我的问题是是否有一种方法可以使用 JAX 来“矢量化”for 循环?
我认识到
jax.vmap()
在这里不合适,因为变量 x 在每次 for 循环迭代中都会发生变化。此工作流程是否有更 JAX 友好的方法?
jax.lax.scan
支持。以下是使用 scan
进行等效计算的方法:
import jax
def scan_body(carry, t):
x, dt = carry
new_x = integrator(dxdt, t, x, dt)
return (new_x, dt), x
_, x_arr = jax.lax.scan(scan_body, (0., dt), t_arr)