我正在尝试在 JAX 中创建一个物理信息神经网络 (PINN)。我想通过输入 (x) 区分定义的模型(神经网络)。如果我将
model
设置为 jax.grad(params)
,则会出现错误。model
设置为 jax.grad(model)
,我不会收到错误,但我不知道是否能够通过 x 区分神经网络的模型。
class MLP(fnn.Module):
@fnn.compact
def __call__(self, x):
x = fnn.Dense(128)(x)
x = fnn.relu(x)
x = fnn.Dense(256)(x)
x = fnn.relu(x)
x = fnn.Dense(10)(x)
return x
model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones([1]))['params']
tx = optax.adam(0.001)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
您可以通过以下方式区分 JAX 中的模型:(1) 定义要区分的函数,(2) 根据您的应用程序使用
jax.grad
、jax.jacrev
、jax.jacfwd
等对其进行转换,以及 (3 )将数据传递给转换后的函数。
从您的问题中并不完全清楚您希望区分什么操作,但这是一个计算相对于参数的训练状态创建的前向模式雅可比的示例:
def f(params):
return TrainState.create(apply_fn=model.apply, params=params, tx=tx)
result = jax.jacfwd(f)(params)
如果这没有帮助,我建议编辑你的问题以明确你有兴趣区分什么操作。
请检查此存储库https://github.com/PredictiveIntelligenceLab/jaxpi,其中包含使用 JAX 全面实现 PINN。