Zygote 不支持变异数组

问题描述 投票:0回答:1

我正在尝试使用 科学机器学习框架在 Julia 中编写来自 example

的螺旋示例

但是,当代码通过

get_batch
函数进行区分时,我收到错误:

ERROR: Mutating arrays is not supported -- called setindex!(Vector{Int64}, ...) This error occurs when you ask Zygote to differentiate operations that change the elements of arrays in place (e.g. setting values with x .= ...)

这是一个最小的可重现示例:

  using Lux, DiffEqFlux, DifferentialEquations, ComponentArrays, Random, StatsBase, MLUtils
  using Zygote

  true_y0 = [2., 0.];
  true_A = [-0.1 2.; -2. -0.1];

  data_size = 1000;
  times = LinRange(0, 25, data_size);

  function ground_truth!(du, u, p, t)
      du .= true_A * (u.^3) 
  end

  ground_truth_odeProb = ODEProblem(ground_truth!, true_y0, (0, times[end]));

  sol_ode = Array(
                    solve(ground_truth_odeProb,
                          Tsit5(),
                          abstol = 1e-10, reltol = 1e-10,
                          saveat = times)
                  );

  batch_time = 10
  batch_size = 20

  function get_batch()

    s = sample(range(1, data_size - batch_time), batch_size, replace=false)

    batch_y0 = sol_ode[ :, s];
    
    batch_t = times[1:batch_time];

    batch_y = stack([sol_ode[:, s .+ i] for i in range(1, batch_time)], dims=3);

    return [batch_y0, batch_t, batch_y];

  end

  const neural_net = Lux.Chain(Lux.Dense(2, 50, tanh),
                        Lux.Dense(50, 2))     

  rng = Random.default_rng();
  p, st = Lux.setup(rng, neural_net)

  const _st = st;

  p_init = ComponentArray(p);

  function neural_net_func!(du, u, p, t)
    du .= neural_net(u.^3, p, st)[1];
  end

  prob_nn = ODEProblem(neural_net_func!, [0., 0.], [0., 0.], p);

  function predict(θ, y0s, ts)
    
    _prob = remake( prob_nn, 
                    u0 = y0s, 
                    tspan = (ts[1], ts[end]), 
                    p = θ
                  );
    
    Array(solve(_prob, Tsit5(), saveat = ts,
              abstol = 1e-5, reltol = 1e-5)); 
  end

  function test(θ, y0s, ts)
    y0s, ts, targets = get_batch();
    pred = predict(θ, y0s, ts);
    loss = sum(abs2, targets .- pred)
  end

  x, lambda = pullback((θ, y0s, ts) -> test(θ, y0s, ts), p_init, y0s, ts);

  lambda(x) # give the error
julia
1个回答
0
投票

标题已经告诉您问题所在:Zygote,自动微分 (AD) 库,不支持变异数组,但您可以在 RHS 中变异数组。

事实上许多AD库不支持突变。由于 AD 依赖于跟踪您的方法执行的每个操作,因此支持变异操作在技术上非常具有挑战性。

你的例子并不是那么简单,所以我没有运行它。但对我来说,似乎有一个简单的解决方法:如果你看一下 RHS

neural_net_func!
的定义,很容易避免突变:只需以不合适的样式定义它,所以编写一个函数

neural_net_func(u,p,t) = neural_net(u.^3, p, st)[1];

我怀疑这已经可以解决您的问题。如果您查看错误消息,这也已经是错误消息建议您执行的操作!

如果这没有帮助,还有更普遍的其他方法可以解决这个问题:

© www.soinside.com 2019 - 2024. All rights reserved.