Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to vib_gen notebook and accompanying python files #65

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 52 additions & 43 deletions fdm-devito-notebooks/01_vib/vib_gen.ipynb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"cells": [
{
"cell_type": "raw",
"cell_type": "markdown",
"metadata": {},
"source": [
"<!-- Equation labels as ordinary links -->\n",
Expand Down Expand Up @@ -510,6 +510,11 @@
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import sympy as sp\n",
"from devito import Dimension, Constant, TimeFunction, Eq, solve, Operator\n",
"\n",
"\n",
"def solver(I, V, m, b, s, F, dt, T, damping='linear'):\n",
" \"\"\"\n",
" Solve m*u'' + f(u') + s(u) = F(t) for t in (0,T],\n",
Expand All @@ -519,27 +524,40 @@
" 'quadratic', f(u')=b*u'*abs(u').\n",
" F(t) and s(u) are Python functions.\n",
" \"\"\"\n",
" dt = float(dt); b = float(b); m = float(m) # avoid integer div.\n",
" dt = float(dt)\n",
" b = float(b)\n",
" m = float(m)\n",
" Nt = int(round(T/dt))\n",
" u = np.zeros(Nt+1)\n",
" t = np.linspace(0, Nt*dt, Nt+1)\n",
" t = Dimension('t', spacing=Constant('h_t'))\n",
"\n",
" u = TimeFunction(name='u', dimensions=(t,),\n",
" shape=(Nt+1,), space_order=2)\n",
"\n",
" u.data[0] = I\n",
"\n",
" u[0] = I\n",
" if damping == 'linear':\n",
" u[1] = u[0] + dt*V + dt**2/(2*m)*(-b*V - s(u[0]) + F(t[0]))\n",
" # dtc for central difference (default for time is forward, 1st order)\n",
" eqn = m*u.dt2 + b*u.dtc + s(u) - F(u)\n",
" stencil = Eq(u.forward, solve(eqn, u.forward))\n",
" elif damping == 'quadratic':\n",
" u[1] = u[0] + dt*V + \\\n",
" dt**2/(2*m)*(-b*V*abs(V) - s(u[0]) + F(t[0]))\n",
" # fd_order set as backward derivative used is 1st order\n",
" eqn = m*u.dt2 + b*u.dt*sp.Abs(u.dtl(fd_order=1)) + s(u) - F(u)\n",
" stencil = Eq(u.forward, solve(eqn, u.forward))\n",
"\n",
" for n in range(1, Nt):\n",
" if damping == 'linear':\n",
" u[n+1] = (2*m*u[n] + (b*dt/2 - m)*u[n-1] +\n",
" dt**2*(F(t[n]) - s(u[n])))/(m + b*dt/2)\n",
" elif damping == 'quadratic':\n",
" u[n+1] = (2*m*u[n] - m*u[n-1] + b*u[n]*abs(u[n] - u[n-1])\n",
" + dt**2*(F(t[n]) - s(u[n])))/\\\n",
" (m + b*abs(u[n] - u[n-1]))\n",
" return u, t"
" # First timestep needs to have the backward timestep substituted\n",
" # Has to be done to the equation otherwise the stencil will have\n",
" # forward timestep on both sides\n",
" # FIXME: Doesn't look like you can do subs or solve on anything inside an Abs\n",
" eqn_init = eqn.subs(u.backward, u.forward-2*t.spacing*V)\n",
" stencil_init = Eq(u.forward, solve(eqn_init, u.forward))\n",
" # stencil_init = stencil.subs(u.backward, u.forward-2*t.spacing*V)\n",
"\n",
" op_init = Operator(stencil_init, name='first_timestep')\n",
" op = Operator(stencil, name='main_loop')\n",
" op_init.apply(h_t=dt, t_M=1)\n",
" op.apply(h_t=dt, t_m=1, t_M=Nt-1)\n",
"\n",
" return u.data, np.linspace(0, Nt*dt, Nt+1)"
]
},
{
Expand Down Expand Up @@ -619,27 +637,6 @@
"the quadratic polynomial is reproduced by the numerical method (to\n",
"machine precision).\n",
"\n",
"### Catching bugs\n",
"\n",
"How good are the constant and quadratic solutions at catching\n",
"bugs in the implementation? Let us check that by introducing some bugs.\n",
"\n",
" * Use `m` instead of `2*m` in the denominator of `u[1]`: code works for constant\n",
" solution, but fails (as it should) for a quadratic one.\n",
"\n",
" * Use `b*dt` instead of `b*dt/2` in the updating formula for `u[n+1]`\n",
" in case of linear damping: constant and quadratic both fail.\n",
"\n",
" * Use `F[n+1]` instead of `F[n]` in case of linear or quadratic damping:\n",
" constant solution works, quadratic fails.\n",
"\n",
"We realize that the constant solution is very useful for catching certain bugs because\n",
"of its simplicity (easy to predict what the different terms in the\n",
"formula should evaluate to), while the quadratic solution seems\n",
"capable of detecting all (?) other kinds of typos in the scheme.\n",
"These results demonstrate why we focus so much on exact, simple polynomial\n",
"solutions of the numerical schemes in these writings.\n",
"\n",
"<!-- More: classes, cases with pendulum approx u vs sin(u), -->\n",
"<!-- making UI via parampool -->\n",
"\n",
Expand Down Expand Up @@ -1470,7 +1467,19 @@
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'scitools'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-3-b591f37272a6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mscitools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstd\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msympy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0msym\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mvib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msolver\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mvib_solver\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'scitools'"
]
}
],
"source": [
"# Reimplementation of vib.py using classes\n",
"\n",
Expand Down Expand Up @@ -1858,7 +1867,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -2191,9 +2200,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "devito",
"language": "python",
"name": "python3"
"name": "devito"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -2205,7 +2214,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
"version": "3.8.1"
}
},
"nbformat": 4,
Expand Down
100 changes: 60 additions & 40 deletions src/vib/vib.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
#import matplotlib.pyplot as plt
import scitools.std as plt
import sympy as sp
from devito import Dimension, Constant, TimeFunction, Eq, solve, Operator
import matplotlib.pyplot as plt
# import scitools.std as plt


def solver(I, V, m, b, s, F, dt, T, damping='linear'):
"""
Expand All @@ -11,27 +14,41 @@ def solver(I, V, m, b, s, F, dt, T, damping='linear'):
'quadratic', f(u')=b*u'*abs(u').
F(t) and s(u) are Python functions.
"""
dt = float(dt); b = float(b); m = float(m) # avoid integer div.
dt = float(dt)
b = float(b)
m = float(m)
Nt = int(round(T/dt))
u = np.zeros(Nt+1)
t = np.linspace(0, Nt*dt, Nt+1)
t = Dimension('t', spacing=Constant('h_t'))

u = TimeFunction(name='u', dimensions=(t,),
shape=(Nt+1,), space_order=2)

u.data[0] = I

u[0] = I
if damping == 'linear':
u[1] = u[0] + dt*V + dt**2/(2*m)*(-b*V - s(u[0]) + F(t[0]))
# dtc for central difference (default for time is forward, 1st order)
eqn = m*u.dt2 + b*u.dtc + s(u) - F(u)
stencil = Eq(u.forward, solve(eqn, u.forward))
elif damping == 'quadratic':
u[1] = u[0] + dt*V + \
dt**2/(2*m)*(-b*V*abs(V) - s(u[0]) + F(t[0]))

for n in range(1, Nt):
if damping == 'linear':
u[n+1] = (2*m*u[n] + (b*dt/2 - m)*u[n-1] +
dt**2*(F(t[n]) - s(u[n])))/(m + b*dt/2)
elif damping == 'quadratic':
u[n+1] = (2*m*u[n] - m*u[n-1] + b*u[n]*abs(u[n] - u[n-1])
+ dt**2*(F(t[n]) - s(u[n])))/\
(m + b*abs(u[n] - u[n-1]))
return u, t
# fd_order set as backward derivative used is 1st order
eqn = m*u.dt2 + b*u.dt*sp.Abs(u.dtl(fd_order=1)) + s(u) - F(u)
stencil = Eq(u.forward, solve(eqn, u.forward))

# First timestep needs to have the backward timestep substituted
# Has to be done to the equation otherwise the stencil will have
# forward timestep on both sides
# FIXME: Doesn't look like you can do subs or solve on anything inside an Abs
eqn_init = eqn.subs(u.backward, u.forward-2*t.spacing*V)
stencil_init = Eq(u.forward, solve(eqn_init, u.forward))
# stencil_init = stencil.subs(u.backward, u.forward-2*t.spacing*V)

op_init = Operator(stencil_init, name='first_timestep')
op = Operator(stencil, name='main_loop')
op_init.apply(h_t=dt, t_M=1)
op.apply(h_t=dt, t_m=1, t_M=Nt-1)

return u.data, np.linspace(0, Nt*dt, Nt+1)


def visualize(u, t, title='', filename='tmp'):
plt.plot(t, u, 'b-')
Expand All @@ -46,12 +63,14 @@ def visualize(u, t, title='', filename='tmp'):
plt.savefig(filename + '.pdf')
plt.show()

import sympy as sym

def test_constant():
"""Verify a constant solution."""
u_exact = lambda t: I
I = 1.2; V = 0; m = 2; b = 0.9
I = 1.2
V = 0
m = 2
b = 0.9
w = 1.5
s = lambda u: w**2*u
F = lambda t: w**2*u_exact(t)
Expand All @@ -66,26 +85,27 @@ def test_constant():
difference = np.abs(u_exact(t) - u).max()
assert difference < tol


def lhs_eq(t, m, b, s, u, damping='linear'):
"""Return lhs of differential equation as sympy expression."""
v = sym.diff(u, t)
v = sp.diff(u, t)
if damping == 'linear':
return m*sym.diff(u, t, t) + b*v + s(u)
return m*sp.diff(u, t, t) + b*v + s(u)
else:
return m*sym.diff(u, t, t) + b*v*sym.Abs(v) + s(u)
return m*sp.diff(u, t, t) + b*v*sp.Abs(v) + s(u)

def test_quadratic():
"""Verify a quadratic solution."""
I = 1.2; V = 3; m = 2; b = 0.9
s = lambda u: 4*u
t = sym.Symbol('t')
t = sp.Symbol('t')
dt = 0.2
T = 2

q = 2 # arbitrary constant
u_exact = I + V*t + q*t**2
F = sym.lambdify(t, lhs_eq(t, m, b, s, u_exact, 'linear'))
u_exact = sym.lambdify(t, u_exact, modules='numpy')
F = sp.lambdify(t, lhs_eq(t, m, b, s, u_exact, 'linear'))
u_exact = sp.lambdify(t, u_exact, modules='numpy')
u1, t1 = solver(I, V, m, b, s, F, dt, T, 'linear')
diff = np.abs(u_exact(t1) - u1).max()
tol = 1E-13
Expand All @@ -94,8 +114,8 @@ def test_quadratic():
# In the quadratic damping case, u_exact must be linear
# in order exactly recover this solution
u_exact = I + V*t
F = sym.lambdify(t, lhs_eq(t, m, b, s, u_exact, 'quadratic'))
u_exact = sym.lambdify(t, u_exact, modules='numpy')
F = sp.lambdify(t, lhs_eq(t, m, b, s, u_exact, 'quadratic'))
u_exact = sp.lambdify(t, u_exact, modules='numpy')
u2, t2 = solver(I, V, m, b, s, F, dt, T, 'quadratic')
diff = np.abs(u_exact(t2) - u2).max()
assert diff < tol
Expand Down Expand Up @@ -127,11 +147,11 @@ def test_mms():
"""Use method of manufactured solutions."""
m = 4.; b = 1
w = 1.5
t = sym.Symbol('t')
u_exact = 3*sym.exp(-0.2*t)*sym.cos(1.2*t)
t = sp.Symbol('t')
u_exact = 3*sp.exp(-0.2*t)*sp.cos(1.2*t)
I = u_exact.subs(t, 0).evalf()
V = sym.diff(u_exact, t).subs(t, 0).evalf()
u_exact_py = sym.lambdify(t, u_exact, modules='numpy')
V = sp.diff(u_exact, t).subs(t, 0).evalf()
u_exact_py = sp.lambdify(t, u_exact, modules='numpy')
s = lambda u: u**3
dt = 0.2
T = 6
Expand All @@ -140,14 +160,14 @@ def test_mms():
# Run grid refinements and compute exact error
for i in range(5):
F_formula = lhs_eq(t, m, b, s, u_exact, 'linear')
F = sym.lambdify(t, F_formula)
F = sp.lambdify(t, F_formula)
u1, t1 = solver(I, V, m, b, s, F, dt, T, 'linear')
error = np.sqrt(np.sum((u_exact_py(t1) - u1)**2)*dt)
errors_linear.append((dt, error))

F_formula = lhs_eq(t, m, b, s, u_exact, 'quadratic')
#print sym.latex(F_formula, mode='plain')
F = sym.lambdify(t, F_formula)
F = sp.lambdify(t, F_formula)
u2, t2 = solver(I, V, m, b, s, F, dt, T, 'quadratic')
error = np.sqrt(np.sum((u_exact_py(t2) - u2)**2)*dt)
errors_quadratic.append((dt, error))
Expand Down Expand Up @@ -205,10 +225,10 @@ def plot_empirical_freq_and_amplitude(u, t):
a = amplitudes(minima, maxima)
plt.figure()
from math import pi
w = 2*pi/p
plt.plot(range(len(p)), w, 'r-')
w = 2*pi/p
plt.plot(list(range(len(p))), w, 'r-')
plt.hold('on')
plt.plot(range(len(a)), a, 'b-')
plt.plot(list(range(len(a))), a, 'b-')
ymax = 1.1*max(w.max(), a.max())
ymin = 0.9*min(w.min(), a.min())
plt.axis([0, max(len(p), len(a)), ymin, ymax])
Expand Down Expand Up @@ -241,7 +261,7 @@ def visualize_front(u, t, window_width, savefig=False):
axis=plot_manager.axis(),
show=not savefig) # drop window if savefig
if savefig:
print 't=%g' % t[n]
print('t=%g' % t[n])
st.savefig('tmp_vib%04d.png' % n)
plot_manager.update(n)

Expand All @@ -257,7 +277,7 @@ def visualize_front_ascii(u, t, fps=10):

p = Plotter(ymin=umin, ymax=umax, width=60, symbols='+o')
for n in range(len(u)):
print p.plot(t[n], u[n]), '%.2f' % (t[n])
print(p.plot(t[n], u[n]), '%.2f' % (t[n]))
time.sleep(1/float(fps))

def minmax(t, u):
Expand Down