如何使用 scipy.integrate.solve_ivp 以向量化方式处理具有耦合微分方程的 numpy 数组输入

问题描述 投票:0回答:1

假设我想求解一系列方程,但它们没有耦合,只是由 numpy 数组定义不同。比如我有

y'(t) = 2*t + c

其中 c 只是 numpy 数组中给出的已知常量,例如

c = np.array([1, 2])
。这意味着我需要求解两个方程,可能还具有不同的初始值。这当然可以通过分析轻松解决。

y(t) = t^2 + c*t + C

其中 C 是另一个常数,取决于

y(0)
处的初始值。例如,第一个方程有
C = y(0) = 0
,第二个方程有
C = y(0) = 1
。然后我就可以得到精确的解决方案:

y =  t^2 + 1*t
y =  t^2 + 2*t + 1

如果我想使用

scipy.integrate.solve_ivp
来获取
t=1
处的值,那么我可以“假装”它们是耦合的。具体来说,代码如下

def test_fun(t, y):
    c = np.array([1, 2])
    dy1 = 2*t + c
    return dy1
sol = solve_ivp(test_fun, [0, 1], np.array([0, 1]))
y = sol.y[:, -1]

这实际上给了我

y = np.array([2., 4.])
,这与分析给出的结果相匹配。

但是现在假设我有一个方程列表,其中的常数与 numpy 数组不同,而且每个方程都由 3 个耦合微分方程组成。让我们仍然使用一个玩具示例,它们并不是真正的“耦合”,而是只需要在函数中返回多个值,就像求解耦合微分方程时所做的那样


def test_fun(t, y):
    c = np.array([1, 2])
    dy1 = 2*t + c
    dy2 = 3*t + c
    dy3 = 4*t + c
    return dy1, dy2, dy3

sol = solve_ivp(test_fun, [0, 1], np.array([[1, 2], [2, 2], [3, 2]]))

但现在它会抱怨

ValueError: y0 must be 1-dimensional.
Flatten y0 似乎没有意义,因为我希望函数的输入是大小为 2 的数组,但对于每个元素,它将返回一个由 3 个耦合方程组组成的系统。如果我只是展平,它只会假设它是 6 个耦合方程组的输入,然后抱怨
ValueError: could not broadcast input array from shape (3,2) into shape (6,)

当然我可以只使用 for 循环,但一般的想法是避免它,因为它在 python 中非常慢,而且我的

c
数组非常大。

numpy math scipy vectorization differential-equations
1个回答
0
投票

“展平 y0 似乎没有意义,因为我希望函数的输入是大小为 2 的数组,但对于每个元素,它将返回一个由 3 个耦合方程组组成的系统。如果我只是展平,它将假设它是 6 个耦合方程组的输入,然后抱怨 ValueError: Could not Broadcast input array from shape (3,2) into shape (6,)"

在将

y0
传递给求解器之前,您必须将其展平。然后
test_fun
将获得展平的向量,因此在
test_fun
中,您将重塑
y
,使用数组进行计算,然后展平导数数组,然后从
test_fun
返回它。当求解器返回时,您还必须重新调整结果,使其看起来像数组的数组。

我创建了一个名为

odeintw
的包来为您执行此操作,但它使用
scipy.integrate.odeint
,而不是
solve_ivp
。如果您使用
odeintw
,您的脚本可能如下所示:

import numpy as np
from odeintw import odeintw


def test_fun(t, y):
    c = np.array([1, 2])
    dy1 = 2*t + c
    dy2 = 3*t + c
    dy3 = 4*t + c
    return dy1, dy2, dy3


n = 250
t = np.linspace(0, 1, n)
y0 = np.array([[1, 2], [2, 2], [3, 2]])
sol = odeintw(test_fun, y0, t, tfirst=True)
# The numerical solution `sol`` is an array with shape (n, 3, 2).

您必须运行一些测试,看看这种方法是否比在各个系统上运行 Python 循环更快。 (为了提高性能,您可以尝试实现

Dfun
odeint
参数。请参阅
odeintw
文档字符串了解如何实现它。)

© www.soinside.com 2019 - 2024. All rights reserved.