我正在阅读 devito_book/fdm-jupyter-book/notebooks/01_vib/vib_undamped.ipynb,其中的代码似乎与 devito 4.8.3 不兼容。
所以我尝试将其重写为:
import numpy as np
from devito import Constant, TimeFunction, Eq, solve, Operator, Grid, TimeDimension
def solver(I, w, dt, T):
dt = float(dt)
Nt = int(round(T / dt))
t = TimeDimension('t', spacing=Constant('h_t'))
grid = Grid(shape=(Nt + 1, ), time_dimension=t)
u = TimeFunction(name='u', grid=grid, time_order=2)
u.data[:] = I
eqn = u.dt2 + w ** 2 * u
stencil = Eq(u.forward, solve(eqn, u.forward))
op = Operator(stencil)
op.apply(h_t=dt, t_M=Nt - 1)
return u.data, np.linspace(0, Nt * dt, Nt + 1)
I = 1
w = 2 * np.pi
dt = 0.05
num_periods = 5
P = 2 * np.pi / w
T = P * num_periods
u, t = solver(I, w, dt, T)
但是返回的数据
u
是:
[[0.95374113 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113, 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113, 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113, 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113, 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113, 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113, 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113, 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113, 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113, 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113, 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113, 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113, 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113, 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113, 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113, 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113, 0.95374113 0.95374113 0.95374113 0.95374113 0.95374113], [1.0121983 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983, 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983, 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983, 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983, 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983, 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983, 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983, 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983, 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983, 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983, 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983, 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983, 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983, 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983, 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983, 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983, 1.0121983 1.0121983 1.0121983 1.0121983 1.0121983 ], [0.8011535 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535, 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535, 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535, 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535, 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535, 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535, 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535, 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535, 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535, 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535, 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535, 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535, 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535, 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535, 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535, 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535, 0.8011535 0.8011535 0.8011535 0.8011535 0.8011535 ]]
该值不会从一个时间节点更新到下一个时间节点。谁能帮我解决这个问题吗?
提前谢谢您!
import numpy as np
import matplotlib.pyplot as plt
from devito import Constant, TimeFunction, Eq, solve, Operator, Grid, TimeDimension
def solver(I, w, dt, T):
dt = float(dt)
Nt = int(round(T / dt))
t = TimeDimension('t', spacing=Constant('h_t'))
grid = Grid(shape=(2, ), time_dimension=t)
u = TimeFunction(name='u', grid=grid, time_order=2, save=Nt + 1)
u.data[0:] = I
u.data[1:] = (1 - 0.5 * dt ** 2 * w ** 2) * I
eqn = u.dt2 + w ** 2 * u
stencil = Eq(u.forward, solve(eqn, u.forward))
op = Operator(stencil)
op.apply(h_t=dt, t_M=Nt - 1)
return np.array(u.data)[:, 0], np.linspace(0, Nt * dt, Nt + 1)
def u_exact(t, I, w):
return I * np.cos(w * t)
def visualize(u, t, I, w):
plt.plot(t, u, 'r--o')
t_fine = np.linspace(0, t[-1], 1001)
u_e = u_exact(t_fine, I, w)
plt.plot(t_fine, u_e, 'b-')
plt.legend(['numerical', 'exact'], loc='upper left')
plt.xlabel('t')
plt.ylabel('u')
dt = t[1] - t[0]
plt.title('dt=%g' % dt)
umin = 1.2 * u.min()
umax = -umin
plt.axis((t[0], t[-1], umin, umax))
plt.savefig('tmp.png')
plt.savefig('tmp.pdf')
I = 1
w = 2 * np.pi
dt = 0.05
num_periods = 5
P = 2 * np.pi / w
T = P * num_periods
u, t = solver(I, w, dt, T)
visualize(u, t, I, w)