追加列出 JAX 的替代工作流程

问题描述 投票:0回答:1

我正在开发一个用 JAX 编写的微分方程求解器。我遇到的常见工作流程是这样的:

import jax.numpy as jnp
from jax import jit

# Function to integrate.
@jit
def dxdt(t, x):
   return -x**2

# Euler method for simplicity.
@jit
def integrator(f, t, x, dt):
    return x + f(t, x) * dt

t_arr = jnp.linspace(0, 10, 100)
dt = t_arr[1] - t_arr[0]

x_list = []

# initialize x.
x = 0.

for t in t_arr:
    x_list.append(x)
    x = integrator(f, t, x, dt)

x_arr = jnp.array(x_list)

我的问题是是否有一种方法可以使用 JAX 来“矢量化”for 循环?

我认识到

jax.vmap()
在这里不合适,因为变量 x 在每次 for 循环迭代中都会发生变化。此工作流程是否有更 JAX 友好的方法?

python numerical-methods differential-equations jax
1个回答
0
投票

此类操作通过

jax.lax.scan
支持。以下是使用
scan
进行等效计算的方法:

import jax

def scan_body(carry, t):
  x, dt = carry
  new_x = integrator(dxdt, t, x, dt)
  return (new_x, dt), x

_, x_arr = jax.lax.scan(scan_body, (0., dt), t_arr)
© www.soinside.com 2019 - 2024. All rights reserved.