有没有办法在使用 jax.grad 获取梯度时接受一个函数?

问题描述 投票:0回答:1

我正在尝试为微分方程制作一个基于神经网络的微分方程求解器

y' + 2xy = 0
.

import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

def softplus(x):
    return jnp.log(1 + jnp.exp(x))

def init_params():
    params = jax.random.normal(key, shape=(241,))
    return params

def linear_model(params, x):
    w0 = params[:80]
    b0 = params[80:160]
    w1 = params[160:240]
    b1 = params[240]
    h = softplus(x*w0 + b0)
    o = jnp.sum(h*w1) + b1
    return o

def loss(derivative, initial_condition, params, model, x):
    dfdx = jax.grad(model, 1)
    dfdx_vect = jax.vmap(dfdx, (None, 0))
    model_vect = jax.vmap(model, (None, 0))
    eq_difference = dfdx_vect(params, x) - derivative(x, model(params, x))
    condition_difference = model(params, 0) - initial_condition
    return jnp.mean(eq_difference ** 2 - condition_difference ** 2)

def dfdx(x, y):
    return -2. * x * y

key = jax.random.PRNGKey(0)
inputs = np.linspace(0, 1, num=401)
params = init_params()

epochs = 2000
learning_rate = 0.0005

# Training Neural Network

for epoch in tqdm(range(epochs)):
    grad_loss = jax.grad(loss)
    gradient = grad_loss(dfdx, 1., params, linear_model, inputs)
    params -= learning_rate*gradient

model_vect = jax.vmap(linear_model, (None, 0))
preds = model_vect(params, inputs)

plt.plot(inputs, jnp.exp(inputs**2), label='exact')
plt.plot(inputs, model_vect(params, inputs), label='approx')
plt.legend()
plt.show()

问题是 Jax 不喜欢将接收另一个函数的函数的梯度作为参数:

TypeError: Argument '<function dfdx at 0x7fce88340af0>' of type <class 'function'> is not a valid JAX type.

这有什么解决方法吗?

neural-network differential-equations jax
1个回答
0
投票

你只是错误地订购了参数。 Jax 与众不同。第一个参数,你不想区分你的函数,而是 - 参数。让他们成为第一个论点。

def loss(params, derivative, initial_condition, model, x):
    dfdx = jax.grad(model, 1)
    dfdx_vect = jax.vmap(dfdx, (None, 0))
    model_vect = jax.vmap(model, (None, 0))
    eq_difference = dfdx_vect(params, x) - derivative(x, model(params, x))
    condition_difference = model(params, 0) - initial_condition
    return jnp.mean(eq_difference ** 2 - condition_difference ** 2)

© www.soinside.com 2019 - 2024. All rights reserved.