diff --git a/src/vib/vib.py b/src/vib/vib.py index 6e0ce1f7..5b386e90 100644 --- a/src/vib/vib.py +++ b/src/vib/vib.py @@ -1,4 +1,6 @@ import numpy as np +import sympy as sp +from devito import Dimension, Constant, TimeFunction, Eq, solve, Operator #import matplotlib.pyplot as plt import scitools.std as plt @@ -11,27 +13,33 @@ def solver(I, V, m, b, s, F, dt, T, damping='linear'): 'quadratic', f(u')=b*u'*abs(u'). F(t) and s(u) are Python functions. """ - dt = float(dt); b = float(b); m = float(m) # avoid integer div. + dt = float(dt) + b = float(b) + m = float(m) Nt = int(round(T/dt)) - u = np.zeros(Nt+1) - t = np.linspace(0, Nt*dt, Nt+1) + t = Dimension('t', spacing=Constant('h_t')) + + u = TimeFunction(name='u', dimensions=(t,), + shape=(Nt+1,), space_order=2) + + u.data[0] = I - u[0] = I if damping == 'linear': - u[1] = u[0] + dt*V + dt**2/(2*m)*(-b*V - s(u[0]) + F(t[0])) + # dtc for central difference (default for time is forward, 1st order) + eqn = m*u.dt2 + b*u.dtc + s(u) - F(u) + stencil = Eq(u.forward, solve(eqn, u.forward)) elif damping == 'quadratic': - u[1] = u[0] + dt*V + \ - dt**2/(2*m)*(-b*V*abs(V) - s(u[0]) + F(t[0])) + # fd_order set as backward derivative used is 1st order + eqn = m*u.dt2 + b*u.dt*sp.Abs(u.dtl(fd_order=1)) + s(u) - F(u) + stencil = Eq(u.forward, solve(eqn, u.forward)) + # First timestep needs to have the backward timestep substituted + stencil_init = stencil.subs(u.backward, u.forward-2*t.spacing*V) + op_init = Operator(stencil_init, name='first_timestep') + op = Operator(stencil, name='main_loop') + op_init.apply(h_t=dt, t_M=1) + op.apply(h_t=dt, t_m=1, t_M=Nt-1) - for n in range(1, Nt): - if damping == 'linear': - u[n+1] = (2*m*u[n] + (b*dt/2 - m)*u[n-1] + - dt**2*(F(t[n]) - s(u[n])))/(m + b*dt/2) - elif damping == 'quadratic': - u[n+1] = (2*m*u[n] - m*u[n-1] + b*u[n]*abs(u[n] - u[n-1]) - + dt**2*(F(t[n]) - s(u[n])))/\ - (m + b*abs(u[n] - u[n-1])) - return u, t + return u.data, np.linspace(0, Nt*dt, Nt+1) def visualize(u, t, title='', filename='tmp'): plt.plot(t, u, 'b-') @@ -46,8 +54,6 @@ def visualize(u, t, title='', filename='tmp'): plt.savefig(filename + '.pdf') plt.show() -import sympy as sym - def test_constant(): """Verify a constant solution.""" u_exact = lambda t: I @@ -68,24 +74,24 @@ def test_constant(): def lhs_eq(t, m, b, s, u, damping='linear'): """Return lhs of differential equation as sympy expression.""" - v = sym.diff(u, t) + v = sp.diff(u, t) if damping == 'linear': - return m*sym.diff(u, t, t) + b*v + s(u) + return m*sp.diff(u, t, t) + b*v + s(u) else: - return m*sym.diff(u, t, t) + b*v*sym.Abs(v) + s(u) + return m*sp.diff(u, t, t) + b*v*sp.Abs(v) + s(u) def test_quadratic(): """Verify a quadratic solution.""" I = 1.2; V = 3; m = 2; b = 0.9 s = lambda u: 4*u - t = sym.Symbol('t') + t = sp.Symbol('t') dt = 0.2 T = 2 q = 2 # arbitrary constant u_exact = I + V*t + q*t**2 - F = sym.lambdify(t, lhs_eq(t, m, b, s, u_exact, 'linear')) - u_exact = sym.lambdify(t, u_exact, modules='numpy') + F = sp.lambdify(t, lhs_eq(t, m, b, s, u_exact, 'linear')) + u_exact = sp.lambdify(t, u_exact, modules='numpy') u1, t1 = solver(I, V, m, b, s, F, dt, T, 'linear') diff = np.abs(u_exact(t1) - u1).max() tol = 1E-13 @@ -94,8 +100,8 @@ def test_quadratic(): # In the quadratic damping case, u_exact must be linear # in order exactly recover this solution u_exact = I + V*t - F = sym.lambdify(t, lhs_eq(t, m, b, s, u_exact, 'quadratic')) - u_exact = sym.lambdify(t, u_exact, modules='numpy') + F = sp.lambdify(t, lhs_eq(t, m, b, s, u_exact, 'quadratic')) + u_exact = sp.lambdify(t, u_exact, modules='numpy') u2, t2 = solver(I, V, m, b, s, F, dt, T, 'quadratic') diff = np.abs(u_exact(t2) - u2).max() assert diff < tol @@ -127,11 +133,11 @@ def test_mms(): """Use method of manufactured solutions.""" m = 4.; b = 1 w = 1.5 - t = sym.Symbol('t') - u_exact = 3*sym.exp(-0.2*t)*sym.cos(1.2*t) + t = sp.Symbol('t') + u_exact = 3*sp.exp(-0.2*t)*sp.cos(1.2*t) I = u_exact.subs(t, 0).evalf() - V = sym.diff(u_exact, t).subs(t, 0).evalf() - u_exact_py = sym.lambdify(t, u_exact, modules='numpy') + V = sp.diff(u_exact, t).subs(t, 0).evalf() + u_exact_py = sp.lambdify(t, u_exact, modules='numpy') s = lambda u: u**3 dt = 0.2 T = 6 @@ -140,14 +146,14 @@ def test_mms(): # Run grid refinements and compute exact error for i in range(5): F_formula = lhs_eq(t, m, b, s, u_exact, 'linear') - F = sym.lambdify(t, F_formula) + F = sp.lambdify(t, F_formula) u1, t1 = solver(I, V, m, b, s, F, dt, T, 'linear') error = np.sqrt(np.sum((u_exact_py(t1) - u1)**2)*dt) errors_linear.append((dt, error)) F_formula = lhs_eq(t, m, b, s, u_exact, 'quadratic') #print sym.latex(F_formula, mode='plain') - F = sym.lambdify(t, F_formula) + F = sp.lambdify(t, F_formula) u2, t2 = solver(I, V, m, b, s, F, dt, T, 'quadratic') error = np.sqrt(np.sum((u_exact_py(t2) - u2)**2)*dt) errors_quadratic.append((dt, error))